You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/27 02:15:37 UTC

[airflow] branch main updated: Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dbfa6487b8 Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)
dbfa6487b8 is described below

commit dbfa6487b820e6c94770404b3ba29ab11ae2a05e
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Sat Aug 27 06:15:29 2022 +0400

    Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)
---
 airflow/providers/amazon/aws/operators/ecs.py    |  8 +++--
 airflow/providers/amazon/aws/sensors/ecs.py      |  8 +++--
 tests/providers/amazon/aws/operators/test_ecs.py | 39 ++++++++++++++++++++++++
 3 files changed, 49 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index 2b5ee08d97..3336ecb9b5 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -49,9 +49,11 @@ DEFAULT_CONN_ID = 'aws_default'
 class EcsBaseOperator(BaseOperator):
     """This is the base operator for all Elastic Container Service operators."""
 
-    def __init__(self, **kwargs):
-        self.aws_conn_id = kwargs.get('aws_conn_id', DEFAULT_CONN_ID)
-        self.region = kwargs.get('region')
+    def __init__(
+        self, *, aws_conn_id: Optional[str] = DEFAULT_CONN_ID, region: Optional[str] = None, **kwargs
+    ):
+        self.aws_conn_id = aws_conn_id
+        self.region = region
         super().__init__(**kwargs)
 
     @cached_property
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py
index 171e2b14eb..a048f48378 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -45,9 +45,11 @@ def _check_failed(current_state, target_state, failure_states):
 class EcsBaseSensor(BaseSensorOperator):
     """Contains general sensor behavior for Elastic Container Service."""
 
-    def __init__(self, **kwargs):
-        self.aws_conn_id = kwargs.get('aws_conn_id', DEFAULT_CONN_ID)
-        self.region = kwargs.get('region')
+    def __init__(
+        self, *, aws_conn_id: Optional[str] = DEFAULT_CONN_ID, region: Optional[str] = None, **kwargs
+    ):
+        self.aws_conn_id = aws_conn_id
+        self.region = region
         super().__init__(**kwargs)
 
     @cached_property
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index ecdd4d779e..36ef07688b 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -29,6 +29,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.providers.amazon.aws.operators.ecs import (
+    DEFAULT_CONN_ID,
     EcsBaseOperator,
     EcsCreateClusterOperator,
     EcsDeleteClusterOperator,
@@ -84,6 +85,44 @@ RESPONSE_WITHOUT_FAILURES = {
         }
     ],
 }
+NOTSET = type("ArgumentNotSet", (), {"__str__": lambda self: "argument-not-set"})
+
+
+class TestEcsBaseOperator:
+    """Test Base ECS Operator."""
+
+    @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+    @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+    def test_initialise_operator(self, aws_conn_id, region_name):
+        """Test initialize operator."""
+        op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+        op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+        op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
+
+        assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
+        assert op.region == (region_name if region_name is not NOTSET else None)
+
+    @mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
+    @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+    @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+    def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id, region_name):
+        """Test initialize ``EcsHook`` and ``boto3.client``."""
+        mock_ecs_hook = mock_ecs_hook_cls.return_value
+        mock_conn = mock.MagicMock()
+        type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)
+
+        op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+        op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+        op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
+
+        hook = op.hook
+        assert op.hook is hook
+        mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
+
+        client = op.client
+        mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
+        assert client == mock_conn
+        assert op.client is client
 
 
 @pytest.mark.skipif(mock_ecs is None, reason="mock_ecs package not present")