You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/12 08:51:14 UTC

[airflow] branch main updated: Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 19dd9f5873 Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)
19dd9f5873 is described below

commit 19dd9f5873098decb41040b0c252a6072a67a356
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Sun Jun 12 12:51:08 2022 +0400

    Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)
---
 airflow/providers/amazon/aws/hooks/base_aws.py     |  16 ++++
 airflow/providers/amazon/aws/links/__init__.py     |  16 ++++
 airflow/providers/amazon/aws/links/base_aws.py     | 100 ++++++++++++++++++++
 airflow/providers/amazon/aws/links/emr.py          |  28 ++++++
 airflow/providers/amazon/aws/operators/emr.py      |  80 ++++++++--------
 airflow/providers/amazon/provider.yaml             |   2 +-
 tests/providers/amazon/aws/hooks/test_base_aws.py  |  56 ++++++++++-
 tests/providers/amazon/aws/links/__init__.py       |  16 ++++
 tests/providers/amazon/aws/links/test_base.py      |  83 +++++++++++++++++
 tests/providers/amazon/aws/links/test_emr.py       | 103 +++++++++++++++++++++
 .../amazon/aws/operators/test_emr_add_steps.py     |   5 +-
 .../aws/operators/test_emr_create_job_flow.py      |  48 +---------
 .../aws/operators/test_emr_terminate_job_flow.py   |   2 +-
 .../providers/amazon/aws/utils/links_test_utils.py |  25 +++++
 14 files changed, 491 insertions(+), 89 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index bbf0bfff83..162be9ce47 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -37,6 +37,7 @@ import botocore
 import botocore.session
 import requests
 import tenacity
+from botocore.client import ClientMeta
 from botocore.config import Config
 from botocore.credentials import ReadOnlyCredentials
 from slugify import slugify
@@ -521,6 +522,21 @@ class AwsBaseHook(BaseHook):
             # Rare possibility - subclasses have not specified a client_type or resource_type
             raise NotImplementedError('Could not get boto3 connection!')
 
