You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/02 02:57:35 UTC

[airflow] branch main updated: Add information about Amazon Elastic MapReduce Connection (#26687)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f3ad164aef Add information about Amazon Elastic MapReduce Connection (#26687)
f3ad164aef is described below

commit f3ad164aefb4915ce8c7725a43ddbcd61c830aa5
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Sun Oct 2 06:57:23 2022 +0400

    Add information about Amazon Elastic MapReduce Connection (#26687)
    
    * Added information about Amazon Elastic MapReduce Connection
---
 airflow/providers/amazon/aws/hooks/emr.py          | 105 ++++++++++++++++++---
 airflow/providers/amazon/aws/operators/emr.py      |  15 +--
 .../connections/emr.rst                            |  43 +++++++++
 tests/providers/amazon/aws/hooks/test_emr.py       |  61 +++++++++++-
 4 files changed, 199 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index fb18c32394..3104bf26e8 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import json
 import warnings
 from time import sleep
 from typing import Any, Callable
@@ -30,8 +31,11 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 
 class EmrHook(AwsBaseHook):
     """
-    Interact with AWS EMR. emr_conn_id is only necessary for using the
-    create_job_flow method.
+    Interact with Amazon Elastic MapReduce Service.
+
+    :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`.
+        This attribute is only necessary when using
+        the :meth:`~airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` method.
 
     Additional arguments (such as ``aws_conn_id``) may be specified and
     are passed down to the underlying AwsBaseHook.
@@ -45,8 +49,8 @@ class EmrHook(AwsBaseHook):
     conn_type = 'emr'
     hook_name = 'Amazon Elastic MapReduce'
 
-    def __init__(self, emr_conn_id: str = default_conn_name, *args, **kwargs) -> None:
-        self.emr_conn_id: str = emr_conn_id
+    def __init__(self, emr_conn_id: str | None = default_conn_name, *args, **kwargs) -> None:
+        self.emr_conn_id = emr_conn_id
         kwargs["client_type"] = "emr"
         super().__init__(*args, **kwargs)
 
@@ -77,22 +81,97 @@ class EmrHook(AwsBaseHook):
 
     def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]:
         """
-        Creates a job flow using the config from the EMR connection.
-        Keys of the json extra hash may have the arguments of the boto3
-        run_job_flow method.
-        Overrides for this config may be passed as the job_flow_overrides.
+        Create and start running a new cluster (job flow).
+
+        This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration.
+        If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial
+        configuration is used.
+
+        :param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration
+            cluster. The resulting configuration will be used in the boto3 emr client run_job_flow method.
+
+        .. seealso::
+            - :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`
+            - `API RunJobFlow <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_
+            - `boto3 emr client run_job_flow method <https://boto3.amazonaws.com/v1/documentation/\
+               api/latest/reference/services/emr.html#EMR.Client.run_job_flow>`_.
         """
-        try:
-            emr_conn = self.get_connection(self.emr_conn_id)
-            config = emr_conn.extra_dejson.copy()
-        except AirflowNotFoundException:
-            config = {}
+        config = {}
+        if self.emr_conn_id:
+            try:
+                emr_conn = self.get_connection(self.emr_conn_id)
+            except AirflowNotFoundException:
+                warnings.warn(
+                    f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, "
+                    "using an empty initial configuration. If you want to get rid of this warning "
+                    "message please provide a valid `emr_conn_id` or set it to None.",
+                    UserWarning,
+                    stacklevel=2,
+                )
+            else:
+                if emr_conn.conn_type and emr_conn.conn_type != self.conn_type:
+                    warnings.warn(
+                        f"{self.hook_name} Connection expected connection type {self.conn_type!r}, "
+                        f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. "
+                        f"This connection might not work correctly.",
+                        UserWarning,
+                        stacklevel=2,
+                    )
+                config = emr_conn.extra_dejson.copy()
         config.update(job_flow_overrides)
 
         response = self.get_conn().run_job_flow(**config)
 
         return response
 
+    def test_connection(self):
+        """
+        Return failed state for test Amazon Elastic MapReduce Connection (untestable).
+
+        We need to overwrite this method because this hook is based on
+        :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`,
+        otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy.
+        """
+        msg = (
+            f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores "
+            f"only key/value pairs and does not make a connection to an external resource."
+        )
+        return False, msg
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Returns custom UI field behaviour for Amazon Elastic MapReduce Connection."""
+        return {
+            "hidden_fields": ["host", "schema", "port", "login", "password"],
+            "relabeling": {
+                "extra": "Run Job Flow Configuration",
+            },
+            "placeholders": {
+                "extra": json.dumps(
+                    {
+                        "Name": "MyClusterName",
+                        "ReleaseLabel": "emr-5.36.0",
+                        "Applications": [{"Name": "Spark"}],
+                        "Instances": {
+                            "InstanceGroups": [
+                                {
+                                    "Name": "Primary node",
+                                    "Market": "SPOT",
+                                    "InstanceRole": "MASTER",
+                                    "InstanceType": "m5.large",
+                                    "InstanceCount": 1,
+                                },
+                            ],
+                            "KeepJobFlowAliveWhenNoSteps": False,
+                            "TerminationProtected": False,
+                        },
+                        "StepConcurrencyLevel": 2,
+                    },
+                    indent=2,
+                ),
+            },
+        }
+
 
 class EmrServerlessHook(AwsBaseHook):
     """
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index 5ccff487e9..659252e107 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -332,10 +332,13 @@ class EmrCreateJobFlowOperator(BaseOperator):
         running Airflow in a distributed manner and aws_conn_id is None or
         empty, then default boto3 configuration would be used (and must be
         maintained on each worker node)
-    :param emr_conn_id: emr connection to use for run_job_flow request body.
-        This will be overridden by the job_flow_overrides param
+    :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection <howto/connection:emr>`.
+        Use to receive an initial Amazon EMR cluster configuration:
+        ``boto3.client('emr').run_job_flow`` request body.
+        If this is None or empty or the connection does not exist,
+        then an empty initial configuration is used.
     :param job_flow_overrides: boto3 style arguments or reference to an arguments file
-        (must be '.json') to override emr_connection extra. (templated)
+        (must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
     :param region_name: Region named passed to EmrHook
     """
 
@@ -349,7 +352,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
         self,
         *,
         aws_conn_id: str = 'aws_default',
-        emr_conn_id: str = 'emr_default',
+        emr_conn_id: str | None = 'emr_default',
         job_flow_overrides: str | dict[str, Any] | None = None,
         region_name: str | None = None,
         **kwargs,
@@ -357,9 +360,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
         super().__init__(**kwargs)
         self.aws_conn_id = aws_conn_id
         self.emr_conn_id = emr_conn_id
-        if job_flow_overrides is None:
-            job_flow_overrides = {}
-        self.job_flow_overrides = job_flow_overrides
+        self.job_flow_overrides = job_flow_overrides or {}
         self.region_name = region_name
 
     def execute(self, context: Context) -> str:
diff --git a/docs/apache-airflow-providers-amazon/connections/emr.rst b/docs/apache-airflow-providers-amazon/connections/emr.rst
new file mode 100644
index 0000000000..657a19b5bc
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/connections/emr.rst
@@ -0,0 +1,43 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+.. _howto/connection:emr:
+
+Amazon Elastic MapReduce (EMR) Connection
+=========================================
+
+.. note::
+  This connection type is only used to store parameters to Start EMR Cluster (`run_job_flow` boto3 EMR client method).
+
+  This connection not intend to store any credentials for ``boto3`` client, if you try to pass any
+  parameters not listed in `RunJobFlow API <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_
+  you will get an error like this.
+
+  .. code-block:: text
+
+      Parameter validation failed: Unknown parameter in input: "region_name", must be one of:
+
+  For Authenticating to AWS please use :ref:`Amazon Web Services Connection <howto/connection:aws>`.
+
+Configuring the Connection
+--------------------------
+
+Extra (optional)
+    Specify the parameters (as a `json` dictionary) that can be used as an initial configuration
+    in :meth:`airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` to propagate to
+    `RunJobFlow API <https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html>`_.
+    All parameters are optional.
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py b/tests/providers/amazon/aws/hooks/test_emr.py
index 7836bc0498..339d49239d 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -17,9 +17,10 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
+from unittest import mock
 
 import boto3
+import pytest
 
 from airflow.providers.amazon.aws.hooks.emr import EmrHook
 
@@ -29,8 +30,8 @@ except ImportError:
     mock_emr = None
 
 
-@unittest.skipIf(mock_emr is None, 'moto package not present')
-class TestEmrHook(unittest.TestCase):
+@pytest.mark.skipif(mock_emr is None, reason='moto package not present')
+class TestEmrHook:
     @mock_emr
     def test_get_conn_returns_a_boto3_connection(self):
         hook = EmrHook(aws_conn_id='aws_default', region_name='ap-southeast-2')
@@ -59,13 +60,63 @@ class TestEmrHook(unittest.TestCase):
         # AmiVersion is really old and almost no one will use it anymore, but
         # it's one of the "optional" request params that moto supports - it's
         # coverage of EMR isn't 100% it turns out.
-        cluster = hook.create_job_flow({'Name': 'test_cluster', 'ReleaseLabel': '', 'AmiVersion': '3.2'})
-
+        with pytest.warns(None):  # Expected no warnings if ``emr_conn_id`` exists with correct conn_type
+            cluster = hook.create_job_flow({'Name': 'test_cluster', 'ReleaseLabel': '', 'AmiVersion': '3.2'})
         cluster = client.describe_cluster(ClusterId=cluster['JobFlowId'])['Cluster']
 
         # The AmiVersion comes back as {Requested,Running}AmiVersion fields.
         assert cluster['RequestedAmiVersion'] == '3.2'
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
+    def test_empty_emr_conn_id(self, mock_boto3_client):
+        """Test empty ``emr_conn_id``."""
+        mock_run_job_flow = mock.MagicMock()
+        mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
+        job_flow_overrides = {"foo": "bar"}
+
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id=None)
+        hook.create_job_flow(job_flow_overrides)
+        mock_run_job_flow.assert_called_once_with(**job_flow_overrides)
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
+    def test_missing_emr_conn_id(self, mock_boto3_client):
+        """Test not exists ``emr_conn_id``."""
+        mock_run_job_flow = mock.MagicMock()
+        mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
+        job_flow_overrides = {"foo": "bar"}
+
+        hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="not-exists-emr-conn-id")
+        warning_message = r"Unable to find Amazon Elastic MapReduce Connection ID 'not-exists-emr-conn-id',.*"
+        with pytest.warns(UserWarning, match=warning_message):
+            hook.create_job_flow(job_flow_overrides)
+        mock_run_job_flow.assert_called_once_with(**job_flow_overrides)
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_conn")
+    def test_emr_conn_id_wrong_conn_type(self, mock_boto3_client):
+        """Test exists ``emr_conn_id`` have unexpected ``conn_type``."""
+        mock_run_job_flow = mock.MagicMock()
+        mock_boto3_client.return_value.run_job_flow = mock_run_job_flow
+        job_flow_overrides = {"foo": "bar"}
+
+        with mock.patch.dict("os.environ", AIRFLOW_CONN_WRONG_TYPE_CONN="aws://"):
+            hook = EmrHook(aws_conn_id="aws_default", emr_conn_id="wrong_type_conn")
+            warning_message = (
+                r"Amazon Elastic MapReduce Connection expected connection type 'emr'"
+                r".* This connection might not work correctly."
+            )
+            with pytest.warns(UserWarning, match=warning_message):
+                hook.create_job_flow(job_flow_overrides)
+            mock_run_job_flow.assert_called_once_with(**job_flow_overrides)
+
+    @pytest.mark.parametrize("aws_conn_id", ["aws_default", None])
+    @pytest.mark.parametrize("emr_conn_id", ["emr_default", None])
+    def test_emr_connection(self, aws_conn_id, emr_conn_id):
+        """Test that ``EmrHook`` always return False state."""
+        hook = EmrHook(aws_conn_id=aws_conn_id, emr_conn_id=emr_conn_id)
+        result, message = hook.test_connection()
+        assert not result
+        assert message.startswith("'Amazon Elastic MapReduce' Airflow Connection cannot be tested")
+
     @mock_emr
     def test_get_cluster_id_by_name(self):
         """