You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by vi...@apache.org on 2023/10/06 13:34:31 UTC
[airflow] branch main updated: Do not mock `isinstance` in Amazon Tests (#34800)
This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7e896c513a Do not mock `isinstance` in Amazon Tests (#34800)
7e896c513a is described below
commit 7e896c513a0db84a44d03e682451f55f211a61a4
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Fri Oct 6 17:34:24 2023 +0400
Do not mock `isinstance` in Amazon Tests (#34800)
---
tests/providers/amazon/aws/hooks/test_base_aws.py | 73 +++++------
.../amazon/aws/hooks/test_emr_containers.py | 72 +++++------
.../amazon/aws/operators/test_cloud_formation.py | 48 ++------
.../amazon/aws/operators/test_emr_add_steps.py | 133 +++++++++------------
.../amazon/aws/operators/test_emr_containers.py | 46 +++----
.../aws/operators/test_emr_create_job_flow.py | 77 ++++--------
.../aws/operators/test_emr_modify_cluster.py | 42 ++-----
.../aws/operators/test_emr_terminate_job_flow.py | 60 ++++------
.../amazon/aws/sensors/test_cloud_formation.py | 93 +++++---------
.../amazon/aws/sensors/test_emr_job_flow.py | 119 ++++++++----------
.../providers/amazon/aws/sensors/test_emr_step.py | 99 +++++----------
11 files changed, 326 insertions(+), 536 deletions(-)
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index 57350cb741..ccbcd4fff3 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -390,17 +390,19 @@ class TestAwsBaseHook:
assert table.item_count == 0
@pytest.mark.parametrize(
- "client_meta",
+ "hook_params",
[
- AwsBaseHook(client_type="s3").get_client_type().meta,
- AwsBaseHook(resource_type="dynamodb").get_resource_type().meta.client.meta,
+ pytest.param({"client_type": "s3"}, id="client-type"),
+ pytest.param({"resource_type": "dynamodb"}, id="resource-type"),
],
)
- def test_user_agent_extra_update(self, client_meta):
+ def test_user_agent_extra_update(self, hook_params):
"""
We are only looking for the keys appended by the AwsBaseHook. A user_agent string
is a number of key/value pairs such as: `BOTO3/1.25.4 AIRFLOW/2.5.0.DEV0 AMPP/6.0.0`.
"""
+ client_meta = AwsBaseHook(aws_conn_id=None, client_type="s3").conn_client_meta
+
expected_user_agent_tag_keys = ["Airflow", "AmPP", "Caller", "DagRunKey"]
result_user_agent_tags = client_meta.config.user_agent.split(" ")
@@ -477,31 +479,25 @@ class TestAwsBaseHook:
return sts_response
with mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
- ) as mock_get, mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.boto3"
- ) as mock_boto3, mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- mock_get.return_value.ok = True
-
- mock_client = mock_boto3.session.Session.return_value.client
+ "airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
+ spec=boto3.session.Session,
+ ) as mocked_basic_session:
+ mocked_basic_session.return_value.region_name = "us-east-2"
+ mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role.side_effect = mock_assume_role
- hook = AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3")
- hook.get_client_type("s3")
-
- calls_assume_role = [
- mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
- mock.call.session.Session()
- .client()
- .assume_role(
- RoleArn=role_arn,
- RoleSessionName=slugified_role_session_name,
- ),
- ]
- mock_boto3.assert_has_calls(calls_assume_role)
+ AwsBaseHook(aws_conn_id=aws_conn_id, client_type="s3").get_client_type()
+ mocked_basic_session.assert_has_calls(
+ [
+ mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
+ mock.call()
+ .client()
+ .assume_role(
+ RoleArn=role_arn,
+ RoleSessionName=slugified_role_session_name,
+ ),
+ ]
+ )
def test_get_credentials_from_gcp_credentials(self):
mock_connection = Connection(
@@ -684,25 +680,21 @@ class TestAwsBaseHook:
with mock.patch("builtins.__import__", side_effect=import_mock), mock.patch(
"airflow.providers.amazon.aws.hooks.base_aws.requests.Session.get"
) as mock_get, mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.boto3"
- ) as mock_boto3, mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- mock_get.return_value.ok = True
-
- mock_client = mock_boto3.session.Session.return_value.client
+ "airflow.providers.amazon.aws.hooks.base_aws.BaseSessionFactory._create_basic_session",
+ spec=boto3.session.Session,
+ ) as mocked_basic_session:
+ mocked_basic_session.return_value.region_name = "us-east-2"
+ mock_client = mocked_basic_session.return_value.client
mock_client.return_value.assume_role_with_saml.side_effect = mock_assume_role_with_saml
- hook = AwsBaseHook(aws_conn_id="aws_default", client_type="s3")
- hook.get_client_type("s3")
+ AwsBaseHook(aws_conn_id="aws_default", client_type="s3").get_client_type()
mock_get.assert_called_once_with(idp_url, auth=mock_auth)
mock_xpath.assert_called_once_with(xpath)
- calls_assume_role_with_saml = [
- mock.call.session.Session().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
- mock.call.session.Session()
+ mocked_basic_session.assert_has_calls = [
+ mock.call().client("sts", config=mock.ANY, endpoint_url=sts_endpoint),
+ mock.call()
.client()
.assume_role_with_saml(
DurationSeconds=duration_seconds,
@@ -711,7 +703,6 @@ class TestAwsBaseHook:
SAMLAssertion=encoded_saml_assertion,
),
]
- mock_boto3.assert_has_calls(calls_assume_role_with_saml)
@mock_iam
def test_expand_role(self):
diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py
index 9be48a08ba..6a6c32bfb3 100644
--- a/tests/providers/amazon/aws/hooks/test_emr_containers.py
+++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py
@@ -19,6 +19,8 @@ from __future__ import annotations
from unittest import mock
+import pytest
+
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
SUBMIT_JOB_SUCCESS_RETURN = {
@@ -46,6 +48,12 @@ JOB2_RUN_DESCRIPTION = {
}
+@pytest.fixture
+def mocked_hook_client():
+ with mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook.conn") as m:
+ yield m
+
+
class TestEmrContainerHook:
def setup_method(self):
self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")
@@ -54,14 +62,8 @@ class TestEmrContainerHook:
assert self.emr_containers.aws_conn_id == "aws_default"
assert self.emr_containers.virtual_cluster_id == "vc1234"
- @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
- @mock.patch("boto3.session.Session")
- def test_create_emr_on_eks_cluster(self, mock_session, mock_isinstance):
- emr_client_mock = mock.MagicMock()
- emr_client_mock.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN
- emr_session_mock = mock.MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- mock_session.return_value = emr_session_mock
+ def test_create_emr_on_eks_cluster(self, mocked_hook_client):
+ mocked_hook_client.create_virtual_cluster.return_value = CREATE_EMR_ON_EKS_CLUSTER_RETURN
emr_on_eks_create_cluster_response = self.emr_containers.create_emr_on_eks_cluster(
virtual_cluster_name="test_virtual_cluster",
@@ -70,15 +72,19 @@ class TestEmrContainerHook:
)
assert emr_on_eks_create_cluster_response == "vc1234"
- @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
- @mock.patch("boto3.session.Session")
- def test_submit_job(self, mock_session, mock_isinstance):
+ mocked_hook_client.create_virtual_cluster.assert_called_once_with(
+ name="test_virtual_cluster",
+ containerProvider={
+ "id": "test_eks_cluster",
+ "type": "EKS",
+ "info": {"eksInfo": {"namespace": "test_eks_namespace"}},
+ },
+ tags={},
+ )
+
+ def test_submit_job(self, mocked_hook_client):
# Mock out the emr_client creator
- emr_client_mock = mock.MagicMock()
- emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
- emr_session_mock = mock.MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- mock_session.return_value = emr_session_mock
+ mocked_hook_client.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
emr_containers_job = self.emr_containers.submit_job(
name="test-job-run",
@@ -90,32 +96,30 @@ class TestEmrContainerHook:
)
assert emr_containers_job == "job123456"
- @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
- @mock.patch("boto3.session.Session")
- def test_query_status_polling_when_terminal(self, mock_session, mock_isinstance):
- emr_client_mock = mock.MagicMock()
- emr_session_mock = mock.MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- mock_session.return_value = emr_session_mock
- emr_client_mock.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
+ mocked_hook_client.start_job_run.assert_called_once_with(
+ name="test-job-run",
+ virtualClusterId="vc1234",
+ executionRoleArn="arn:aws:somerole",
+ releaseLabel="emr-6.3.0-latest",
+ jobDriver={},
+ configurationOverrides={},
+ tags={},
+ clientToken="uuidtoken",
+ )
+ def test_query_status_polling_when_terminal(self, mocked_hook_client):
+ mocked_hook_client.describe_job_run.return_value = JOB1_RUN_DESCRIPTION
query_status = self.emr_containers.poll_query_status(job_id="job123456")
# should only poll once since query is already in terminal state
- emr_client_mock.describe_job_run.assert_called_once()
+ mocked_hook_client.describe_job_run.assert_called_once()
assert query_status == "COMPLETED"
- @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.isinstance", return_value=True)
- @mock.patch("boto3.session.Session")
- def test_query_status_polling_with_timeout(self, mock_session, mock_isinstance):
- emr_client_mock = mock.MagicMock()
- emr_session_mock = mock.MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- mock_session.return_value = emr_session_mock
- emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
+ def test_query_status_polling_with_timeout(self, mocked_hook_client):
+ mocked_hook_client.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
query_status = self.emr_containers.poll_query_status(
job_id="job123456", max_polling_attempts=2, poll_interval=0
)
# should poll until max_tries is reached since query is in non-terminal state
- assert emr_client_mock.describe_job_run.call_count == 2
+ assert mocked_hook_client.describe_job_run.call_count == 2
assert query_status == "RUNNING"
diff --git a/tests/providers/amazon/aws/operators/test_cloud_formation.py b/tests/providers/amazon/aws/operators/test_cloud_formation.py
index 1a1088adab..df54596b02 100644
--- a/tests/providers/amazon/aws/operators/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/operators/test_cloud_formation.py
@@ -20,6 +20,8 @@ from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock
+import pytest
+
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.cloud_formation import (
CloudFormationCreateStackOperator,
@@ -31,19 +33,14 @@ DEFAULT_DATE = timezone.datetime(2019, 1, 1)
DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE}
-class TestCloudFormationCreateStackOperator:
- def setup_method(self):
- # Mock out the cloudformation_client (moto fails with an exception).
- self.cloudformation_client_mock = MagicMock()
-
- # Mock out the emr_client creator
- cloudformation_session_mock = MagicMock()
- cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
- self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
+@pytest.fixture
+def mocked_hook_client():
+ with mock.patch("airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook.conn") as m:
+ yield m
- self.mock_context = MagicMock()
- def test_create_stack(self):
+class TestCloudFormationCreateStackOperator:
+ def test_create_stack(self, mocked_hook_client):
stack_name = "myStack"
timeout = 15
template_body = "My stack body"
@@ -55,30 +52,15 @@ class TestCloudFormationCreateStackOperator:
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)
- with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator.execute(self.mock_context)
+ operator.execute(MagicMock())
- self.cloudformation_client_mock.create_stack.assert_any_call(
+ mocked_hook_client.create_stack.assert_any_call(
StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout
)
class TestCloudFormationDeleteStackOperator:
- def setup_method(self):
- # Mock out the cloudformation_client (moto fails with an exception).
- self.cloudformation_client_mock = MagicMock()
-
- # Mock out the emr_client creator
- cloudformation_session_mock = MagicMock()
- cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
- self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
-
- self.mock_context = MagicMock()
-
- def test_delete_stack(self):
+ def test_delete_stack(self, mocked_hook_client):
stack_name = "myStackToBeDeleted"
operator = CloudFormationDeleteStackOperator(
@@ -87,10 +69,6 @@ class TestCloudFormationDeleteStackOperator:
dag=DAG("test_dag_id", default_args=DEFAULT_ARGS),
)
- with mock.patch("boto3.session.Session", self.boto3_session_mock), mock.patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator.execute(self.mock_context)
+ operator.execute(MagicMock())
- self.cloudformation_client_mock.delete_stack.assert_any_call(StackName=stack_name)
+ mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name)
diff --git a/tests/providers/amazon/aws/operators/test_emr_add_steps.py b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
index a2c69b6599..335452b131 100644
--- a/tests/providers/amazon/aws/operators/test_emr_add_steps.py
+++ b/tests/providers/amazon/aws/operators/test_emr_add_steps.py
@@ -27,8 +27,6 @@ from jinja2 import StrictUndefined
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator
from airflow.providers.amazon.aws.triggers.emr import EmrAddStepsTrigger
from airflow.utils import timezone
@@ -43,6 +41,12 @@ TEMPLATE_SEARCHPATH = os.path.join(
)
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
+
+
class TestEmrAddStepsOperator:
# When
_config = [
@@ -58,17 +62,6 @@ class TestEmrAddStepsOperator:
def setup_method(self):
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
-
- # Mock out the emr_client (moto has incorrect response)
- self.emr_client_mock = MagicMock()
-
- # Mock out the emr_client creator
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
-
- self.mock_context = MagicMock()
-
self.operator = EmrAddStepsOperator(
task_id="test_task",
job_flow_id="j-8989898989",
@@ -78,8 +71,14 @@ class TestEmrAddStepsOperator:
)
def test_init(self):
- assert self.operator.job_flow_id == "j-8989898989"
- assert self.operator.aws_conn_id == "aws_default"
+ op = EmrAddStepsOperator(
+ task_id="test_task",
+ job_flow_id="j-8989898989",
+ aws_conn_id="aws_default",
+ steps=self._config,
+ )
+ assert op.job_flow_id == "j-8989898989"
+ assert op.aws_conn_id == "aws_default"
@pytest.mark.parametrize(
"job_flow_id, job_flow_name",
@@ -120,8 +119,7 @@ class TestEmrAddStepsOperator:
assert self.operator.steps == expected_args
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_render_template_from_file(self, _):
+ def test_render_template_from_file(self, mocked_hook_client):
dag = DAG(
dag_id="test_file",
default_args=self.args,
@@ -137,7 +135,7 @@ class TestEmrAddStepsOperator:
}
]
- self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
+ mocked_hook_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
test_task = EmrAddStepsOperator(
task_id="test_task",
@@ -155,78 +153,57 @@ class TestEmrAddStepsOperator:
assert json.loads(test_task.steps) == file_steps
# String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute()
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- test_task.execute(None)
+ test_task.execute(MagicMock())
- self.emr_client_mock.add_job_flow_steps.assert_called_once_with(
+ mocked_hook_client.add_job_flow_steps.assert_called_once_with(
JobFlowId="j-8989898989", Steps=file_steps
)
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_returns_step_id(self, _):
- self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
+ def test_execute_returns_step_id(self, mocked_hook_client):
+ mocked_hook_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- assert self.operator.execute(self.mock_context) == ["s-2LH3R5GW3A53T"]
+ assert self.operator.execute(MagicMock()) == ["s-2LH3R5GW3A53T"]
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_init_with_cluster_name(self, _):
+ def test_init_with_cluster_name(self, mocked_hook_client):
+ mocked_hook_client.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
+ mock_context = MagicMock()
expected_job_flow_id = "j-1231231234"
- self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
-
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with patch(
- "airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name"
- ) as mock_get_cluster_id_by_name:
- mock_get_cluster_id_by_name.return_value = expected_job_flow_id
-
- operator = EmrAddStepsOperator(
- task_id="test_task",
- job_flow_name="test_cluster",
- cluster_states=["RUNNING", "WAITING"],
- aws_conn_id="aws_default",
- dag=DAG("test_dag_id", default_args=self.args),
- )
+ operator = EmrAddStepsOperator(
+ task_id="test_task",
+ job_flow_name="test_cluster",
+ cluster_states=["RUNNING", "WAITING"],
+ aws_conn_id="aws_default",
+ dag=DAG("test_dag_id", default_args=self.args),
+ )
- operator.execute(self.mock_context)
+ with patch(
+ "airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name",
+ return_value=expected_job_flow_id,
+ ):
+ operator.execute(mock_context)
- ti = self.mock_context["ti"]
- ti.assert_has_calls(calls=[call.xcom_push(key="job_flow_id", value=expected_job_flow_id)])
+ mocked_ti = mock_context["ti"]
+ mocked_ti.assert_has_calls(calls=[call.xcom_push(key="job_flow_id", value=expected_job_flow_id)])
def test_init_with_nonexistent_cluster_name(self):
cluster_name = "test_cluster"
+ operator = EmrAddStepsOperator(
+ task_id="test_task",
+ job_flow_name=cluster_name,
+ cluster_states=["RUNNING", "WAITING"],
+ aws_conn_id="aws_default",
+ dag=DAG("test_dag_id", default_args=self.args),
+ )
with patch(
- "airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name"
- ) as mock_get_cluster_id_by_name:
- mock_get_cluster_id_by_name.return_value = None
-
- operator = EmrAddStepsOperator(
- task_id="test_task",
- job_flow_name=cluster_name,
- cluster_states=["RUNNING", "WAITING"],
- aws_conn_id="aws_default",
- dag=DAG("test_dag_id", default_args=self.args),
- )
-
- with pytest.raises(AirflowException) as ctx:
- operator.execute(self.mock_context)
- assert str(ctx.value) == f"No cluster found for name: {cluster_name}"
+ "airflow.providers.amazon.aws.hooks.emr.EmrHook.get_cluster_id_by_name", return_value=None
+ ):
+ error_match = rf"No cluster found for name: {cluster_name}"
+ with pytest.raises(AirflowException, match=error_match):
+ operator.execute(MagicMock())
- @patch.object(EmrHook, "conn")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- @patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps")
- def test_wait_for_completion(self, mock_add_job_flow_steps, *_):
+ def test_wait_for_completion(self, mocked_hook_client):
job_flow_id = "j-8989898989"
operator = EmrAddStepsOperator(
task_id="test_task",
@@ -235,7 +212,11 @@ class TestEmrAddStepsOperator:
dag=DAG("test_dag_id", default_args=self.args),
wait_for_completion=False,
)
- operator.execute(self.mock_context)
+
+ with patch(
+ "airflow.providers.amazon.aws.hooks.emr.EmrHook.add_job_flow_steps"
+ ) as mock_add_job_flow_steps:
+ operator.execute(MagicMock())
mock_add_job_flow_steps.assert_called_once_with(
job_flow_id=job_flow_id,
@@ -275,6 +256,6 @@ class TestEmrAddStepsOperator:
)
with pytest.raises(TaskDeferred) as exc:
- operator.execute(self.mock_context)
+ operator.execute(MagicMock())
assert isinstance(exc.value.trigger, EmrAddStepsTrigger), "Trigger is not a EmrAddStepsTrigger"
diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py
index 00a2eb22aa..d0a4351155 100644
--- a/tests/providers/amazon/aws/operators/test_emr_containers.py
+++ b/tests/providers/amazon/aws/operators/test_emr_containers.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from unittest import mock
-from unittest.mock import MagicMock, patch
+from unittest.mock import patch
import pytest
@@ -38,6 +38,12 @@ CREATE_EMR_ON_EKS_CLUSTER_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200},
GENERATED_UUID = "800647a9-adda-4237-94e6-f542c85fa55b"
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.conn") as m:
+ yield m
+
+
class TestEmrContainerOperator:
def setup_method(self):
conf.load_test_config()
@@ -78,20 +84,11 @@ class TestEmrContainerOperator:
"check_query_status",
side_effect=["PENDING", "PENDING", "SUBMITTED", "RUNNING", "COMPLETED"],
)
- def test_execute_with_polling(self, mock_check_query_status):
+ def test_execute_with_polling(self, mock_check_query_status, mocked_hook_client):
# Mock out the emr_client creator
- emr_client_mock = MagicMock()
- emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
-
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- assert self.emr_container.execute(None) == "job123456"
- assert mock_check_query_status.call_count == 5
+ mocked_hook_client.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
+ assert self.emr_container.execute(None) == "job123456"
+ assert mock_check_query_status.call_count == 5
@mock.patch.object(EmrContainerHook, "submit_job")
@mock.patch.object(EmrContainerHook, "check_query_status")
@@ -114,13 +111,9 @@ class TestEmrContainerOperator:
"check_query_status",
side_effect=["PENDING", "PENDING", "SUBMITTED", "RUNNING", "COMPLETED"],
)
- def test_execute_with_polling_timeout(self, mock_check_query_status):
+ def test_execute_with_polling_timeout(self, mock_check_query_status, mocked_hook_client):
# Mock out the emr_client creator
- emr_client_mock = MagicMock()
- emr_client_mock.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
+ mocked_hook_client.start_job_run.return_value = SUBMIT_JOB_SUCCESS_RETURN
timeout_container = EmrContainerOperator(
task_id="start_job",
@@ -134,16 +127,11 @@ class TestEmrContainerOperator:
max_polling_attempts=3,
)
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(AirflowException) as ctx:
- timeout_container.execute(None)
+ error_match = "Final state of EMR Containers job is SUBMITTED.*Max tries of poll status exceeded"
+ with pytest.raises(AirflowException, match=error_match):
+ timeout_container.execute(None)
- assert mock_check_query_status.call_count == 3
- assert "Final state of EMR Containers job is SUBMITTED" in str(ctx.value)
- assert "Max tries of poll status exceeded" in str(ctx.value)
+ assert mock_check_query_status.call_count == 3
@mock.patch.object(EmrContainerHook, "submit_job")
@mock.patch.object(
diff --git a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
index 7e531f4a0f..82cb2e245b 100644
--- a/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_create_job_flow.py
@@ -28,7 +28,6 @@ from jinja2 import StrictUndefined
from airflow.exceptions import TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrCreateJobFlowTrigger
from airflow.utils import timezone
@@ -49,6 +48,12 @@ TEMPLATE_SEARCHPATH = os.path.join(
)
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
+
+
class TestEmrCreateJobFlowOperator:
# When
_config = {
@@ -68,9 +73,6 @@ class TestEmrCreateJobFlowOperator:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
-
- # Mock out the emr_client (moto has incorrect response)
- self.emr_client_mock = MagicMock()
self.operator = EmrCreateJobFlowOperator(
task_id=TASK_ID,
aws_conn_id="aws_default",
@@ -118,8 +120,7 @@ class TestEmrCreateJobFlowOperator:
assert self.operator.job_flow_overrides == expected_args
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_render_template_from_file(self, _):
+ def test_render_template_from_file(self, mocked_hook_client):
self.operator.job_flow_overrides = "job.j2.json"
self.operator.params = {"releaseLabel": "5.11.0"}
@@ -128,17 +129,10 @@ class TestEmrCreateJobFlowOperator:
ti.dag_run = dag_run
ti.render_templates()
- self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
+ mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
# String in job_flow_overrides (i.e. from loaded as a file) is not "parsed" until inside execute()
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.operator.execute(self.mock_context)
+ self.operator.execute(self.mock_context)
expected_args = {
"Name": "test_job_flow",
@@ -161,61 +155,32 @@ class TestEmrCreateJobFlowOperator:
assert self.operator.job_flow_overrides == expected_args
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_returns_job_id(self, _):
- self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
-
- # Mock out the emr_client creator
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
-
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
+ def test_execute_returns_job_id(self, mocked_hook_client):
+ mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
+ assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
@mock.patch("botocore.waiter.get_service_module_name", return_value="emr")
@mock.patch.object(Waiter, "wait")
- def test_execute_with_wait(self, mock_waiter, *_):
- self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
+ def test_execute_with_wait(self, mock_waiter, _, mocked_hook_client):
+ mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
# Mock out the emr_client creator
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
self.operator.wait_for_completion = True
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
- mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
- assert_expected_waiter_type(mock_waiter, "job_flow_waiting")
+ assert self.operator.execute(self.mock_context) == JOB_FLOW_ID
+ mock_waiter.assert_called_once_with(mock.ANY, ClusterId=JOB_FLOW_ID, WaiterConfig=mock.ANY)
+ assert_expected_waiter_type(mock_waiter, "job_flow_waiting")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_create_job_flow_deferrable(self, _):
+ def test_create_job_flow_deferrable(self, mocked_hook_client):
"""
Test to make sure that the operator raises a TaskDeferred exception
if run in deferrable mode.
"""
- self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
-
- # Mock out the emr_client creator
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- boto3_session_mock = MagicMock(return_value=emr_session_mock)
+ mocked_hook_client.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
self.operator.deferrable = True
- with patch("boto3.session.Session", boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(TaskDeferred) as exc:
- self.operator.execute(self.mock_context)
+ with pytest.raises(TaskDeferred) as exc:
+ self.operator.execute(self.mock_context)
assert isinstance(
exc.value.trigger, EmrCreateJobFlowTrigger
diff --git a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
index f29871e521..98d8ba9989 100644
--- a/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_emr_modify_cluster.py
@@ -23,31 +23,25 @@ import pytest
from airflow.exceptions import AirflowException
from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrModifyClusterOperator
from airflow.utils import timezone
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
-
MODIFY_CLUSTER_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}, "StepConcurrencyLevel": 1}
-
MODIFY_CLUSTER_ERROR_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 400}}
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
+
+
class TestEmrModifyClusterOperator:
def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
- # Mock out the emr_client (moto has incorrect response)
- self.emr_client_mock = MagicMock()
-
- # Mock out the emr_client creator
- emr_session_mock = MagicMock()
- emr_session_mock.client.return_value = self.emr_client_mock
- self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
-
self.mock_context = MagicMock()
-
self.operator = EmrModifyClusterOperator(
task_id="test_task",
cluster_id="j-8989898989",
@@ -61,23 +55,13 @@ class TestEmrModifyClusterOperator:
assert self.operator.step_concurrency_level == 1
assert self.operator.aws_conn_id == "aws_default"
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_returns_step_concurrency(self, _):
- self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_SUCCESS_RETURN
+ def test_execute_returns_step_concurrency(self, mocked_hook_client):
+ mocked_hook_client.modify_cluster.return_value = MODIFY_CLUSTER_SUCCESS_RETURN
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- assert self.operator.execute(self.mock_context) == 1
+ assert self.operator.execute(self.mock_context) == 1
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_returns_error(self, _):
- self.emr_client_mock.modify_cluster.return_value = MODIFY_CLUSTER_ERROR_RETURN
+ def test_execute_returns_error(self, mocked_hook_client):
+ mocked_hook_client.modify_cluster.return_value = MODIFY_CLUSTER_ERROR_RETURN
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(AirflowException):
- self.operator.execute(self.mock_context)
+ with pytest.raises(AirflowException, match="Modify cluster failed"):
+ self.operator.execute(self.mock_context)
diff --git a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
index 509dcfee0c..2c27c146d2 100644
--- a/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
+++ b/tests/providers/amazon/aws/operators/test_emr_terminate_job_flow.py
@@ -22,52 +22,38 @@ from unittest.mock import MagicMock, patch
import pytest
from airflow.exceptions import TaskDeferred
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator
from airflow.providers.amazon.aws.triggers.emr import EmrTerminateJobFlowTrigger
TERMINATE_SUCCESS_RETURN = {"ResponseMetadata": {"HTTPStatusCode": 200}}
-class TestEmrTerminateJobFlowOperator:
- def setup_method(self):
- # Mock out the emr_client (moto has incorrect response)
- mock_emr_client = MagicMock()
- mock_emr_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN
-
- mock_emr_session = MagicMock()
- mock_emr_session.client.return_value = mock_emr_client
-
- # Mock out the emr_client creator
- self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_terminates_the_job_flow_and_does_not_error(self, _):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator = EmrTerminateJobFlowOperator(
- task_id="test_task", job_flow_id="j-8989898989", aws_conn_id="aws_default"
- )
+class TestEmrTerminateJobFlowOperator:
+ def test_execute_terminates_the_job_flow_and_does_not_error(self, mocked_hook_client):
+ mocked_hook_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN
+ operator = EmrTerminateJobFlowOperator(
+ task_id="test_task", job_flow_id="j-8989898989", aws_conn_id="aws_default"
+ )
+
+ operator.execute(MagicMock())
+
+ def test_create_job_flow_deferrable(self, mocked_hook_client):
+ mocked_hook_client.terminate_job_flows.return_value = TERMINATE_SUCCESS_RETURN
+ operator = EmrTerminateJobFlowOperator(
+ task_id="test_task",
+ job_flow_id="j-8989898989",
+ aws_conn_id="aws_default",
+ deferrable=True,
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
operator.execute(MagicMock())
-
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_create_job_flow_deferrable(self, _):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator = EmrTerminateJobFlowOperator(
- task_id="test_task",
- job_flow_id="j-8989898989",
- aws_conn_id="aws_default",
- deferrable=True,
- )
- with pytest.raises(TaskDeferred) as exc:
- operator.execute(MagicMock())
-
assert isinstance(
exc.value.trigger, EmrTerminateJobFlowTrigger
), "Trigger is not a EmrTerminateJobFlowTrigger"
diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
index 9aef7fae6d..51b9c385f1 100644
--- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py
+++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-from unittest.mock import MagicMock, patch
+from unittest.mock import patch
import boto3
import pytest
@@ -29,65 +29,40 @@ from airflow.providers.amazon.aws.sensors.cloud_formation import (
)
-class TestCloudFormationCreateStackSensor:
- task_id = "test_cloudformation_cluster_create_sensor"
+@pytest.fixture
+def mocked_hook_client():
+ with patch("airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook.conn") as m:
+ yield m
+
+class TestCloudFormationCreateStackSensor:
@mock_cloudformation
def setup_method(self, method):
self.client = boto3.client("cloudformation", region_name="us-east-1")
- self.cloudformation_client_mock = MagicMock()
-
- cloudformation_session_mock = MagicMock()
- cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
-
- self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
-
@mock_cloudformation
def test_poke(self):
- stack_name = "foobar"
- self.client.create_stack(StackName=stack_name, TemplateBody='{"Resources": {}}')
+ self.client.create_stack(StackName="foobar", TemplateBody='{"Resources": {}}')
op = CloudFormationCreateStackSensor(task_id="task", stack_name="foobar")
assert op.poke({})
- def test_poke_false(self):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.cloudformation_client_mock.describe_stacks.return_value = {
- "Stacks": [{"StackStatus": "CREATE_IN_PROGRESS"}]
- }
- op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
- assert not op.poke({})
-
- def test_poke_stack_in_unsuccessful_state(self):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.cloudformation_client_mock.describe_stacks.return_value = {
- "Stacks": [{"StackStatus": "bar"}]
- }
- with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
- op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
- op.poke({})
+ def test_poke_false(self, mocked_hook_client):
+ mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "CREATE_IN_PROGRESS"}]}
+ op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
+ assert not op.poke({})
+ def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client):
+ mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]}
+ op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo")
+ with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
+ op.poke({})
-class TestCloudFormationDeleteStackSensor:
- task_id = "test_cloudformation_cluster_delete_sensor"
+class TestCloudFormationDeleteStackSensor:
@mock_cloudformation
def setup_method(self, method):
self.client = boto3.client("cloudformation", region_name="us-east-1")
- self.cloudformation_client_mock = MagicMock()
-
- cloudformation_session_mock = MagicMock()
- cloudformation_session_mock.client.return_value = self.cloudformation_client_mock
-
- self.boto3_session_mock = MagicMock(return_value=cloudformation_session_mock)
-
@mock_cloudformation
def test_poke(self):
stack_name = "foobar"
@@ -96,28 +71,16 @@ class TestCloudFormationDeleteStackSensor:
op = CloudFormationDeleteStackSensor(task_id="task", stack_name=stack_name)
assert op.poke({})
- def test_poke_false(self):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.cloudformation_client_mock.describe_stacks.return_value = {
- "Stacks": [{"StackStatus": "DELETE_IN_PROGRESS"}]
- }
- op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
- assert not op.poke({})
-
- def test_poke_stack_in_unsuccessful_state(self):
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.cloudformation_client_mock.describe_stacks.return_value = {
- "Stacks": [{"StackStatus": "bar"}]
- }
- with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
- op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
- op.poke({})
+ def test_poke_false(self, mocked_hook_client):
+ mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "DELETE_IN_PROGRESS"}]}
+ op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
+ assert not op.poke({})
+
+ def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client):
+ mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]}
+ op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo")
+ with pytest.raises(ValueError, match="Stack foo in bad state: bar"):
+ op.poke({})
@mock_cloudformation
def test_poke_stack_does_not_exist(self):
diff --git a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
index ffad3c0ce5..a0a56039fa 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_job_flow.py
@@ -190,67 +190,53 @@ DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN = {
}
-class TestEmrJobFlowSensor:
- def setup_method(self):
- # Mock out the emr_client (moto has incorrect response)
- self.mock_emr_client = MagicMock()
-
- mock_emr_session = MagicMock()
- mock_emr_session.client.return_value = self.mock_emr_client
+@pytest.fixture
+def mocked_hook_client():
+ with mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
- # Mock out the emr_client creator
- self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
- # Mock context used in execute function
- self.mock_ctx = MagicMock()
-
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self, _):
- self.mock_emr_client.describe_cluster.side_effect = [
+class TestEmrJobFlowSensor:
+ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_target_state(self, mocked_hook_client):
+ mocked_hook_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN,
DESCRIBE_CLUSTER_RUNNING_RETURN,
DESCRIBE_CLUSTER_TERMINATED_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator = EmrJobFlowSensor(
- task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
- )
- operator.execute(self.mock_ctx)
-
- assert self.mock_emr_client.describe_cluster.call_count == 3
+ operator = EmrJobFlowSensor(
+ task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
+ )
+ with patch.object(S3Hook, "parse_s3_url", return_value="valid_uri"):
+ operator.execute(MagicMock())
- # make sure it was called with the job_flow_id
- calls = [mock.call(ClusterId="j-8989898989")]
- self.mock_emr_client.describe_cluster.assert_has_calls(calls)
+ assert mocked_hook_client.describe_cluster.call_count == 3
+ # make sure it was called with the job_flow_id
+ calls = [mock.call(ClusterId="j-8989898989")] * 3
+ mocked_hook_client.describe_cluster.assert_has_calls(calls)
- def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_exception(self):
- self.mock_emr_client.describe_cluster.side_effect = [
+ def test_execute_calls_with_the_job_flow_id_until_it_reaches_failed_state_with_exception(
+ self, mocked_hook_client
+ ):
+ mocked_hook_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_RUNNING_RETURN,
DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator = EmrJobFlowSensor(
- task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
- )
- with pytest.raises(AirflowException):
- operator.execute(self.mock_ctx)
-
- # make sure we called twice
- assert self.mock_emr_client.describe_cluster.call_count == 2
+ operator = EmrJobFlowSensor(
+ task_id="test_task", poke_interval=0, job_flow_id="j-8989898989", aws_conn_id="aws_default"
+ )
+ with pytest.raises(AirflowException):
+ operator.execute(MagicMock())
- # make sure it was called with the job_flow_id
- self.mock_emr_client.describe_cluster.assert_called_once_with(ClusterId="j-8989898989")
+ # make sure we called twice
+ assert mocked_hook_client.describe_cluster.call_count == 2
+ # make sure it was called with the job_flow_id
+ calls = [mock.call(ClusterId="j-8989898989")] * 2
+ mocked_hook_client.describe_cluster.assert_has_calls(calls=calls)
- def test_different_target_states(self):
- self.mock_emr_client.describe_cluster.side_effect = [
+ def test_different_target_states(self, mocked_hook_client):
+ mocked_hook_client.describe_cluster.side_effect = [
DESCRIBE_CLUSTER_STARTING_RETURN, # return False
DESCRIBE_CLUSTER_BOOTSTRAPPING_RETURN, # return False
DESCRIBE_CLUSTER_RUNNING_RETURN, # return True
@@ -258,28 +244,23 @@ class TestEmrJobFlowSensor:
DESCRIBE_CLUSTER_TERMINATED_RETURN, # will not be used
DESCRIBE_CLUSTER_TERMINATED_WITH_ERRORS_RETURN, # will not be used
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- operator = EmrJobFlowSensor(
- task_id="test_task",
- poke_interval=0,
- job_flow_id="j-8989898989",
- aws_conn_id="aws_default",
- target_states=["RUNNING", "WAITING"],
- )
- operator.execute(self.mock_ctx)
+ operator = EmrJobFlowSensor(
+ task_id="test_task",
+ poke_interval=0,
+ job_flow_id="j-8989898989",
+ aws_conn_id="aws_default",
+ target_states=["RUNNING", "WAITING"],
+ )
- assert self.mock_emr_client.describe_cluster.call_count == 3
+ operator.execute(MagicMock())
- # make sure it was called with the job_flow_id
- calls = [mock.call(ClusterId="j-8989898989")]
- self.mock_emr_client.describe_cluster.assert_has_calls(calls)
+ assert mocked_hook_client.describe_cluster.call_count == 3
+ # make sure it was called with the job_flow_id
+ calls = [mock.call(ClusterId="j-8989898989")] * 3
+ mocked_hook_client.describe_cluster.assert_has_calls(calls)
- @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrJobFlowSensor.poke")
- def test_sensor_defer(self, mock_poke):
+ def test_sensor_defer(self):
"""Test the execute method raise TaskDeferred if running sensor in deferrable mode"""
sensor = EmrJobFlowSensor(
task_id="test_task",
@@ -289,9 +270,11 @@ class TestEmrJobFlowSensor:
target_states=["RUNNING", "WAITING"],
deferrable=True,
)
- mock_poke.return_value = False
- with pytest.raises(TaskDeferred) as exc:
- sensor.execute(context=None)
+
+ with patch.object(EmrJobFlowSensor, "poke", return_value=False):
+ with pytest.raises(TaskDeferred) as exc:
+ sensor.execute(context=None)
+
assert isinstance(
exc.value.trigger, EmrTerminateJobFlowTrigger
- ), f"{exc.value.trigger} is not a EmrTerminateJobFlowTrigger "
+ ), f"{exc.value.trigger} is not a EmrTerminateJobFlowTrigger"
diff --git a/tests/providers/amazon/aws/sensors/test_emr_step.py b/tests/providers/amazon/aws/sensors/test_emr_step.py
index 8387207ba1..34f93d0107 100644
--- a/tests/providers/amazon/aws/sensors/test_emr_step.py
+++ b/tests/providers/amazon/aws/sensors/test_emr_step.py
@@ -26,7 +26,6 @@ from dateutil.tz import tzlocal
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink
from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor
from airflow.providers.amazon.aws.triggers.emr import EmrStepSensorTrigger
@@ -145,9 +144,14 @@ DESCRIBE_JOB_STEP_COMPLETED_RETURN = {
}
+@pytest.fixture
+def mocked_hook_client():
+ with mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrHook.conn") as m:
+ yield m
+
+
class TestEmrStepSensor:
def setup_method(self):
- self.emr_client_mock = MagicMock()
self.sensor = EmrStepSensor(
task_id="test_task",
poke_interval=0,
@@ -156,84 +160,47 @@ class TestEmrStepSensor:
aws_conn_id="aws_default",
)
- mock_emr_session = MagicMock()
- mock_emr_session.client.return_value = self.emr_client_mock
-
- # Mock out the emr_client creator
- self.boto3_session_mock = MagicMock(return_value=mock_emr_session)
-
- @patch.object(EmrClusterLink, "persist")
- @patch.object(EmrLogsLink, "persist")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_step_completed(self, *_):
- self.emr_client_mock.describe_step.side_effect = [
+ def test_step_completed(self, mocked_hook_client):
+ mocked_hook_client.describe_step.side_effect = [
DESCRIBE_JOB_STEP_RUNNING_RETURN,
DESCRIBE_JOB_STEP_COMPLETED_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- self.sensor.execute(None)
-
- assert self.emr_client_mock.describe_step.call_count == 2
- calls = [
- mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
- mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N"),
- ]
- self.emr_client_mock.describe_step.assert_has_calls(calls)
-
- @patch.object(EmrClusterLink, "persist")
- @patch.object(EmrLogsLink, "persist")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_step_cancelled(self, *_):
- self.emr_client_mock.describe_step.side_effect = [
+ with patch.object(S3Hook, "parse_s3_url", return_value="valid_uri"):
+ self.sensor.execute(MagicMock())
+
+ assert mocked_hook_client.describe_step.call_count == 2
+ calls = [mock.call(ClusterId="j-8989898989", StepId="s-VK57YR1Z9Z5N")] * 2
+ mocked_hook_client.describe_step.assert_has_calls(calls)
+
+ def test_step_cancelled(self, mocked_hook_client):
+ mocked_hook_client.describe_step.side_effect = [
DESCRIBE_JOB_STEP_RUNNING_RETURN,
DESCRIBE_JOB_STEP_CANCELLED_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(AirflowException):
- self.sensor.execute(None)
-
- @patch.object(EmrClusterLink, "persist")
- @patch.object(EmrLogsLink, "persist")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_step_failed(self, *_):
- self.emr_client_mock.describe_step.side_effect = [
+ with pytest.raises(AirflowException, match="EMR job failed"):
+ self.sensor.execute(MagicMock())
+
+ def test_step_failed(self, mocked_hook_client):
+ mocked_hook_client.describe_step.side_effect = [
DESCRIBE_JOB_STEP_RUNNING_RETURN,
DESCRIBE_JOB_STEP_FAILED_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(AirflowException):
- self.sensor.execute(None)
-
- @patch.object(EmrClusterLink, "persist")
- @patch.object(EmrLogsLink, "persist")
- @patch.object(S3Hook, "parse_s3_url", return_value="valid_uri")
- def test_step_interrupted(self, *_):
- self.emr_client_mock.describe_step.side_effect = [
+ with pytest.raises(AirflowException, match="EMR job failed"):
+ self.sensor.execute(MagicMock())
+
+ def test_step_interrupted(self, mocked_hook_client):
+ mocked_hook_client.describe_step.side_effect = [
DESCRIBE_JOB_STEP_RUNNING_RETURN,
DESCRIBE_JOB_STEP_INTERRUPTED_RETURN,
]
- with patch("boto3.session.Session", self.boto3_session_mock), patch(
- "airflow.providers.amazon.aws.hooks.base_aws.isinstance"
- ) as mock_isinstance:
- mock_isinstance.return_value = True
- with pytest.raises(AirflowException):
- self.sensor.execute(None)
+ with pytest.raises(AirflowException):
+ self.sensor.execute(MagicMock())
- @mock.patch("airflow.providers.amazon.aws.sensors.emr.EmrStepSensor.poke")
- def test_sensor_defer(self, mock_poke):
+ def test_sensor_defer(self):
"""Test the execute method raise TaskDeferred if running sensor in deferrable mode"""
sensor = EmrStepSensor(
task_id="test_task",
@@ -244,7 +211,7 @@ class TestEmrStepSensor:
deferrable=True,
)
- mock_poke.return_value = False
- with pytest.raises(TaskDeferred) as exc:
- sensor.execute(context=None)
+ with patch.object(EmrStepSensor, "poke", return_value=False):
+ with pytest.raises(TaskDeferred) as exc:
+ sensor.execute(context=None)
assert isinstance(exc.value.trigger, EmrStepSensorTrigger), "Trigger is not a EmrStepSensorTrigger"