You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2023/12/28 12:26:46 UTC

(airflow) branch main updated: Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 22294abf68 Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)
22294abf68 is described below

commit 22294abf68f17eefc00ec9b363bfcf1ca21f145a
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Thu Dec 28 16:26:37 2023 +0400

    Use base aws classes in AWS Step Functions Operators/Sensors/Triggers (#36468)
---
 .../amazon/aws/operators/step_function.py          |  63 ++++-----
 .../providers/amazon/aws/sensors/step_function.py  |  36 +++---
 .../providers/amazon/aws/triggers/step_function.py |   9 +-
 .../operators/step_functions.rst                   |   5 +
 .../amazon/aws/operators/test_step_function.py     | 143 +++++++++++----------
 .../amazon/aws/sensors/test_step_function.py       | 112 ++++++++--------
 6 files changed, 190 insertions(+), 178 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py
index 68324df731..e02de32bae 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -22,15 +22,16 @@ from typing import TYPE_CHECKING, Any, Sequence
 
 from airflow.configuration import conf
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
 from airflow.providers.amazon.aws.triggers.step_function import StepFunctionsExecutionCompleteTrigger
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class StepFunctionStartExecutionOperator(BaseOperator):
+class StepFunctionStartExecutionOperator(AwsBaseOperator[StepFunctionHook]):
     """
     An Operator that begins execution of an AWS Step Function State Machine.
 
@@ -50,10 +51,20 @@ class StepFunctionStartExecutionOperator(BaseOperator):
     :param deferrable: If True, the operator will wait asynchronously for the job to complete.
         This implies waiting for completion. This mode requires aiobotocore module to be installed.
         (default: False, but can be overridden in config file by setting default_deferrable to True)
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+        https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
-    template_ext: Sequence[str] = ()
+    aws_hook_class = StepFunctionHook
+    template_fields: Sequence[str] = aws_template_fields("state_machine_arn", "name", "input")
     ui_color = "#f9c915"
 
     def __init__(
@@ -62,8 +73,6 @@ class StepFunctionStartExecutionOperator(BaseOperator):
         state_machine_arn: str,
         name: str | None = None,
         state_machine_input: dict | str | None = None,
-        aws_conn_id: str = "aws_default",
-        region_name: str | None = None,
         waiter_max_attempts: int = 30,
         waiter_delay: int = 60,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
@@ -73,18 +82,12 @@ class StepFunctionStartExecutionOperator(BaseOperator):
         self.state_machine_arn = state_machine_arn
         self.name = name
         self.input = state_machine_input
-        self.aws_conn_id = aws_conn_id
-        self.region_name = region_name
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
         self.deferrable = deferrable
 
     def execute(self, context: Context):
-        hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
-
-        execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
-
-        if execution_arn is None:
+        if not (execution_arn := self.hook.start_execution(self.state_machine_arn, self.name, self.input)):
             raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
 
         self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
@@ -96,6 +99,8 @@ class StepFunctionStartExecutionOperator(BaseOperator):
                     waiter_max_attempts=self.waiter_max_attempts,
                     aws_conn_id=self.aws_conn_id,
                     region_name=self.region_name,
+                    botocore_config=self.botocore_config,
+                    verify=self.verify,
                 ),
                 method_name="execute_complete",
                 timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay),
@@ -110,7 +115,7 @@ class StepFunctionStartExecutionOperator(BaseOperator):
         return event["execution_arn"]
 
 
-class StepFunctionGetExecutionOutputOperator(BaseOperator):
+class StepFunctionGetExecutionOutputOperator(AwsBaseOperator[StepFunctionHook]):
     """
     An Operator that returns the output of an AWS Step Function State Machine execution.
 
@@ -121,30 +126,28 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
         :ref:`howto/operator:StepFunctionGetExecutionOutputOperator`
 
     :param execution_arn: ARN of the Step Function State Machine Execution
-    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+        https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
-    template_fields: Sequence[str] = ("execution_arn",)
-    template_ext: Sequence[str] = ()
+    aws_hook_class = StepFunctionHook
+    template_fields: Sequence[str] = aws_template_fields("execution_arn")
     ui_color = "#f9c915"
 
-    def __init__(
-        self,
-        *,
-        execution_arn: str,
-        aws_conn_id: str = "aws_default",
-        region_name: str | None = None,
-        **kwargs,
-    ):
+    def __init__(self, *, execution_arn: str, **kwargs):
         super().__init__(**kwargs)
         self.execution_arn = execution_arn
-        self.aws_conn_id = aws_conn_id
-        self.region_name = region_name
 
     def execute(self, context: Context):
-        hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
-
-        execution_status = hook.describe_execution(self.execution_arn)
+        execution_status = self.hook.describe_execution(self.execution_arn)
         response = None
         if "output" in execution_status:
             response = json.loads(execution_status["output"])
diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py
index 053a751336..5e0d3cfcf7 100644
--- a/airflow/providers/amazon/aws/sensors/step_function.py
+++ b/airflow/providers/amazon/aws/sensors/step_function.py
@@ -17,20 +17,20 @@
 from __future__ import annotations
 
 import json
-from functools import cached_property
 from typing import TYPE_CHECKING, Sequence
 
 from deprecated import deprecated
 
 from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class StepFunctionExecutionSensor(BaseSensorOperator):
+class StepFunctionExecutionSensor(AwsBaseSensor[StepFunctionHook]):
     """
     Poll the Step Function State Machine Execution until it reaches a terminal state; fails if the task fails.
 
@@ -42,7 +42,16 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
         :ref:`howto/sensor:StepFunctionExecutionSensor`
 
     :param execution_arn: execution_arn to check the state of
-    :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
+        https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     """
 
     INTERMEDIATE_STATES = ("RUNNING",)
@@ -53,22 +62,13 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
     )
     SUCCESS_STATES = ("SUCCEEDED",)
 
