You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/22 19:34:46 UTC

[airflow] branch main updated: Implement `EmrEksCreateClusterOperator` (#25816)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b7a343b25 Implement `EmrEksCreateClusterOperator` (#25816)
6b7a343b25 is described below

commit 6b7a343b25b06ab592f19b7e70843dda2d7e0fdb
Author: Phani Kumar <94...@users.noreply.github.com>
AuthorDate: Tue Aug 23 01:04:22 2022 +0530

    Implement `EmrEksCreateClusterOperator` (#25816)
---
 .../amazon/aws/example_dags/example_emr_eks.py     | 18 +++++--
 airflow/providers/amazon/aws/hooks/emr.py          | 26 ++++++++++
 airflow/providers/amazon/aws/operators/emr.py      | 56 ++++++++++++++++++++++
 .../operators/emr_eks.rst                          | 26 ++++++++++
 .../amazon/aws/hooks/test_emr_containers.py        | 17 +++++++
 .../amazon/aws/operators/test_emr_containers.py    | 43 ++++++++++++++++-
 6 files changed, 181 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_eks.py b/airflow/providers/amazon/aws/example_dags/example_emr_eks.py
index 0e7774f81d..413ab9e15c 100644
--- a/airflow/providers/amazon/aws/example_dags/example_emr_eks.py
+++ b/airflow/providers/amazon/aws/example_dags/example_emr_eks.py
@@ -20,10 +20,9 @@ from datetime import datetime
 
 from airflow import DAG
 from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
+from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator
 from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor
 
-VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster")
 JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role")
 
 # [START howto_operator_emr_eks_config]
@@ -58,10 +57,19 @@ with DAG(
     tags=['example'],
     catchup=False,
 ) as dag:
+    # [START howto_operator_emr_eks_create_cluster]
+    create_emr_eks_cluster = EmrEksCreateClusterOperator(
+        task_id="create_emr_eks_cluster",
+        virtual_cluster_name="emr_eks_virtual_cluster",
+        eks_cluster_name="eks_cluster",
+        eks_namespace="eks_namespace",
+    )
+    # [END howto_operator_emr_eks_create_cluster]
+
     # [START howto_operator_emr_container]
     job_starter = EmrContainerOperator(
         task_id="start_job",
-        virtual_cluster_id=VIRTUAL_CLUSTER_ID,
+        virtual_cluster_id=str(create_emr_eks_cluster.output),
         execution_role_arn=JOB_ROLE_ARN,
         release_label="emr-6.3.0-latest",
         job_driver=JOB_DRIVER_ARG,
@@ -73,7 +81,9 @@ with DAG(
 
     # [START howto_sensor_emr_container]
     job_waiter = EmrContainerSensor(
-        task_id="job_waiter", virtual_cluster_id=VIRTUAL_CLUSTER_ID, job_id=str(job_starter.output)
+        task_id="job_waiter",
+        virtual_cluster_id=str(create_emr_eks_cluster.output),
+        job_id=str(job_starter.output),
     )
     # [END howto_sensor_emr_container]
 
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index e085bff899..48fc7684f1 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -197,6 +197,32 @@ class EmrContainerHook(AwsBaseHook):
         super().__init__(client_type="emr-containers", *args, **kwargs)  # type: ignore
         self.virtual_cluster_id = virtual_cluster_id
 
+    def create_emr_on_eks_cluster(
+        self,
+        virtual_cluster_name: str,
+        eks_cluster_name: str,
+        eks_namespace: str,
+        tags: Optional[dict] = None,
+    ) -> str:
+        response = self.conn.create_virtual_cluster(
+            name=virtual_cluster_name,
+            containerProvider={
+                "id": eks_cluster_name,
+                "type": "EKS",
+                "info": {"eksInfo": {"namespace": eks_namespace}},
+            },
+            tags=tags or {},
+        )
+
+        if response['ResponseMetadata']['HTTPStatusCode'] != 200:
+            raise AirflowException(f'Create EMR EKS Cluster failed: {response}')
+        else:
+            self.log.info(
+                "Create EMR EKS Cluster success - virtual cluster id %s",
+                response['id'],
+            )
+            return response['id']
+
     def submit_job(
         self,
         name: str,
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 1cac1eb0f5..3b7bda0cd6 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -118,6 +118,62 @@ class EmrAddStepsOperator(BaseOperator):
             return response['StepIds']
 
 
+class EmrEksCreateClusterOperator(BaseOperator):
+    """
+    An operator that creates EMR on EKS virtual clusters.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the guide:
+        :ref:`howto/operator:EmrEksCreateClusterOperator`
+
+    :param virtual_cluster_name: The name of the EMR EKS virtual cluster to create.
+    :param eks_cluster_name: The EKS cluster used by the EMR virtual cluster.
+    :param eks_namespace: namespace used by the EKS cluster.
+    :param virtual_cluster_id: The EMR on EKS virtual cluster id.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param tags: The tags assigned to created cluster.
+        Defaults to None
+    """
+
+    template_fields: Sequence[str] = (
+        "virtual_cluster_name",
+        "eks_cluster_name",
+        "eks_namespace",
+    )
+    ui_color = "#f9c915"
+
+    def __init__(
+        self,
+        *,
+        virtual_cluster_name: str,
+        eks_cluster_name: str,
+        eks_namespace: str,
+        virtual_cluster_id: str = '',
+        aws_conn_id: str = "aws_default",
+        tags: Optional[dict] = None,
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.virtual_cluster_name = virtual_cluster_name
+        self.eks_cluster_name = eks_cluster_name
+        self.eks_namespace = eks_namespace
+        self.virtual_cluster_id = virtual_cluster_id
+        self.aws_conn_id = aws_conn_id
+        self.tags = tags
+
+    @cached_property
+    def hook(self) -> EmrContainerHook:
+        """Create and return an EmrContainerHook."""
+        return EmrContainerHook(self.aws_conn_id)
+
+    def execute(self, context: 'Context') -> Optional[str]:
+        """Create EMR on EKS virtual Cluster"""
+        self.virtual_cluster_id = self.hook.create_emr_on_eks_cluster(
+            self.virtual_cluster_name, self.eks_cluster_name, self.eks_namespace, self.tags
+        )
+        return self.virtual_cluster_id
+
+
 class EmrContainerOperator(BaseOperator):
     """
     An operator that submits jobs to EMR on EKS virtual clusters.
diff --git a/docs/apache-airflow-providers-amazon/operators/emr_eks.rst b/docs/apache-airflow-providers-amazon/operators/emr_eks.rst
index 3e9e58fbca..1116aa6be6 100644
--- a/docs/apache-airflow-providers-amazon/operators/emr_eks.rst
+++ b/docs/apache-airflow-providers-amazon/operators/emr_eks.rst
@@ -31,8 +31,33 @@ Prerequisite Tasks
 Operators
 ---------
 
+
+.. _howto/operator:EmrEksCreateClusterOperator:
+
+
+Create an Amazon EMR EKS virtual cluster
+========================================
+
+
+The ``EmrEksCreateClusterOperator`` will create an Amazon EMR on EKS virtual cluster.
+The example DAG below shows how to create an EMR on EKS virtual cluster.
+
+To create an Amazon EMR cluster on Amazon EKS, you need to specify a virtual cluster name,
+the eks cluster that you would like to use , and an eks namespace.
+
+Refer to the `EMR on EKS Development guide <https://docs.aws.amazon.com/emr/latest/EMR-on-EKS-DevelopmentGuide/virtual-cluster.html>`__
+for more details.
+
+.. exampleinclude:: /../../airflow/providers/amazon/aws/example_dags/example_emr_eks.py
+    :language: python
+    :start-after: [START howto_operator_emr_eks_create_cluster]
+    :end-before: [END howto_operator_emr_eks_create_cluster]
+
+
+
 .. _howto/operator:EmrContainerOperator:
 
+
 Submit a job to an Amazon EMR virtual cluster
 =============================================
 
@@ -72,6 +97,7 @@ that gets passed to the operator with the ``aws_conn_id`` parameter. The operato
     :start-after: [START howto_operator_emr_container]
     :end-before: [END howto_operator_emr_container]
 
+
 Sensors
 -------
 
diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py
index 8ad54e26c8..0237f28b2b 100644
--- a/tests/providers/amazon/aws/hooks/test_emr_containers.py
+++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py
@@ -28,6 +28,8 @@ SUBMIT_JOB_SUCCESS_RETURN = {
     'virtualClusterId': 'vc1234',
 }
 
+CREATE_EMR_ON_EKS_CLUSTER_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'id': 'vc1234'}
+
 JOB1_RUN_DESCRIPTION = {
     'jobRun': {
         'id': 'job123456',
@@ -53,6 +55,21 @@ class TestEmrContainerHook(unittest.TestCase):
         assert self.emr_containers.aws_conn_id == 'aws_default'
         assert self.emr_containers.virtual_cluster_id == 'vc1234'
 
+    @mock.patch("boto3.session.Session")
+    def test_create_emr_on_eks_cluster(self, mock_session):
+        emr_client_mock = mock.MagicMock()
+        emr_client_mock.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN
+        emr_session_mock = mock.MagicMock()
+        emr_session_mock.client.return_value = emr_client_mock
+        mock_session.return_value = emr_session_mock
+
+        emr_on_eks_create_cluster_response = self.emr_containers.create_emr_on_eks_cluster(
+            virtual_cluster_name="test_virtual_cluster",
+            eks_cluster_name="test_eks_cluster",
+            eks_namespace="test_eks_namespace",
+        )
+        assert emr_on_eks_create_cluster_response == "vc1234"
+
     @mock.patch("boto3.session.Session")
     def test_submit_job(self, mock_session):
         # Mock out the emr_client creator
diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py
index 8bd58e2b6a..3a7dd400d8 100644
--- a/tests/providers/amazon/aws/operators/test_emr_containers.py
+++ b/tests/providers/amazon/aws/operators/test_emr_containers.py
@@ -24,7 +24,7 @@ import pytest
 from airflow import configuration
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
-from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
+from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator
 
 SUBMIT_JOB_SUCCESS_RETURN = {
     'ResponseMetadata': {'HTTPStatusCode': 200},
@@ -32,6 +32,8 @@ SUBMIT_JOB_SUCCESS_RETURN = {
     'virtualClusterId': 'vc1234',
 }
 
+CREATE_EMR_ON_EKS_CLUSTER_RETURN = {'ResponseMetadata': {'HTTPStatusCode': 200}, 'id': 'vc1234'}
+
 GENERATED_UUID = '800647a9-adda-4237-94e6-f542c85fa55b'
 
 
@@ -137,3 +139,42 @@ class TestEmrContainerOperator(unittest.TestCase):
             assert mock_check_query_status.call_count == 3
             assert 'Final state of EMR Containers job is SUBMITTED' in str(ctx.value)
             assert 'Max tries of poll status exceeded' in str(ctx.value)
+
+
+class TestEmrEksCreateClusterOperator(unittest.TestCase):
+    @mock.patch('airflow.providers.amazon.aws.hooks.emr.EmrContainerHook')
+    def setUp(self, emr_hook_mock):
+        configuration.load_test_config()
+
+        self.emr_hook_mock = emr_hook_mock
+        self.emr_container = EmrEksCreateClusterOperator(
+            task_id='start_cluster',
+            virtual_cluster_name="test_virtual_cluster",
+            eks_cluster_name="test_eks_cluster",
+            eks_namespace="test_eks_namespace",
+            tags={},
+        )
+
+    @mock.patch.object(EmrContainerHook, 'create_emr_on_eks_cluster')
+    def test_emr_on_eks_execute_without_failure(self, mock_create_emr_on_eks_cluster):
+        mock_create_emr_on_eks_cluster.return_value = "vc1234"
+
+        self.emr_container.execute(None)
+
+        mock_create_emr_on_eks_cluster.assert_called_once_with(
+            'test_virtual_cluster', 'test_eks_cluster', 'test_eks_namespace', {}
+        )
+        assert self.emr_container.virtual_cluster_name == 'test_virtual_cluster'
+
+    @mock.patch.object(EmrContainerHook, 'create_emr_on_eks_cluster')
+    def test_emr_on_eks_execute_with_failure(self, mock_create_emr_on_eks_cluster):
+        expected_exception_msg = (
+            "An error occurred (ValidationException) when calling the "
+            "CreateVirtualCluster "
+            "operation:"
+            "A virtual cluster already exists in the given namespace"
+        )
+        mock_create_emr_on_eks_cluster.side_effect = AirflowException(expected_exception_msg)
+        with pytest.raises(AirflowException) as ctx:
+            self.emr_container.execute(None)
+        assert expected_exception_msg in str(ctx.value)