You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/08/27 02:15:37 UTC
[airflow] branch main updated: Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new dbfa6487b8 Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)
dbfa6487b8 is described below
commit dbfa6487b820e6c94770404b3ba29ab11ae2a05e
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Sat Aug 27 06:15:29 2022 +0400
Fix `EcsBaseOperator` and `EcsBaseSensor` arguments (#25989)
---
airflow/providers/amazon/aws/operators/ecs.py | 8 +++--
airflow/providers/amazon/aws/sensors/ecs.py | 8 +++--
tests/providers/amazon/aws/operators/test_ecs.py | 39 ++++++++++++++++++++++++
3 files changed, 49 insertions(+), 6 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index 2b5ee08d97..3336ecb9b5 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -49,9 +49,11 @@ DEFAULT_CONN_ID = 'aws_default'
class EcsBaseOperator(BaseOperator):
"""This is the base operator for all Elastic Container Service operators."""
- def __init__(self, **kwargs):
- self.aws_conn_id = kwargs.get('aws_conn_id', DEFAULT_CONN_ID)
- self.region = kwargs.get('region')
+ def __init__(
+ self, *, aws_conn_id: Optional[str] = DEFAULT_CONN_ID, region: Optional[str] = None, **kwargs
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.region = region
super().__init__(**kwargs)
@cached_property
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py
index 171e2b14eb..a048f48378 100644
--- a/airflow/providers/amazon/aws/sensors/ecs.py
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -45,9 +45,11 @@ def _check_failed(current_state, target_state, failure_states):
class EcsBaseSensor(BaseSensorOperator):
"""Contains general sensor behavior for Elastic Container Service."""
- def __init__(self, **kwargs):
- self.aws_conn_id = kwargs.get('aws_conn_id', DEFAULT_CONN_ID)
- self.region = kwargs.get('region')
+ def __init__(
+ self, *, aws_conn_id: Optional[str] = DEFAULT_CONN_ID, region: Optional[str] = None, **kwargs
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.region = region
super().__init__(**kwargs)
@cached_property
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py
index ecdd4d779e..36ef07688b 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -29,6 +29,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.aws.operators.ecs import (
+ DEFAULT_CONN_ID,
EcsBaseOperator,
EcsCreateClusterOperator,
EcsDeleteClusterOperator,
@@ -84,6 +85,44 @@ RESPONSE_WITHOUT_FAILURES = {
}
],
}
+NOTSET = type("ArgumentNotSet", (), {"__str__": lambda self: "argument-not-set"})
+
+
+class TestEcsBaseOperator:
+ """Test Base ECS Operator."""
+
+ @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+ @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+ def test_initialise_operator(self, aws_conn_id, region_name):
+ """Test initialize operator."""
+ op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+ op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+ op = EcsBaseOperator(task_id="test_ecs_base", **op_kw)
+
+ assert op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else DEFAULT_CONN_ID)
+ assert op.region == (region_name if region_name is not NOTSET else None)
+
+ @mock.patch("airflow.providers.amazon.aws.operators.ecs.EcsHook")
+ @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"])
+ @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"])
+ def test_hook_and_client(self, mock_ecs_hook_cls, aws_conn_id, region_name):
+ """Test initialize ``EcsHook`` and ``boto3.client``."""
+ mock_ecs_hook = mock_ecs_hook_cls.return_value
+ mock_conn = mock.MagicMock()
+ type(mock_ecs_hook).conn = mock.PropertyMock(return_value=mock_conn)
+
+ op_kw = {"aws_conn_id": aws_conn_id, "region": region_name}
+ op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET}
+ op = EcsBaseOperator(task_id="test_ecs_base_hook_client", **op_kw)
+
+ hook = op.hook
+ assert op.hook is hook
+ mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
+
+ client = op.client
+ mock_ecs_hook_cls.assert_called_once_with(aws_conn_id=op.aws_conn_id, region_name=op.region)
+ assert client == mock_conn
+ assert op.client is client
@pytest.mark.skipif(mock_ecs is None, reason="mock_ecs package not present")