You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/12/28 12:26:46 UTC
(airflow) branch main updated: Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 22294abf68 Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)
22294abf68 is described below
commit 22294abf68f17eefc00ec9b363bfcf1ca21f145a
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Thu Dec 28 16:26:37 2023 +0400
Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)
---
.../amazon/aws/operators/step_function.py | 63 ++++-----
.../providers/amazon/aws/sensors/step_function.py | 36 +++---
.../providers/amazon/aws/triggers/step_function.py | 9 +-
.../operators/step_functions.rst | 5 +
.../amazon/aws/operators/test_step_function.py | 143 +++++++++++----------
.../amazon/aws/sensors/test_step_function.py | 112 ++++++++--------
6 files changed, 190 insertions(+), 178 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py
index 68324df731..e02de32bae 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -22,15 +22,16 @@ from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-class StepFunctionStartExecutionOperator(BaseOperator):
+class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
"""
An Operator that begins execution of an AWS Step Function State Machine.
@@ -50,10 +51,20 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False, but can be overridden in config file by setting default_deferrable to True)
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
- template_ext: Sequence[str] = ()
+ aws_hook_class = StepFunctionHook
+ template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
ui_color = "#f9c915"
def __init__(
@@ -62,8 +73,6 @@ class StepFunctionStartExecutionOperator(BaseOperator):
state_machine_arn: str,
name: str | None = None,
state_machine_input: dict | str | None = None,
- aws_conn_id: str = "aws_default",
- region_name: str | None = None,
waiter_max_attempts: int = 30,
waiter_delay: int = 60,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
@@ -73,18 +82,12 @@ class StepFunctionStartExecutionOperator(BaseOperator):
self.state_machine_arn = state_machine_arn
self.name = name
self.input = state_machine_input
- self.aws_conn_id = aws_conn_id
- self.region_name = region_name
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.deferrable = deferrable
def execute(self, context: Context):
- hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
-
- execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
-
- if execution_arn is None:
+ if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
@@ -96,6 +99,8 @@ class StepFunctionStartExecutionOperator(BaseOperator):
waiter_max_attempts=self.waiter_max_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
+ botocore_config=self.botocore_config,
+ verify=self.verify,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
@@ -110,7 +115,7 @@ class StepFunctionStartExecutionOperator(BaseOperator):
return event["execution_arn"]
-class StepFunctionGetExecutionOutputOperator(BaseOperator):
+class StepFunctionGetExecutionOutputOperator(AwsBaseOperator[StepFunctionHook]):
"""
An Operator that returns the output of an AWS Step Function State Machine execution.
@@ -121,30 +126,28 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
:ref:`howto/operator:StepFunctionGetExecutionOutputOperator`
:param execution_arn: ARN of the Step Function State Machine Execution
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = ("execution_arn",)
- template_ext: Sequence[str] = ()
+ aws_hook_class = StepFunctionHook
+ template_fields: Sequence[str] = aws_template_fields("execution_arn")
ui_color = "#f9c915"
- def __init__(
- self,
- *,
- execution_arn: str,
- aws_conn_id: str = "aws_default",
- region_name: str | None = None,
- **kwargs,
- ):
+ def __init__(self, *, execution_arn: str, **kwargs):
super().__init__(**kwargs)
self.execution_arn = execution_arn
- self.aws_conn_id = aws_conn_id
- self.region_name = region_name
def execute(self, context: Context):
- hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
-
- execution_status = hook.describe_execution(self.execution_arn)
+ execution_status = self.hook.describe_execution(self.execution_arn)
response = None
if "output" in execution_status:
response = json.loads(execution_status["output"])
diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py
index 053a751336..5e0d3cfcf7 100644
--- a/airflow/providers/amazon/aws/sensors/step_function.py
+++ b/airflow/providers/amazon/aws/sensors/step_function.py
@@ -17,20 +17,20 @@
from __future__ import annotations
import json
-from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from deprecated import deprecated
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-class StepFunctionExecutionSensor(BaseSensorOperator):
+class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]):
"""
Poll the Step Function State Machine Execution until it reaches a terminal state; fails if the task fails.
@@ -42,7 +42,16 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
:ref:`howto/sensor:StepFunctionExecutionSensor`
:param execution_arn: execution_arn to check the state of
- :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used. If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
INTERMEDIATE_STATES = ("RUNNING",)
@@ -53,22 +62,13 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("SUCCEEDED",)
- template_fields: Sequence[str] = ("execution_arn",)
- template_ext: Sequence[str] = ()
+ aws_hook_class = StepFunctionHook
+ template_fields: Sequence[str] = aws_template_fields("execution_arn")
ui_color = "#66c3ff"
- def __init__(
- self,
- *,
- execution_arn: str,
- aws_conn_id: str = "aws_default",
- region_name: str | None = None,
- **kwargs,
- ):
+ def __init__(self, *, execution_arn: str, **kwargs):
super().__init__(**kwargs)
self.execution_arn = execution_arn
- self.aws_conn_id = aws_conn_id
- self.region_name = region_name
def poke(self, context: Context):
execution_status = self.hook.describe_execution(self.execution_arn)
@@ -93,7 +93,3 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
def get_hook(self) -> StepFunctionHook:
"""Create and return a StepFunctionHook."""
return self.hook
-
- @cached_property
- def hook(self) -> StepFunctionHook:
- return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/triggers/step_function.py b/airflow/providers/amazon/aws/triggers/step_function.py
index da0f186da9..6fe6af2218 100644
--- a/airflow/providers/amazon/aws/triggers/step_function.py
+++ b/airflow/providers/amazon/aws/triggers/step_function.py
@@ -43,6 +43,7 @@ class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts: int = 30,
aws_conn_id: str | None = None,
region_name: str | None = None,
+ **kwargs,
) -> None:
super().__init__(
serialized_fields={"execution_arn": execution_arn, "region_name": region_name},
@@ -56,7 +57,13 @@ class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
+ **kwargs,
)
def hook(self) -> AwsGenericHook:
- return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+ return StepFunctionHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
index 7736fa9b16..5ab5d19e68 100644
--- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
+++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
@@ -28,6 +28,11 @@ Prerequisite Tasks
.. include:: ../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py
index 91ccebf7c6..6845a7f98a 100644
--- a/tests/providers/amazon/aws/operators/test_step_function.py
+++ b/tests/providers/amazon/aws/operators/test_step_function.py
@@ -18,12 +18,10 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import MagicMock
import pytest
-from airflow.exceptions import TaskDeferred
-from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.operators.step_function import (
StepFunctionGetExecutionOutputOperator,
StepFunctionStartExecutionOperator,
@@ -40,104 +38,106 @@ NAME = "NAME"
INPUT = "{}"
+@pytest.fixture
+def mocked_context():
+ return mock.MagicMock(name="FakeContext")
+
+
class TestStepFunctionGetExecutionOutputOperator:
TASK_ID = "step_function_get_execution_output"
- def setup_method(self):
- self.mock_context = MagicMock()
-
def test_init(self):
- # Given / When
- operator = StepFunctionGetExecutionOutputOperator(
+ op = StepFunctionGetExecutionOutputOperator(
task_id=self.TASK_ID,
execution_arn=EXECUTION_ARN,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
+ verify="/spam/egg.pem",
+ botocore_config={"read_timeout": 42},
)
-
- # Then
- assert self.TASK_ID == operator.task_id
- assert EXECUTION_ARN == operator.execution_arn
- assert AWS_CONN_ID == operator.aws_conn_id
- assert REGION_NAME == operator.region_name
-
- @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook")
- @pytest.mark.parametrize("response", ["output", "error"])
- def test_execute(self, mock_hook, response):
- # Given
- hook_response = {response: "{}"}
-
- hook_instance = mock_hook.return_value
- hook_instance.describe_execution.return_value = hook_response
-
- operator = StepFunctionGetExecutionOutputOperator(
+ assert op.execution_arn == EXECUTION_ARN
+ assert op.hook.aws_conn_id == AWS_CONN_ID
+ assert op.hook._region_name == REGION_NAME
+ assert op.hook._verify == "/spam/egg.pem"
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ op = StepFunctionGetExecutionOutputOperator(task_id=self.TASK_ID, execution_arn=EXECUTION_ARN)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ @mock.patch.object(StepFunctionGetExecutionOutputOperator, "hook")
+ @pytest.mark.parametrize(
+ "response, expected_output",
+ [
+ pytest.param({"output": '{"foo": "bar"}'}, {"foo": "bar"}, id="output"),
+ pytest.param({"error": '{"spam": "egg"}'}, {"spam": "egg"}, id="error"),
+ pytest.param({"other": '{"baz": "qux"}'}, None, id="other"),
+ ],
+ )
+ def test_execute(self, mocked_hook, mocked_context, response, expected_output):
+ mocked_hook.describe_execution.return_value = response
+ op = StepFunctionGetExecutionOutputOperator(
task_id=self.TASK_ID,
execution_arn=EXECUTION_ARN,
- aws_conn_id=AWS_CONN_ID,
- region_name=REGION_NAME,
+ aws_conn_id=None,
)
-
- # When
- result = operator.execute(self.mock_context)
-
- # Then
- assert {} == result
+ assert op.execute(mocked_context) == expected_output
+ mocked_hook.describe_execution.assert_called_once_with(EXECUTION_ARN)
class TestStepFunctionStartExecutionOperator:
TASK_ID = "step_function_start_execution_task"
- def setup_method(self):
- self.mock_context = MagicMock()
-
def test_init(self):
- # Given / When
- operator = StepFunctionStartExecutionOperator(
+ op = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
name=NAME,
state_machine_input=INPUT,
aws_conn_id=AWS_CONN_ID,
region_name=REGION_NAME,
+ verify=False,
+ botocore_config={"read_timeout": 42},
)
-
- # Then
- assert self.TASK_ID == operator.task_id
- assert STATE_MACHINE_ARN == operator.state_machine_arn
- assert NAME == operator.name
- assert INPUT == operator.input
- assert AWS_CONN_ID == operator.aws_conn_id
- assert REGION_NAME == operator.region_name
-
- @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook")
- def test_execute(self, mock_hook):
- # Given
+ assert op.state_machine_arn == STATE_MACHINE_ARN
+ assert op.state_machine_arn == STATE_MACHINE_ARN
+ assert op.name == NAME
+ assert op.input == INPUT
+ assert op.hook.aws_conn_id == AWS_CONN_ID
+ assert op.hook._region_name == REGION_NAME
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ op = StepFunctionStartExecutionOperator(task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+ def test_execute(self, mocked_hook, mocked_context):
hook_response = (
"arn:aws:states:us-east-1:123456789012:execution:"
"pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934"
)
-
- hook_instance = mock_hook.return_value
- hook_instance.start_execution.return_value = hook_response
-
- operator = StepFunctionStartExecutionOperator(
+ mocked_hook.start_execution.return_value = hook_response
+ op = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
name=NAME,
state_machine_input=INPUT,
- aws_conn_id=AWS_CONN_ID,
- region_name=REGION_NAME,
+ aws_conn_id=None,
)
+ assert op.execute(mocked_context) == hook_response
+ mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
- # When
- result = operator.execute(self.mock_context)
-
- # Then
- assert hook_response == result
-
- @mock.patch.object(StepFunctionHook, "start_execution")
- def test_step_function_start_execution_deferrable(self, mock_start_execution):
- mock_start_execution.return_value = "test-execution-arn"
+ @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+ def test_step_function_start_execution_deferrable(self, mocked_hook):
+ mocked_hook.start_execution.return_value = "test-execution-arn"
operator = StepFunctionStartExecutionOperator(
task_id=self.TASK_ID,
state_machine_arn=STATE_MACHINE_ARN,
@@ -149,3 +149,14 @@ class TestStepFunctionStartExecutionOperator:
)
with pytest.raises(TaskDeferred):
operator.execute(None)
+ mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
+
+ @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+ @pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")])
+ def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn):
+ mocked_hook.start_execution.return_value = execution_arn
+ op = StepFunctionStartExecutionOperator(
+ task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, aws_conn_id=None
+ )
+ with pytest.raises(AirflowException, match="Failed to start State Machine execution"):
+ op.execute({})
diff --git a/tests/providers/amazon/aws/sensors/test_step_function.py b/tests/providers/amazon/aws/sensors/test_step_function.py
index b6a47d49cd..878691dc1c 100644
--- a/tests/providers/amazon/aws/sensors/test_step_function.py
+++ b/tests/providers/amazon/aws/sensors/test_step_function.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import json
from unittest import mock
-from unittest.mock import MagicMock
import pytest
@@ -35,72 +34,63 @@ AWS_CONN_ID = "aws_non_default"
REGION_NAME = "us-west-2"
-class TestStepFunctionExecutionSensor:
- def setup_method(self):
- self.mock_context = MagicMock()
-
- def test_init(self):
- sensor = StepFunctionExecutionSensor(
- task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
- )
-
- assert TASK_ID == sensor.task_id
- assert EXECUTION_ARN == sensor.execution_arn
- assert AWS_CONN_ID == sensor.aws_conn_id
- assert REGION_NAME == sensor.region_name
-
- @pytest.mark.parametrize("mock_status", ["FAILED", "TIMED_OUT", "ABORTED"])
- @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
- def test_exceptions(self, mock_hook, mock_status):
- hook_response = {"status": mock_status}
-
- hook_instance = mock_hook.return_value
- hook_instance.describe_execution.return_value = hook_response
-
- sensor = StepFunctionExecutionSensor(
- task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
- )
-
- with pytest.raises(AirflowException):
- sensor.poke(self.mock_context)
+@pytest.fixture
+def mocked_context():
+ return mock.MagicMock(name="FakeContext")
- @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
- def test_running(self, mock_hook):
- hook_response = {"status": "RUNNING"}
-
- hook_instance = mock_hook.return_value
- hook_instance.describe_execution.return_value = hook_response
-
- sensor = StepFunctionExecutionSensor(
- task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
- )
-
- assert not sensor.poke(self.mock_context)
-
- @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
- def test_succeeded(self, mock_hook):
- hook_response = {"status": "SUCCEEDED"}
-
- hook_instance = mock_hook.return_value
- hook_instance.describe_execution.return_value = hook_response
+class TestStepFunctionExecutionSensor:
+ def test_init(self):
sensor = StepFunctionExecutionSensor(
- task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
+ task_id=TASK_ID,
+ execution_arn=EXECUTION_ARN,
+ aws_conn_id=AWS_CONN_ID,
+ region_name=REGION_NAME,
+ verify=True,
+ botocore_config={"read_timeout": 42},
)
-
- assert sensor.poke(self.mock_context)
-
+ assert sensor.execution_arn == EXECUTION_ARN
+ assert sensor.hook.aws_conn_id == AWS_CONN_ID
+ assert sensor.hook._region_name == REGION_NAME
+ assert sensor.hook._verify is True
+ assert sensor.hook._config is not None
+ assert sensor.hook._config.read_timeout == 42
+
+ sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN)
+ assert sensor.hook.aws_conn_id == "aws_default"
+ assert sensor.hook._region_name is None
+ assert sensor.hook._verify is None
+ assert sensor.hook._config is None
+
+ @mock.patch.object(StepFunctionExecutionSensor, "hook")
+ @pytest.mark.parametrize("status", StepFunctionExecutionSensor.INTERMEDIATE_STATES)
+ def test_running(self, mocked_hook, status, mocked_context):
+ mocked_hook.describe_execution.return_value = {"status": status}
+ sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None)
+ assert sensor.poke(mocked_context) is False
+
+ @mock.patch.object(StepFunctionExecutionSensor, "hook")
+ @pytest.mark.parametrize("status", StepFunctionExecutionSensor.SUCCESS_STATES)
+ def test_succeeded(self, mocked_hook, status, mocked_context):
+ mocked_hook.describe_execution.return_value = {"status": status}
+ sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None)
+ assert sensor.poke(mocked_context) is True
+
+ @mock.patch.object(StepFunctionExecutionSensor, "hook")
+ @pytest.mark.parametrize("status", StepFunctionExecutionSensor.FAILURE_STATES)
@pytest.mark.parametrize(
- "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+ "soft_fail, expected_exception",
+ [
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ pytest.param(False, AirflowException, id="non-soft-fail"),
+ ],
)
- @mock.patch("airflow.providers.amazon.aws.hooks.step_function.StepFunctionHook.describe_execution")
- def test_fail_poke(self, describe_execution, soft_fail, expected_exception):
+ def test_failure(self, mocked_hook, status, soft_fail, expected_exception, mocked_context):
+ output = {"test": "test"}
+ mocked_hook.describe_execution.return_value = {"status": status, "output": json.dumps(output)}
sensor = StepFunctionExecutionSensor(
- task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
+ task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None, soft_fail=soft_fail
)
- sensor.soft_fail = soft_fail
- output = '{"test": "test"}'
- describe_execution.return_value = {"status": "FAILED", "output": output}
- message = f"Step Function sensor failed. State Machine Output: {json.loads(output)}"
+ message = f"Step Function sensor failed. State Machine Output: {output}"
with pytest.raises(expected_exception, match=message):
- sensor.poke(context={})
+ sensor.poke(mocked_context)