You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/12 08:51:14 UTC
[airflow] branch main updated: Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 19dd9f5873 Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)
19dd9f5873 is described below
commit 19dd9f5873098decb41040b0c252a6072a67a356
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Sun Jun 12 12:51:08 2022 +0400
Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 16 ++++
airflow/providers/amazon/aws/links/__init__.py | 16 ++++
airflow/providers/amazon/aws/links/base_aws.py | 100 ++++++++++++++++++++
airflow/providers/amazon/aws/links/emr.py | 28 ++++++
airflow/providers/amazon/aws/operators/emr.py | 80 ++++++++--------
airflow/providers/amazon/provider.yaml | 2 +-
tests/providers/amazon/aws/hooks/test_base_aws.py | 56 ++++++++++-
tests/providers/amazon/aws/links/__init__.py | 16 ++++
tests/providers/amazon/aws/links/test_base.py | 83 +++++++++++++++++
tests/providers/amazon/aws/links/test_emr.py | 103 +++++++++++++++++++++
.../amazon/aws/operators/test_emr_add_steps.py | 5 +-
.../aws/operators/test_emr_create_job_flow.py | 48 +---------
.../aws/operators/test_emr_terminate_job_flow.py | 2 +-
.../providers/amazon/aws/utils/links_test_utils.py | 25 +++++
14 files changed, 491 insertions(+), 89 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index bbf0bfff83..162be9ce47 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -37,6 +37,7 @@ import botocore
import botocore.session
import requests
import tenacity
+from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
from slugify import slugify
@@ -521,6 +522,21 @@ class AwsBaseHook(BaseHook):
# Rare possibility - subclasses have not specified a client_type or resource_type
raise NotImplementedError('Could not get boto3 connection!')
+ @cached_property
+ def conn_client_meta(self) -> ClientMeta:
+ conn = self.conn
+ if isinstance(conn, botocore.client.BaseClient):
+ return conn.meta
+ return conn.meta.client.meta
+
+ @property
+ def conn_region_name(self) -> str:
+ return self.conn_client_meta.region_name
+
+ @property
+ def conn_partition(self) -> str:
+ return self.conn_client_meta.partition
+
def get_conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
diff --git a/airflow/providers/amazon/aws/links/__init__.py b/airflow/providers/amazon/aws/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/airflow/providers/amazon/aws/links/base_aws.py b/airflow/providers/amazon/aws/links/base_aws.py
new file mode 100644
index 0000000000..362f92e76c
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/base_aws.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from datetime import datetime
+from typing import TYPE_CHECKING, ClassVar, Optional
+from urllib.parse import quote_plus
+
+from airflow.models import BaseOperatorLink, XCom
+
+if TYPE_CHECKING:
+ from airflow.models import BaseOperator
+ from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.utils.context import Context
+
+
+BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
+
+
+class BaseAwsLink(BaseOperatorLink):
+ """Base Helper class for constructing AWS Console Link"""
+
+ name: ClassVar[str]
+ key: ClassVar[str]
+ format_str: ClassVar[str]
+
+ @staticmethod
+ def get_aws_domain(aws_partition) -> Optional[str]:
+ if aws_partition == "aws":
+ return "aws.amazon.com"
+ elif aws_partition == "aws-cn":
+ return "amazonaws.cn"
+ elif aws_partition == "aws-us-gov":
+ return "amazonaws-us-gov.com"
+
+ return None
+
+ def get_link(
+ self,
+ operator,
+ dttm: Optional[datetime] = None,
+ ti_key: Optional["TaskInstanceKey"] = None,
+ ) -> str:
+ """
+ Link to Amazon Web Services Console.
+
+ :param operator: airflow operator
+ :param ti_key: TaskInstance ID to return link for
+ :param dttm: execution date. Uses for compatibility with Airflow 2.2
+ :return: link to external system
+ """
+ if ti_key is not None:
+ conf = XCom.get_value(key=self.key, ti_key=ti_key)
+ elif not dttm:
+ conf = {}
+ else:
+ conf = XCom.get_one(
+ key=self.key,
+ dag_id=operator.dag.dag_id,
+ task_id=operator.task_id,
+ execution_date=dttm,
+ )
+ if not conf:
+ return ""
+
+ # urlencode special characters, e.g.: CloudWatch links contains `/` character.
+ quoted_conf = {k: quote_plus(v) if isinstance(v, str) else v for k, v in conf.items()}
+ return self.format_str.format(**quoted_conf)
+
+ @classmethod
+ def persist(
+ cls, context: "Context", operator: "BaseOperator", region_name: str, aws_partition: str, **kwargs
+ ) -> None:
+ """Store link information into XCom"""
+ if not operator.do_xcom_push:
+ return
+
+ operator.xcom_push(
+ context,
+ key=cls.key,
+ value={
+ "region_name": region_name,
+ "aws_domain": cls.get_aws_domain(aws_partition),
+ **kwargs,
+ },
+ )
diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py
new file mode 100644
index 0000000000..ea46341dd7
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
+
+
+class EmrClusterLink(BaseAwsLink):
+ """Helper class for constructing AWS EMR Cluster Link"""
+
+ name = "EMR Cluster"
+ key = "emr_cluster"
+ format_str = (
+ BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
+ )
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 67ae54af50..77c079be83 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -17,16 +17,15 @@
# under the License.
import ast
import sys
-from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, BaseOperatorLink, XCom
+from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.emr import EmrHook
+from airflow.providers.amazon.aws.links.emr import EmrClusterLink
if TYPE_CHECKING:
- from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.context import Context
@@ -62,6 +61,7 @@ class EmrAddStepsOperator(BaseOperator):
template_ext: Sequence[str] = ('.json',)
template_fields_renderers = {"steps": "json"}
ui_color = '#f9c915'
+ operator_extra_links = (EmrClusterLink(),)
def __init__(
self,
@@ -101,6 +101,14 @@ class EmrAddStepsOperator(BaseOperator):
if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=job_flow_id,
+ )
+
self.log.info('Adding steps to %s', job_flow_id)
# steps may arrive as a string representing a list
@@ -243,38 +251,6 @@ class EmrContainerOperator(BaseOperator):
self.hook.poll_query_status(self.job_id)
-class EmrClusterLink(BaseOperatorLink):
- """Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster"""
-
- name = 'EMR Cluster'
-
- def get_link(
- self,
- operator,
- dttm: Optional[datetime] = None,
- ti_key: Optional["TaskInstanceKey"] = None,
- ) -> str:
- """
- Get link to EMR cluster.
-
- :param operator: operator
- :param dttm: datetime
- :return: url link
- """
- if ti_key is not None:
- flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
- else:
- assert dttm
- flow_id = XCom.get_one(
- key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
- if flow_id
- else ''
- )
-
-
class EmrCreateJobFlowOperator(BaseOperator):
"""
Creates an EMR JobFlow, reading the config from the EMR connection.
@@ -339,8 +315,16 @@ class EmrCreateJobFlowOperator(BaseOperator):
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException(f'JobFlow creation failed: {response}')
else:
- self.log.info('JobFlow with id %s created', response['JobFlowId'])
- return response['JobFlowId']
+ job_flow_id = response['JobFlowId']
+ self.log.info('JobFlow with id %s created', job_flow_id)
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr.conn_region_name,
+ aws_partition=emr.conn_partition,
+ job_flow_id=job_flow_id,
+ )
+ return job_flow_id
class EmrModifyClusterOperator(BaseOperator):
@@ -360,6 +344,7 @@ class EmrModifyClusterOperator(BaseOperator):
template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level')
template_ext: Sequence[str] = ()
ui_color = '#f9c915'
+ operator_extra_links = (EmrClusterLink(),)
def __init__(
self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs
@@ -373,12 +358,19 @@ class EmrModifyClusterOperator(BaseOperator):
def execute(self, context: 'Context') -> int:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
-
emr = emr_hook.get_conn()
if self.do_xcom_push:
context['ti'].xcom_push(key='cluster_id', value=self.cluster_id)
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=self.cluster_id,
+ )
+
self.log.info('Modifying cluster %s', self.cluster_id)
response = emr.modify_cluster(
ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level
@@ -406,6 +398,7 @@ class EmrTerminateJobFlowOperator(BaseOperator):
template_fields: Sequence[str] = ('job_flow_id',)
template_ext: Sequence[str] = ()
ui_color = '#f9c915'
+ operator_extra_links = (EmrClusterLink(),)
def __init__(self, *, job_flow_id: str, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)
@@ -413,7 +406,16 @@ class EmrTerminateJobFlowOperator(BaseOperator):
self.aws_conn_id = aws_conn_id
def execute(self, context: 'Context') -> None:
- emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
+ emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
+ emr = emr_hook.get_conn()
+
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=self.job_flow_id,
+ )
self.log.info('Terminating JobFlow %s', self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 8e05db7796..bfdd849c67 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -556,7 +556,7 @@ hook-class-names: # deprecated - to be removed after providers add dependency o
- airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook
extra-links:
- - airflow.providers.amazon.aws.operators.emr.EmrClusterLink
+ - airflow.providers.amazon.aws.links.emr.EmrClusterLink
- airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink
connection-types:
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 79f0b0b572..00ffc31163 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -114,7 +114,7 @@ class CustomSessionFactory(BaseSessionFactory):
return mock.MagicMock()
-class TestAwsBaseHook(unittest.TestCase):
+class TestAwsBaseHook:
@conf_vars(
{("aws", "session_factory"): "tests.providers.amazon.aws.hooks.test_base_aws.CustomSessionFactory"}
)
@@ -647,6 +647,60 @@ class TestAwsBaseHook(unittest.TestCase):
assert mock_refresh.call_count == 2
assert len(expire_on_calls) == 0
+ @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
+ @mock_dynamodb2
+ @pytest.mark.parametrize("conn_type", ["client", "resource"])
+ @pytest.mark.parametrize(
+ "connection_uri,region_name,env_region,expected_region_name",
+ [
+ ("aws://?region_name=eu-west-1", None, "", "eu-west-1"),
+ ("aws://?region_name=eu-west-1", "cn-north-1", "", "cn-north-1"),
+ ("aws://?region_name=eu-west-1", None, "us-east-2", "eu-west-1"),
+ ("aws://?region_name=eu-west-1", "cn-north-1", "us-gov-east-1", "cn-north-1"),
+ ("aws://?", "cn-north-1", "us-gov-east-1", "cn-north-1"),
+ ("aws://?", None, "us-gov-east-1", "us-gov-east-1"),
+ ],
+ )
+ def test_connection_region_name(
+ self, conn_type, connection_uri, region_name, env_region, expected_region_name
+ ):
+ with unittest.mock.patch.dict(
+ 'os.environ', AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region
+ ):
+ if conn_type == "client":
+ hook = AwsBaseHook(aws_conn_id='test_conn', region_name=region_name, client_type='dynamodb')
+ elif conn_type == "resource":
+ hook = AwsBaseHook(aws_conn_id='test_conn', region_name=region_name, resource_type='dynamodb')
+ else:
+ raise ValueError(f"Unsupported conn_type={conn_type!r}")
+
+ assert hook.conn_region_name == expected_region_name
+
+ @unittest.skipIf(mock_dynamodb2 is None, 'mock_dynamo2 package not present')
+ @mock_dynamodb2
+ @pytest.mark.parametrize("conn_type", ["client", "resource"])
+ @pytest.mark.parametrize(
+ "connection_uri,expected_partition",
+ [
+ ("aws://?region_name=eu-west-1", "aws"),
+ ("aws://?region_name=cn-north-1", "aws-cn"),
+ ("aws://?region_name=us-gov-east-1", "aws-us-gov"),
+ ],
+ )
+ def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition):
+ with unittest.mock.patch.dict(
+ 'os.environ',
+ AIRFLOW_CONN_TEST_CONN=connection_uri,
+ ):
+ if conn_type == "client":
+ hook = AwsBaseHook(aws_conn_id='test_conn', client_type='dynamodb')
+ elif conn_type == "resource":
+ hook = AwsBaseHook(aws_conn_id='test_conn', resource_type='dynamodb')
+ else:
+ raise ValueError(f"Unsupported conn_type={conn_type!r}")
+
+ assert hook.conn_partition == expected_partition
+
class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/links/__init__.py b/tests/providers/amazon/aws/links/__init__.py
new file mode 100644
index 0000000000..13a83393a9
--- /dev/null
+++ b/tests/providers/amazon/aws/links/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/tests/providers/amazon/aws/links/test_base.py b/tests/providers/amazon/aws/links/test_base.py
new file mode 100644
index 0000000000..b6bf9073a5
--- /dev/null
+++ b/tests/providers/amazon/aws/links/test_base.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink
+from tests.test_utils.mock_operators import MockOperator
+
+XCOM_KEY = "test_xcom_key"
+CUSTOM_KEYS = {
+ "foo": "bar",
+ "spam": "egg",
+}
+
+
+class SimpleBaseAwsLink(BaseAwsLink):
+ key = XCOM_KEY
+
+
+class TestBaseAwsLink:
+ @pytest.mark.parametrize(
+ "region_name, aws_partition,keywords,expected_value",
+ [
+ ("eu-central-1", "aws", {}, {"region_name": "eu-central-1", "aws_domain": "aws.amazon.com"}),
+ ("cn-north-1", "aws-cn", {}, {"region_name": "cn-north-1", "aws_domain": "amazonaws.cn"}),
+ (
+ "us-gov-east-1",
+ "aws-us-gov",
+ {},
+ {"region_name": "us-gov-east-1", "aws_domain": "amazonaws-us-gov.com"},
+ ),
+ (
+ "eu-west-1",
+ "aws",
+ CUSTOM_KEYS,
+ {"region_name": "eu-west-1", "aws_domain": "aws.amazon.com", **CUSTOM_KEYS},
+ ),
+ ],
+ )
+ def test_persist(self, region_name, aws_partition, keywords, expected_value):
+ mock_context = MagicMock()
+
+ SimpleBaseAwsLink.persist(
+ context=mock_context,
+ operator=MockOperator(task_id="test_task_id"),
+ region_name=region_name,
+ aws_partition=aws_partition,
+ **keywords,
+ )
+
+ ti = mock_context["ti"]
+ ti.xcom_push.assert_called_once_with(
+ execution_date=None,
+ key=XCOM_KEY,
+ value=expected_value,
+ )
+
+ def test_disable_xcom_push(self):
+ mock_context = MagicMock()
+ SimpleBaseAwsLink.persist(
+ context=mock_context,
+ operator=MockOperator(task_id="test_task_id", do_xcom_push=False),
+ region_name="eu-east-1",
+ aws_partition="aws",
+ )
+ ti = mock_context["ti"]
+ ti.xcom_push.assert_not_called()
diff --git a/tests/providers/amazon/aws/links/test_emr.py b/tests/providers/amazon/aws/links/test_emr.py
new file mode 100644
index 0000000000..14510020d9
--- /dev/null
+++ b/tests/providers/amazon/aws/links/test_emr.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.links.emr import EmrClusterLink
+from airflow.serialization.serialized_objects import SerializedDAG
+from tests.providers.amazon.aws.utils.links_test_utils import link_test_operator
+
+DAG_ID = "test_dag_id"
+TASK_ID = "test_task_id"
+JOB_FLOW_ID = "j-test-flow-id"
+REGION_NAME = "eu-west-1"
+
+
+@pytest.fixture(scope="module")
+def operator_class():
+ return link_test_operator(EmrClusterLink)
+
+
+@pytest.fixture(scope="module")
+def mock_task(operator_class):
+ return operator_class(task_id=TASK_ID)
+
+
+@pytest.fixture()
+def mock_ti(create_task_instance_of_operator, operator_class):
+ return create_task_instance_of_operator(operator_class, dag_id=DAG_ID, task_id=TASK_ID)
+
+
+@pytest.mark.need_serialized_dag
+class TestEmrClusterLink:
+ def test_link_serialize(self, dag_maker, mock_ti):
+ serialized_dag = dag_maker.get_serialized_data()
+
+ assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
+ {"airflow.providers.amazon.aws.links.emr.EmrClusterLink": {}}
+ ], "Operator links should exist for serialized DAG"
+
+ @pytest.mark.parametrize(
+ "region_name,aws_partition,aws_domain",
+ [
+ ("eu-west-1", "aws", "aws.amazon.com"),
+ ("cn-north-1", "aws-cn", "amazonaws.cn"),
+ ("us-gov-east-1", "aws-us-gov", "amazonaws-us-gov.com"),
+ ],
+ )
+ def test_emr_custer_link(self, dag_maker, mock_task, mock_ti, region_name, aws_partition, aws_domain):
+ mock_context = MagicMock()
+ mock_context.__getitem__.side_effect = {"ti": mock_ti}.__getitem__
+
+ EmrClusterLink.persist(
+ context=mock_context,
+ operator=mock_task,
+ region_name=region_name,
+ aws_partition=aws_partition,
+ job_flow_id=JOB_FLOW_ID,
+ )
+
+ expected = (
+ f"https://console.{aws_domain}/elasticmapreduce/home?region={region_name}"
+ f"#cluster-details:{JOB_FLOW_ID}"
+ )
+ assert (
+ mock_ti.task.get_extra_links(mock_ti, EmrClusterLink.name) == expected
+ ), "Operator link should be preserved after execution"
+
+ serialized_dag = dag_maker.get_serialized_data()
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ assert (
+ deserialized_task.get_extra_links(mock_ti, EmrClusterLink.name) == expected
+ ), "Operator link should be preserved in deserialized tasks after execution"
+
+ def test_empty_xcom(self, dag_maker, mock_ti):
+ serialized_dag = dag_maker.get_serialized_data()
+ deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+ deserialized_task = deserialized_dag.task_dict[TASK_ID]
+
+ assert (
+ mock_ti.task.get_extra_links(mock_ti, EmrClusterLink.name) == ""
+ ), "Operator link should only be added if job id is available in XCom"
+
+ assert (
+ deserialized_task.get_extra_links(mock_ti, EmrClusterLink.name) == ""
+ ), "Operator link should be empty for deserialized task with no XCom push"
diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
index f96f7c6f36..05b8d2de4e 100644
--- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py
+++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
@@ -20,7 +20,7 @@ import json
import os
import unittest
from datetime import timedelta
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock, call, patch
import pytest
from jinja2 import StrictUndefined
@@ -170,8 +170,7 @@ class TestEmrAddStepsOperator(unittest.TestCase):
operator.execute(self.mock_context)
ti = self.mock_context['ti']
-
- ti.xcom_push.assert_called_once_with(key='job_flow_id', value=expected_job_flow_id)
+ ti.assert_has_calls(calls=[call.xcom_push(key='job_flow_id', value=expected_job_flow_id)])
def test_init_with_nonexistent_cluster_name(self):
cluster_name = 'test_cluster'
diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
index 5dbda61616..c7ac3bfe4b 100644
--- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
@@ -22,13 +22,10 @@ import unittest
from datetime import timedelta
from unittest.mock import MagicMock, patch
-import pytest
from jinja2 import StrictUndefined
from airflow.models import DAG, DagRun, TaskInstance
-from airflow.models.xcom import XCOM_RETURN_KEY
-from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrCreateJobFlowOperator
-from airflow.serialization.serialized_objects import SerializedDAG
+from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.utils import timezone
from tests.test_utils import AIRFLOW_MAIN_FOLDER
@@ -79,6 +76,7 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
template_undefined=StrictUndefined,
),
)
+ self.mock_context = MagicMock()
def test_init(self):
assert self.operator.aws_conn_id == 'aws_default'
@@ -129,7 +127,7 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
# String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute()
with patch('boto3.session.Session', boto3_session_mock):
- self.operator.execute(None)
+ self.operator.execute(self.mock_context)
expected_args = {
'Name': 'test_job_flow',
@@ -161,42 +159,4 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
boto3_session_mock = MagicMock(return_value=emr_session_mock)
with patch('boto3.session.Session', boto3_session_mock):
- assert self.operator.execute(None) == 'j-8989898989'
-
-
-@pytest.mark.need_serialized_dag
-def test_operator_extra_links(dag_maker, create_task_instance_of_operator):
- ti = create_task_instance_of_operator(
- EmrCreateJobFlowOperator, dag_id=TEST_DAG_ID, execution_date=DEFAULT_DATE, task_id=TASK_ID
- )
-
- serialized_dag = dag_maker.get_serialized_data()
- deserialized_dag = SerializedDAG.from_dict(serialized_dag)
- deserialized_task = deserialized_dag.task_dict[TASK_ID]
-
- assert serialized_dag["dag"]["tasks"][0]["_operator_extra_links"] == [
- {"airflow.providers.amazon.aws.operators.emr.EmrClusterLink": {}}
- ], "Operator links should exist for serialized DAG"
-
- assert isinstance(
- deserialized_task.operator_extra_links[0], EmrClusterLink
- ), "Operator link type should be preserved during deserialization"
-
- assert (
- ti.task.get_extra_links(ti, EmrClusterLink.name) == ""
- ), "Operator link should only be added if job id is available in XCom"
-
- assert (
- deserialized_task.get_extra_links(ti, EmrClusterLink.name) == ""
- ), "Operator link should be empty for deserialized task with no XCom push"
-
- ti.xcom_push(key=XCOM_RETURN_KEY, value='j-SomeClusterId')
-
- expected = "https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:j-SomeClusterId"
- assert (
- deserialized_task.get_extra_links(ti, EmrClusterLink.name) == expected
- ), "Operator link should be preserved in deserialized tasks after execution"
-
- assert (
- ti.task.get_extra_links(ti, EmrClusterLink.name) == expected
- ), "Operator link should be preserved after execution"
+ assert self.operator.execute(self.mock_context) == 'j-8989898989'
diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
index 42c76d4988..bca75c8aea 100644
--- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
@@ -42,4 +42,4 @@ class TestEmrTerminateJobFlowOperator(unittest.TestCase):
task_id='test_task', job_flow_id='j-8989898989', aws_conn_id='aws_default'
)
- operator.execute(None)
+ operator.execute(MagicMock())
diff --git a/tests/providers/amazon/aws/utils/links_test_utils.py b/tests/providers/amazon/aws/utils/links_test_utils.py
new file mode 100644
index 0000000000..e449bf99aa
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/links_test_utils.py
@@ -0,0 +1,25 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from tests.test_utils.mock_operators import MockOperator
+
+
+def link_test_operator(*links):
+ class LinkTestOperator(MockOperator):
+ operator_extra_links = tuple(c() for c in links)
+
+ return LinkTestOperator