-    template_fields: Sequence[str] = ("execution_arn",)
-    template_ext: Sequence[str] = ()
+    aws_hook_class = StepFunctionHook
+    template_fields: Sequence[str] = aws_template_fields("execution_arn")
     ui_color = "#66c3ff"
 
-    def __init__(
-        self,
-        *,
-        execution_arn: str,
-        aws_conn_id: str = "aws_default",
-        region_name: str | None = None,
-        **kwargs,
-    ):
+    def __init__(self, *, execution_arn: str, **kwargs):
         super().__init__(**kwargs)
         self.execution_arn = execution_arn
-        self.aws_conn_id = aws_conn_id
-        self.region_name = region_name
 
     def poke(self, context: Context):
         execution_status = self.hook.describe_execution(self.execution_arn)
@@ -93,7 +93,3 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
     def get_hook(self) -> StepFunctionHook:
         """Create and return a StepFunctionHook."""
         return self.hook
-
-    @cached_property
-    def hook(self) -> StepFunctionHook:
-        return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/triggers/step_function.py b/airflow/providers/amazon/aws/triggers/step_function.py
index da0f186da9..6fe6af2218 100644
--- a/airflow/providers/amazon/aws/triggers/step_function.py
+++ b/airflow/providers/amazon/aws/triggers/step_function.py
@@ -43,6 +43,7 @@ class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
         waiter_max_attempts: int = 30,
         aws_conn_id: str | None = None,
         region_name: str | None = None,
+        **kwargs,
     ) -> None:
         super().__init__(
             serialized_fields={"execution_arn": execution_arn, "region_name": region_name},
@@ -56,7 +57,13 @@ class StepFunctionsExecutionCompleteTrigger(AwsBaseWaiterTrigger):
             waiter_delay=waiter_delay,
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
+            **kwargs,
         )
 
     def hook(self) -> AwsGenericHook:
-        return StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
+        return StepFunctionHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region_name,
+            verify=self.verify,
+            config=self.botocore_config,
+        )
diff --git a/docs/apache-airflow-providers-amazon/operators/step_functions.rst b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
index 7736fa9b16..5ab5d19e68 100644
--- a/docs/apache-airflow-providers-amazon/operators/step_functions.rst
+++ b/docs/apache-airflow-providers-amazon/operators/step_functions.rst
@@ -28,6 +28,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_step_function.py b/tests/providers/amazon/aws/operators/test_step_function.py
index 91ccebf7c6..6845a7f98a 100644
--- a/tests/providers/amazon/aws/operators/test_step_function.py
+++ b/tests/providers/amazon/aws/operators/test_step_function.py
@@ -18,12 +18,10 @@
 from __future__ import annotations
 
 from unittest import mock
-from unittest.mock import MagicMock
 
 import pytest
 