+    @cached_property
+    def conn_client_meta(self) -> ClientMeta:
+        conn = self.conn
+        if isinstance(conn, botocore.client.BaseClient):
+            return conn.meta
+        return conn.meta.client.meta
+
+    @property
+    def conn_region_name(self) -> str:
+        return self.conn_client_meta.region_name
+
+    @property
+    def conn_partition(self) -> str:
+        return self.conn_client_meta.partition
+
     def get_conn(self) -> Union[boto3.client, boto3.resource]:
         """
         Get the underlying boto3 client/resource (cached)
diff --git a/airflow/providers/amazon/aws/links/__init__.py b/airflow/providers/amazon/aws/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/amazon/aws/links/base_aws.py b/airflow/providers/amazon/aws/links/base_aws.py
new file mode 100644
index 0000000000..362f92e76c
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/base_aws.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+from typing import TYPE_CHECKING, ClassVar, Optional
+from urllib.parse import quote_plus
+
+from airflow.models import BaseOperatorLink, XCom
+
+if TYPE_CHECKING:
+    from airflow.models import BaseOperator
+    from airflow.models.taskinstance import TaskInstanceKey
+    from airflow.utils.context import Context
+
+
+BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
+
+
+class BaseAwsLink(BaseOperatorLink):
+    """Base Helper class for constructing AWS Console Link"""
+
+    name: ClassVar[str]
+    key: ClassVar[str]
+    format_str: ClassVar[str]
+
+    @staticmethod
+    def get_aws_domain(aws_partition) -> Optional[str]:
+        if aws_partition == "aws":
+            return "aws.amazon.com"
+        elif aws_partition == "aws-cn":
+            return "amazonaws.cn"
+        elif aws_partition == "aws-us-gov":
+            return "amazonaws-us-gov.com"
+
+        return None
+
+    def get_link(
+        self,
+        operator,
+        dttm: Optional[datetime] = None,
+        ti_key: Optional["TaskInstanceKey"] = None,
+    ) -> str:
+        """
+        Link to Amazon Web Services Console.
+
+        :param operator: airflow operator
+        :param ti_key: TaskInstance ID to return link for
+        :param dttm: execution date. Uses for compatibility with Airflow 2.2
+        :return: link to external system
+        """
+        if ti_key is not None:
+            conf = XCom.get_value(key=self.key, ti_key=ti_key)
+        elif not dttm:
+            conf = {}
+        else:
+            conf = XCom.get_one(
+                key=self.key,
+                dag_id=operator.dag.dag_id,
+                task_id=operator.task_id,
+                execution_date=dttm,
+            )
+        if not conf:
+            return ""
+
+        # urlencode special characters, e.g.: CloudWatch links contains `/` character.
+        quoted_conf = {k: quote_plus(v) if isinstance(v, str) else v for k, v in conf.items()}
+        return self.format_str.format(**quoted_conf)
+
+    @classmethod
+    def persist(
+        cls, context: "Context", operator: "BaseOperator", region_name: str, aws_partition: str, **kwargs
+    ) -> None:
+        """Store link information into XCom"""
+        if not operator.do_xcom_push:
+            return
+
+        operator.xcom_push(
+            context,
+            key=cls.key,
+            value={
+                "region_name": region_name,
+                "aws_domain": cls.get_aws_domain(aws_partition),
+                **kwargs,
+            },
+        )
diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py
new file mode 100644
index 0000000000..ea46341dd7
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
+
+
+class EmrClusterLink(BaseAwsLink):
+    """Helper class for constructing AWS EMR Cluster Link"""
+
+    name = "EMR Cluster"
+    key = "emr_cluster"
+    format_str = (
+        BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
+    )
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 67ae54af50..77c079be83 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -17,16 +17,15 @@
 # under the License.
 import ast
 import sys
-from datetime import datetime
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
 from uuid import uuid4
 
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, BaseOperatorLink, XCom
+from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.emr import EmrHook
+from airflow.providers.amazon.aws.links.emr import EmrClusterLink
 
 if TYPE_CHECKING:
-    from airflow.models.taskinstance import TaskInstanceKey
     from airflow.utils.context import Context
 
 
@@ -62,6 +61,7 @@ class EmrAddStepsOperator(BaseOperator):
     template_ext: Sequence[str] = ('.json',)
     template_fields_renderers = {"steps": "json"}
     ui_color = '#f9c915'
+    operator_extra_links = (EmrClusterLink(),)
 
     def __init__(
         self,
@@ -101,6 +101,14 @@ class EmrAddStepsOperator(BaseOperator):
         if self.do_xcom_push:
             context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
 
+        EmrClusterLink.persist(
+            context=context,
+            operator=self,
+            region_name=emr_hook.conn_region_name,
+            aws_partition=emr_hook.conn_partition,
+            job_flow_id=job_flow_id,
+        )
+
         self.log.info('Adding steps to %s', job_flow_id)
 
         # steps may arrive as a string representing a list
@@ -243,38 +251,6 @@ class EmrContainerOperator(BaseOperator):
                     self.hook.poll_query_status(self.job_id)
 
 
-class EmrClusterLink(BaseOperatorLink):
-    """Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster"""
-
-    name = 'EMR Cluster'
-
-    def get_link(
-        self,
-        operator,
-        dttm: Optional[datetime] = None,
-        ti_key: Optional["TaskInstanceKey"] = None,
-    ) -> str:
-        """
-        Get link to EMR cluster.
-
-        :param operator: operator
-        :param dttm: datetime
-        :return: url link
-        """
-        if ti_key is not None:
-            flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
-        else:
-            assert dttm
-            flow_id = XCom.get_one(
-                key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
-            )
-        return (
-            f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
-            if flow_id
-            else ''
-        )
-
-
 class EmrCreateJobFlowOperator(BaseOperator):
     """
     Creates an EMR JobFlow, reading the config from the EMR connection.
@@ -339,8 +315,16 @@ class EmrCreateJobFlowOperator(BaseOperator):
         if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
             raise AirflowException(f'JobFlow creation failed: {response}')
         else:
-            self.log.info('JobFlow with id %s created', response['JobFlowId'])
-            return response['JobFlowId']
+            job_flow_id = response['JobFlowId']
+            self.log.info('JobFlow with id %s created', job_flow_id)
+            EmrClusterLink.persist(
+                context=context,
+                operator=self,
+                region_name=emr.conn_region_name,
+                aws_partition=emr.conn_partition,
+                job_flow_id=job_flow_id,
+            )
+            return job_flow_id
 
 
 class EmrModifyClusterOperator(BaseOperator):
@@ -360,6 +344,7 @@ class EmrModifyClusterOperator(BaseOperator):
     template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level')
     template_ext: Sequence[str] = ()
     ui_color = '#f9c915'
+    operator_extra_links = (EmrClusterLink(),)
 
     def __init__(
         self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs
@@ -373,12 +358,19 @@ class EmrModifyClusterOperator(BaseOperator):
 
     def execute(self, context: 'Context') -> int:
         emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
-
         emr = emr_hook.get_conn()
 
         if self.do_xcom_push:
             context['ti'].xcom_push(key='cluster_id', value=self.cluster_id)
 
+        EmrClusterLink.persist(
+            context=context,
+            operator=self,
+            region_name=emr_hook.conn_region_name,
+            aws_partition=emr_hook.conn_partition,
+            job_flow_id=self.cluster_id,
+        )
+
         self.log.info('Modifying cluster %s', self.cluster_id)
         response = emr.modify_cluster(
             ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level
@@ -406,6 +398,7 @@ class EmrTerminateJobFlowOperator(BaseOperator):
     template_fields: Sequence[str] = ('job_flow_id',)
     template_ext: Sequence[str] = ()
     ui_color = '#f9c915'
+    operator_extra_links = (EmrClusterLink(),)
 
     def __init__(self, *, job_flow_id: str, aws_conn_id: str = 'aws_default', **kwargs):
         super().__init__(**kwargs)
@@ -413,7 +406,16 @@ class EmrTerminateJobFlowOperator(BaseOperator):
         self.aws_conn_id = aws_conn_id
 
     def execute(self, context: 'Context') -> None:
-        emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
+        emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
+        emr = emr_hook.get_conn()
+
+        EmrClusterLink.persist(
+            context=context,
+            operator=self,
+            region_name=emr_hook.conn_region_name,
+            aws_partition=emr_hook.conn_partition,
+            job_flow_id=self.job_flow_id,
+        )
 
         self.log.info('Terminating JobFlow %s', self.job_flow_id)
         response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 8e05db7796..bfdd849c67 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -556,7 +556,7 @@ hook-class-names:  # deprecated - to be removed after providers add dependency o
   - airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook
 
 extra-links:
-  - airflow.providers.amazon.aws.operators.emr.EmrClusterLink
+  - airflow.providers.amazon.aws.links.emr.EmrClusterLink
   - airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink
 
 connection-types:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 79f0b0b572..00ffc31163 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -114,7 +114,7 @@ class CustomSessionFactory(BaseSessionFactory):
         return mock.MagicMock()
 
 
-class TestAwsBaseHook(unittest.TestCase):
+class TestAwsBaseHook:
     @conf_vars(
         {("aws", "session_factory"): "tests.providers.amazon.aws.hooks.test_base_aws.CustomSessionFactory"}
     )
@@ -647,6 +647,60 @@ class TestAwsBaseHook(unittest.TestCase):
             assert mock_refresh.call_count == 2
             assert len(expire_on_calls) == 0
 
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
+    @mock_dynamodb2
+    @pytest.mark.parametrize("conn_type", ["client", "resource"])
+    @pytest.mark.parametrize(
+        "connection_uri,region_name,env_region,expected_region_name",
+        [
+            ("aws://?region_name=eu-west-1", None, "", "eu-west-1"),
+            ("aws://?region_name=eu-west-1", "cn-north-1", "", "cn-north-1"),
+            ("aws://?region_name=eu-west-1", None, "us-east-2", "eu-west-1"),
+            ("aws://?region_name=eu-west-1", "cn-north-1", "us-gov-east-1", "cn-north-1"),
+            ("aws://?", "cn-north-1", "us-gov-east-1", "cn-north-1"),
+            ("aws://?", None, "us-gov-east-1", "us-gov-east-1"),
+        ],
+    )
+    def test_connection_region_name(
+        self, conn_type, connection_uri, region_name, env_region, expected_region_name
+    ):
+        with unittest.mock.patch.dict(
+            'os.environ', AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region
+        ):
+            if conn_type == "client":
+                hook = AwsBaseHook(aws_conn_id='test_conn', region_name=region_name, client_type='dynamodb')
+            elif conn_type == "resource":
+                hook = AwsBaseHook(aws_conn_id='test_conn', region_name=region_name, resource_type='dynamodb')
+            else:
+                raise ValueError(f"Unsupported conn_type={conn_type!r}")
+
+            assert hook.conn_region_name == expected_region_name
+
+    @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
+    @mock_dynamodb2
+    @pytest.mark.parametrize("conn_type", ["client", "resource"])
+    @pytest.mark.parametrize(
+        "connection_uri,expected_partition",
+        [
+            ("aws://?region_name=eu-west-1", "aws"),
+            ("aws://?region_name=cn-north-1", "aws-cn"),
+            ("aws://?region_name=us-gov-east-1", "aws-us-gov"),
+        ],
+    )
+    def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition):
+        with unittest.mock.patch.dict(
+            'os.environ',
+            AIRFLOW_CONN_TEST_CONN=connection_uri,
+        ):
+            if conn_type == "client":
+                hook = AwsBaseHook(aws_conn_id='test_conn', client_type='dynamodb')
+            elif conn_type == "resource":
+                hook = AwsBaseHook(aws_conn_id='test_conn', resource_type='dynamodb')
+            else:
+                raise ValueError(f"Unsupported conn_type={conn_type!r}")
+
+            assert hook.conn_partition == expected_partition
+
 
 class ThrowErrorUntilCount:
     """Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/links/__init__.py b/tests/providers/amazon/aws/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/amazon/aws/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/amazon/aws/links/test_base.py b/tests/providers/amazon/aws/links/test_base.py
new file mode 100644
index 0000000000..b6bf9073a5
--- /dev/null
+++ b/tests/providers/amazon/aws/links/test_base.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink
+from tests.test_utils.mock_operators import MockOperator
+
+XCOM_KEY = "test_xcom_key"
+CUSTOM_KEYS = {
+    "foo": "bar",
+    "spam": "egg",
+}
+
+
+class SimpleBaseAwsLink(BaseAwsLink):
+    key = XCOM_KEY
+
+
+class TestBaseAwsLink:
+    @pytest.mark.parametrize(
+        "region_name, aws_partition,keywords,expected_value",
+        [
+            ("eu-central-1", "aws", {}, {"region_name": "eu-central-1", "aws_domain": "aws.amazon.com"}),
+            ("cn-north-1", "aws-cn", {}, {"region_name": "cn-north-1", "aws_domain": "amazonaws.cn"}),
+            (
+                "us-gov-east-1",
+                "aws-us-gov",
+                {},
+                {"region_name": "us-gov-east-1", "aws_domain": "amazonaws-us-gov.com"},
+            ),
+            (
+                "eu-west-1",
+                "aws",
+                CUSTOM_KEYS,
+                {"region_name": "eu-west-1", "aws_domain": "aws.amazon.com", **CUSTOM_KEYS},
+            ),
+        ],
+    )
+    def test_persist(self, region_name, aws_partition, keywords, expected_value):
+        mock_context = MagicMock()
+
+        SimpleBaseAwsLink.persist(
+            context=mock_context,
+            operator=MockOperator(task_id="test_task_id"),
+            region_name=region_name,
+            aws_partition=aws_partition,
+            **keywords,
+        )
+
+        ti = mock_context["ti"]
+        ti.xcom_push.assert_called_once_with(
+            execution_date=None,
+            key=XCOM_KEY,
+            value=expected_value,
+        )
+
+    def test_disable_xcom_push(self):
+        mock_context = MagicMock()
+        SimpleBaseAwsLink.persist(
+            context=mock_context,
+            operator=MockOperator(task_id="test_task_id", do_xcom_push=False),
+            region_name="eu-east-1",
+            aws_partition="aws",
+        )
+        ti = mock_context["ti"]
+        ti.xcom_push.assert_not_called()
diff --git a/tests/providers/amazon/aws/links/test_emr.py b/tests/providers/amazon/aws/links/test_emr.py
new file mode 100644
index 0000000000..14510020d9
--- /dev/null
+++ b/tests/providers/amazon/aws/links/test_emr.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.links.emr import EmrClusterLink
+from airflow.serialization.serialized_objects import SerializedDAG
+from tests.providers.amazon.aws.utils.links_test_utils import link_test_operator
+
+DAG_ID = "test_dag_id"
+TASK_ID = "test_task_id"
+JOB_FLOW_ID = "j-test-flow-id"
+REGION_NAME = "eu-west-1"
+
+
+@pytest.fixture(scope="module")
+def operator_class():
+    return link_test_operator(EmrClusterLink)
+
+
+@pytest.fixture(scope="module")
+def mock_task(operator_class):
+    return operator_class(task_id=TASK_ID)
+
+
+@pytest.fixture()
+def mock_ti(create_task_instance_of_operator, operator_class):
+    return create_task_instance_of_operator(operator_class, dag_id=DAG_ID, task_id=TASK_ID)
+
+
+@pytest.mark.need_serialized_dag
+class TestEmrClusterLink:
+    def test_link_serialize(self, dag_maker, mock_ti):
+        serialized_dag = dag_maker.get_serialized_data()
+
+        assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+            {"airflow.providers.amazon.aws.links.emr.EmrClusterLink": {}}
+        ], "Operator links should exist for serialized DAG"
+
+    @pytest.mark.parametrize(
+        "region_name,aws_partition,aws_domain",
+        [
+            ("eu-west-1", "aws", "aws.amazon.com"),
+            ("cn-north-1", "aws-cn", "amazonaws.cn"),
+            ("us-gov-east-1", "aws-us-gov", "amazonaws-us-gov.com"),
+        ],
+    )
+    def test_emr_custer_link(self, dag_maker, mock_task, mock_ti, region_name, aws_partition, aws_domain):
+        mock_context = MagicMock()
+        mock_context.__getitem__.side_effect = {"ti": mock_ti}.__getitem__
+
+        EmrClusterLink.persist(
+            context=mock_context,
+            operator=mock_task,
+            region_name=region_name,
+            aws_partition=aws_partition,
+            job_flow_id=JOB_FLOW_ID,
+        )
+
+        expected = (
+            f"https://console.{aws_domain}/elasticmapreduce/home?region={region_name}"
+            f"#cluster-details:{JOB_FLOW_ID}"
+        )
+        assert (
+            mock_ti.task.get_extra_links(mock_ti, EmrClusterLink.name) == expected
+        ), "Operator link should be preserved after execution"
+
+        serialized_dag = dag_maker.get_serialized_data()
+        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+        deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+        assert (
+            deserialized_task.get_extra_links(mock_ti, EmrClusterLink.name) == expected
+        ), "Operator link should be preserved in deserialized tasks after execution"
+
+    def test_empty_xcom(self, dag_maker, mock_ti):
+        serialized_dag = dag_maker.get_serialized_data()
+        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+        deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+        assert (
+            mock_ti.task.get_extra_links(mock_ti, EmrClusterLink.name) == ""
+        ), "Operator link should only be added if job id is available in XCom"
+
+        assert (
+            deserialized_task.get_extra_links(mock_ti, EmrClusterLink.name) == ""
+        ), "Operator link should be empty for deserialized task with no XCom push"
diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
index f96f7c6f36..05b8d2de4e 100644
--- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py
+++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
@@ -20,7 +20,7 @@ import json
 import os
 import unittest
 from datetime import timedelta
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, call, patch
 
 import pytest
 from jinja2 import StrictUndefined
@@ -170,8 +170,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
                 operator.execute(self.mock_context)
 
         ti = self.mock_context['ti']
-
-        ti.xcom_push.assert_called_once_with(key='job_flow_id', value=expected_job_flow_id)
+        ti.assert_has_calls(calls=[call.xcom_push(key='job_flow_id', value=expected_job_flow_id)])
 
     def test_init_with_nonexistent_cluster_name(self):
         cluster_name = 'test_cluster'
diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
index 5dbda61616..c7ac3bfe4b 100644
--- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
@@ -22,13 +22,10 @@ import unittest
 from datetime import timedelta
 from unittest.mock import MagicMock, patch
 
-import pytest
 from jinja2 import StrictUndefined
 
 from airflow.models import DAG, DagRun, TaskInstance
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrCreateJobFlowOperator
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
 from airflow.utils import timezone
 from tests.test_utils import AIRFLOW_MAIN_FOLDER
 
@@ -79,6 +76,7 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
                 template_undefined=StrictUndefined,
             ),
         )
+        self.mock_context = MagicMock()
 
     def test_init(self):
         assert self.operator.aws_conn_id == 'aws_default'
@@ -129,7 +127,7 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
 
         # String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute()
         with patch('boto3.session.Session', boto3_session_mock):
-            self.operator.execute(None)
+            self.operator.execute(self.mock_context)
 
         expected_args = {
             'Name': 'test_job_flow',
@@ -161,42 +159,4 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
         boto3_session_mock = MagicMock(return_value=emr_session_mock)
 
         with patch('boto3.session.Session', boto3_session_mock):
-            assert self.operator.execute(None) == 'j-8989898989'
-
-
-@pytest.mark.need_serialized_dag
-def test_operator_extra_links(dag_maker, create_task_instance_of_operator):
-    ti = create_task_instance_of_operator(
-        EmrCreateJobFlowOperator, dag_id=TEST_DAG_ID, execution_date=DEFAULT_DATE, task_id=TASK_ID
-    )
-
-    serialized_dag = dag_maker.get_serialized_data()
-    deserialized_dag = SerializedDAG.from_dict(serialized_dag)
-    deserialized_task = deserialized_dag.task_dict[TASK_ID]
-
-    assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
-        {"airflow.providers.amazon.aws.operators.emr.EmrClusterLink": {}}
-    ], "Operator links should exist for serialized DAG"
-
-    assert isinstance(
-        deserialized_task.operator_extra_links[0], EmrClusterLink
-    ), "Operator link type should be preserved during deserialization"
-
-    assert (
-        ti.task.get_extra_links(ti, EmrClusterLink.name) == ""
-    ), "Operator link should only be added if job id is available in XCom"
-
-    assert (
-        deserialized_task.get_extra_links(ti, EmrClusterLink.name) == ""
-    ), "Operator link should be empty for deserialized task with no XCom push"
-
-    ti.xcom_push(key=XCOM_RETURN_KEY, value='j-SomeClusterId')
-
-    expected = "https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:j-SomeClusterId"
-    assert (
-        deserialized_task.get_extra_links(ti, EmrClusterLink.name) == expected
-    ), "Operator link should be preserved in deserialized tasks after execution"
-
-    assert (
-        ti.task.get_extra_links(ti, EmrClusterLink.name) == expected
-    ), "Operator link should be preserved after execution"
+            assert self.operator.execute(self.mock_context) == 'j-8989898989'
diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
index 42c76d4988..bca75c8aea 100644
--- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
@@ -42,4 +42,4 @@ class TestEmrTerminateJobFlowOperator(unittest.TestCase):
                 task_id='test_task', job_flow_id='j-8989898989', aws_conn_id='aws_default'
             )
 
-            operator.execute(None)
+            operator.execute(MagicMock())
diff --git a/tests/providers/amazon/aws/utils/links_test_utils.py b/tests/providers/amazon/aws/utils/links_test_utils.py
new file mode 100644
index 0000000000..e449bf99aa
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/links_test_utils.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from tests.test_utils.mock_operators import MockOperator
+
+
+def link_test_operator(*links):
+    class LinkTestOperator(MockOperator):
+        operator_extra_links = tuple(c() for c in links)
+
+    return LinkTestOperator