-from airflow.exceptions import TaskDeferred
-from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.operators.step_function import (
     StepFunctionGetExecutionOutputOperator,
     StepFunctionStartExecutionOperator,
@@ -40,104 +38,106 @@ NAME = "NAME"
 INPUT = "{}"
 
 
+@pytest.fixture
+def mocked_context():
+    return mock.MagicMock(name="FakeContext")
+
+
 class TestStepFunctionGetExecutionOutputOperator:
     TASK_ID = "step_function_get_execution_output"
 
-    def setup_method(self):
-        self.mock_context = MagicMock()
-
     def test_init(self):
-        # Given / When
-        operator = StepFunctionGetExecutionOutputOperator(
+        op = StepFunctionGetExecutionOutputOperator(
             task_id=self.TASK_ID,
             execution_arn=EXECUTION_ARN,
             aws_conn_id=AWS_CONN_ID,
             region_name=REGION_NAME,
+            verify="/spam/egg.pem",
+            botocore_config={"read_timeout": 42},
         )
-
-        # Then
-        assert self.TASK_ID == operator.task_id
-        assert EXECUTION_ARN == operator.execution_arn
-        assert AWS_CONN_ID == operator.aws_conn_id
-        assert REGION_NAME == operator.region_name
-
-    @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook")
-    @pytest.mark.parametrize("response", ["output", "error"])
-    def test_execute(self, mock_hook, response):
-        # Given
-        hook_response = {response: "{}"}
-
-        hook_instance = mock_hook.return_value
-        hook_instance.describe_execution.return_value = hook_response
-
-        operator = StepFunctionGetExecutionOutputOperator(
+        assert op.execution_arn == EXECUTION_ARN
+        assert op.hook.aws_conn_id == AWS_CONN_ID
+        assert op.hook._region_name == REGION_NAME
+        assert op.hook._verify == "/spam/egg.pem"
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = StepFunctionGetExecutionOutputOperator(task_id=self.TASK_ID, execution_arn=EXECUTION_ARN)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+    @mock.patch.object(StepFunctionGetExecutionOutputOperator, "hook")
+    @pytest.mark.parametrize(
+        "response, expected_output",
+        [
+            pytest.param({"output": '{"foo": "bar"}'}, {"foo": "bar"}, id="output"),
+            pytest.param({"error": '{"spam": "egg"}'}, {"spam": "egg"}, id="error"),
+            pytest.param({"other": '{"baz": "qux"}'}, None, id="other"),
+        ],
+    )
+    def test_execute(self, mocked_hook, mocked_context, response, expected_output):
+        mocked_hook.describe_execution.return_value = response
+        op = StepFunctionGetExecutionOutputOperator(
             task_id=self.TASK_ID,
             execution_arn=EXECUTION_ARN,
-            aws_conn_id=AWS_CONN_ID,
-            region_name=REGION_NAME,
+            aws_conn_id=None,
         )
-
-        # When
-        result = operator.execute(self.mock_context)
-
-        # Then
-        assert {} == result
+        assert op.execute(mocked_context) == expected_output
+        mocked_hook.describe_execution.assert_called_once_with(EXECUTION_ARN)
 
 
 class TestStepFunctionStartExecutionOperator:
     TASK_ID = "step_function_start_execution_task"
 
-    def setup_method(self):
-        self.mock_context = MagicMock()
-
     def test_init(self):
-        # Given / When
-        operator = StepFunctionStartExecutionOperator(
+        op = StepFunctionStartExecutionOperator(
             task_id=self.TASK_ID,
             state_machine_arn=STATE_MACHINE_ARN,
             name=NAME,
             state_machine_input=INPUT,
             aws_conn_id=AWS_CONN_ID,
             region_name=REGION_NAME,
+            verify=False,
+            botocore_config={"read_timeout": 42},
         )
-
-        # Then
-        assert self.TASK_ID == operator.task_id
-        assert STATE_MACHINE_ARN == operator.state_machine_arn
-        assert NAME == operator.name
-        assert INPUT == operator.input
-        assert AWS_CONN_ID == operator.aws_conn_id
-        assert REGION_NAME == operator.region_name
-
-    @mock.patch("airflow.providers.amazon.aws.operators.step_function.StepFunctionHook")
-    def test_execute(self, mock_hook):
-        # Given
+        assert op.state_machine_arn == STATE_MACHINE_ARN
+        assert op.state_machine_arn == STATE_MACHINE_ARN
+        assert op.name == NAME
+        assert op.input == INPUT
+        assert op.hook.aws_conn_id == AWS_CONN_ID
+        assert op.hook._region_name == REGION_NAME
+        assert op.hook._verify is False
+        assert op.hook._config is not None
+        assert op.hook._config.read_timeout == 42
+
+        op = StepFunctionStartExecutionOperator(task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN)
+        assert op.hook.aws_conn_id == "aws_default"
+        assert op.hook._region_name is None
+        assert op.hook._verify is None
+        assert op.hook._config is None
+
+    @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+    def test_execute(self, mocked_hook, mocked_context):
         hook_response = (
             "arn:aws:states:us-east-1:123456789012:execution:"
             "pseudo-state-machine:020f5b16-b1a1-4149-946f-92dd32d97934"
         )
-
-        hook_instance = mock_hook.return_value
-        hook_instance.start_execution.return_value = hook_response
-
-        operator = StepFunctionStartExecutionOperator(
+        mocked_hook.start_execution.return_value = hook_response
+        op = StepFunctionStartExecutionOperator(
             task_id=self.TASK_ID,
             state_machine_arn=STATE_MACHINE_ARN,
             name=NAME,
             state_machine_input=INPUT,
-            aws_conn_id=AWS_CONN_ID,
-            region_name=REGION_NAME,
+            aws_conn_id=None,
         )
+        assert op.execute(mocked_context) == hook_response
+        mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
 
-        # When
-        result = operator.execute(self.mock_context)
-
-        # Then
-        assert hook_response == result
-
-    @mock.patch.object(StepFunctionHook, "start_execution")
-    def test_step_function_start_execution_deferrable(self, mock_start_execution):
-        mock_start_execution.return_value = "test-execution-arn"
+    @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+    def test_step_function_start_execution_deferrable(self, mocked_hook):
+        mocked_hook.start_execution.return_value = "test-execution-arn"
         operator = StepFunctionStartExecutionOperator(
             task_id=self.TASK_ID,
             state_machine_arn=STATE_MACHINE_ARN,
@@ -149,3 +149,14 @@ class TestStepFunctionStartExecutionOperator:
         )
         with pytest.raises(TaskDeferred):
             operator.execute(None)
+        mocked_hook.start_execution.assert_called_once_with(STATE_MACHINE_ARN, NAME, INPUT)
+
+    @mock.patch.object(StepFunctionStartExecutionOperator, "hook")
+    @pytest.mark.parametrize("execution_arn", [pytest.param(None, id="none"), pytest.param("", id="empty")])
+    def test_step_function_no_execution_arn_returns(self, mocked_hook, execution_arn):
+        mocked_hook.start_execution.return_value = execution_arn
+        op = StepFunctionStartExecutionOperator(
+            task_id=self.TASK_ID, state_machine_arn=STATE_MACHINE_ARN, aws_conn_id=None
+        )
+        with pytest.raises(AirflowException, match="Failed to start State Machine execution"):
+            op.execute({})
diff --git a/tests/providers/amazon/aws/sensors/test_step_function.py b/tests/providers/amazon/aws/sensors/test_step_function.py
index b6a47d49cd..878691dc1c 100644
--- a/tests/providers/amazon/aws/sensors/test_step_function.py
+++ b/tests/providers/amazon/aws/sensors/test_step_function.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 
 import json
 from unittest import mock
-from unittest.mock import MagicMock
 
 import pytest
 
@@ -35,72 +34,63 @@ AWS_CONN_ID = "aws_non_default"
 REGION_NAME = "us-west-2"
 
 
-class TestStepFunctionExecutionSensor:
-    def setup_method(self):
-        self.mock_context = MagicMock()
-
-    def test_init(self):
-        sensor = StepFunctionExecutionSensor(
-            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
-        )
-
-        assert TASK_ID == sensor.task_id
-        assert EXECUTION_ARN == sensor.execution_arn
-        assert AWS_CONN_ID == sensor.aws_conn_id
-        assert REGION_NAME == sensor.region_name
-
-    @pytest.mark.parametrize("mock_status", ["FAILED", "TIMED_OUT", "ABORTED"])
-    @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
-    def test_exceptions(self, mock_hook, mock_status):
-        hook_response = {"status": mock_status}
-
-        hook_instance = mock_hook.return_value
-        hook_instance.describe_execution.return_value = hook_response
-
-        sensor = StepFunctionExecutionSensor(
-            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
-        )
-
-        with pytest.raises(AirflowException):
-            sensor.poke(self.mock_context)
+@pytest.fixture
+def mocked_context():
+    return mock.MagicMock(name="FakeContext")
 
-    @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
-    def test_running(self, mock_hook):
-        hook_response = {"status": "RUNNING"}
-
-        hook_instance = mock_hook.return_value
-        hook_instance.describe_execution.return_value = hook_response
-
-        sensor = StepFunctionExecutionSensor(
-            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
-        )
-
-        assert not sensor.poke(self.mock_context)
-
-    @mock.patch("airflow.providers.amazon.aws.sensors.step_function.StepFunctionHook")
-    def test_succeeded(self, mock_hook):
-        hook_response = {"status": "SUCCEEDED"}
-
-        hook_instance = mock_hook.return_value
-        hook_instance.describe_execution.return_value = hook_response
 
+class TestStepFunctionExecutionSensor:
+    def test_init(self):
         sensor = StepFunctionExecutionSensor(
-            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
+            task_id=TASK_ID,
+            execution_arn=EXECUTION_ARN,
+            aws_conn_id=AWS_CONN_ID,
+            region_name=REGION_NAME,
+            verify=True,
+            botocore_config={"read_timeout": 42},
         )
-
-        assert sensor.poke(self.mock_context)
-
+        assert sensor.execution_arn == EXECUTION_ARN
+        assert sensor.hook.aws_conn_id == AWS_CONN_ID
+        assert sensor.hook._region_name == REGION_NAME
+        assert sensor.hook._verify is True
+        assert sensor.hook._config is not None
+        assert sensor.hook._config.read_timeout == 42
+
+        sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN)
+        assert sensor.hook.aws_conn_id == "aws_default"
+        assert sensor.hook._region_name is None
+        assert sensor.hook._verify is None
+        assert sensor.hook._config is None
+
+    @mock.patch.object(StepFunctionExecutionSensor, "hook")
+    @pytest.mark.parametrize("status", StepFunctionExecutionSensor.INTERMEDIATE_STATES)
+    def test_running(self, mocked_hook, status, mocked_context):
+        mocked_hook.describe_execution.return_value = {"status": status}
+        sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None)
+        assert sensor.poke(mocked_context) is False
+
+    @mock.patch.object(StepFunctionExecutionSensor, "hook")
+    @pytest.mark.parametrize("status", StepFunctionExecutionSensor.SUCCESS_STATES)
+    def test_succeeded(self, mocked_hook, status, mocked_context):
+        mocked_hook.describe_execution.return_value = {"status": status}
+        sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None)
+        assert sensor.poke(mocked_context) is True
+
+    @mock.patch.object(StepFunctionExecutionSensor, "hook")
+    @pytest.mark.parametrize("status", StepFunctionExecutionSensor.FAILURE_STATES)
     @pytest.mark.parametrize(
-        "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
+        "soft_fail, expected_exception",
+        [
+            pytest.param(True, AirflowSkipException, id="soft-fail"),
+            pytest.param(False, AirflowException, id="non-soft-fail"),
+        ],
     )
-    @mock.patch("airflow.providers.amazon.aws.hooks.step_function.StepFunctionHook.describe_execution")
-    def test_fail_poke(self, describe_execution, soft_fail, expected_exception):
+    def test_failure(self, mocked_hook, status, soft_fail, expected_exception, mocked_context):
+        output = {"test": "test"}
+        mocked_hook.describe_execution.return_value = {"status": status, "output": json.dumps(output)}
         sensor = StepFunctionExecutionSensor(
-            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=AWS_CONN_ID, region_name=REGION_NAME
+            task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None, soft_fail=soft_fail
         )
-        sensor.soft_fail = soft_fail
-        output = '{"test": "test"}'
-        describe_execution.return_value = {"status": "FAILED", "output": output}
-        message = f"Step Function sensor failed. State Machine Output: {json.loads(output)}"
+        message = f"Step Function sensor failed. State Machine Output: {output}"
         with pytest.raises(expected_exception, match=message):
-            sensor.poke(context={})
+            sensor.poke(mocked_context)