You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/12/05 03:12:20 UTC
[airflow] branch main updated: Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f02a7e9a82 Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)
f02a7e9a82 is described below
commit f02a7e9a8292909b369daae6d573f58deed04440
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Dec 5 06:12:12 2022 +0300
Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)
---
tests/providers/amazon/aws/hooks/test_athena.py | 9 +--
tests/providers/amazon/aws/hooks/test_base_aws.py | 10 +---
.../amazon/aws/hooks/test_batch_client.py | 11 ++--
.../amazon/aws/hooks/test_batch_waiters.py | 5 +-
tests/providers/amazon/aws/hooks/test_datasync.py | 25 +++------
tests/providers/amazon/aws/hooks/test_dms_task.py | 5 +-
tests/providers/amazon/aws/hooks/test_dynamodb.py | 5 +-
tests/providers/amazon/aws/hooks/test_ecs.py | 35 +++++-------
.../hooks/test_elasticache_replication_group.py | 5 +-
.../amazon/aws/hooks/test_emr_containers.py | 9 +--
tests/providers/amazon/aws/hooks/test_glacier.py | 65 +++++++++-------------
.../amazon/aws/hooks/test_glue_crawler.py | 38 +++++--------
.../amazon/aws/hooks/test_redshift_sql.py | 14 ++---
tests/providers/amazon/aws/utils/test_emailer.py | 4 +-
tests/providers/amazon/aws/utils/test_redshift.py | 3 +-
tests/providers/amazon/aws/utils/test_utils.py | 3 +-
16 files changed, 96 insertions(+), 150 deletions(-)
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py
index 58870bb1d0..a65470acea 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
@@ -48,8 +47,8 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
}
-class TestAthenaHook(unittest.TestCase):
- def setUp(self):
+class TestAthenaHook:
+ def setup_method(self):
self.athena = AthenaHook(sleep_time=0)
def test_init(self):
@@ -196,7 +195,3 @@ class TestAthenaHook(unittest.TestCase):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
assert result == "s3://test_bucket/test.csv"
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index ad5a0c147e..837a3d2f89 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -19,7 +19,6 @@ from __future__ import annotations
import json
import os
-import unittest
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from unittest import mock
@@ -606,7 +605,7 @@ class TestAwsBaseHook:
def test_connection_region_name(
self, conn_type, connection_uri, region_name, env_region, expected_region_name
):
- with unittest.mock.patch.dict(
+ with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region
):
if conn_type == "client":
@@ -629,10 +628,7 @@ class TestAwsBaseHook:
],
)
def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition):
- with unittest.mock.patch.dict(
- "os.environ",
- AIRFLOW_CONN_TEST_CONN=connection_uri,
- ):
+ with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri):
if conn_type == "client":
hook = AwsBaseHook(aws_conn_id="test_conn", client_type="dynamodb")
elif conn_type == "resource":
@@ -772,7 +768,7 @@ class TestAwsBaseHook:
extra={"verify": conn_verify} if conn_verify is not None else {},
)
- with unittest.mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
+ with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
hook = AwsBaseHook(aws_conn_id="test_conn", verify=verify)
expected = verify if verify is not None else conn_verify
assert hook.verify == expected
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py
index c1ea153dfd..13726e5518 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -284,11 +284,11 @@ class TestBatchClient:
}
]
}
- with caplog.at_level(level=logging.getLevelName("WARNING")):
+
+ with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
- log_record = caplog.records[0]
- assert "doesn't create AWS CloudWatch Stream" in log_record.message
+ assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0]
def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
@@ -304,11 +304,10 @@ class TestBatchClient:
}
]
}
- with caplog.at_level(level=logging.getLevelName("WARNING")):
+ with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
- log_record = caplog.records[0]
- assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message
+ assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0]
class TestBatchClientDelays:
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index 3ff9a154fd..c245ac4da4 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -30,7 +30,6 @@ derived from the moto test suite for testing the Batch client.
from __future__ import annotations
import inspect
-import unittest
from typing import NamedTuple
from unittest import mock
@@ -317,12 +316,12 @@ def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definiti
assert job_status == "SUCCEEDED"
-class TestBatchWaiters(unittest.TestCase):
+class TestBatchWaiters:
@mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
- def setUp(self, get_client_type_mock):
+ def setup_method(self, method, get_client_type_mock):
self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
self.region_name = AWS_REGION
diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py b/tests/providers/amazon/aws/hooks/test_datasync.py
index f68b441de8..eeb976e4e0 100644
--- a/tests/providers/amazon/aws/hooks/test_datasync.py
+++ b/tests/providers/amazon/aws/hooks/test_datasync.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
import boto3
@@ -29,7 +28,7 @@ from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
@mock_datasync
-class TestDataSyncHook(unittest.TestCase):
+class TestDataSyncHook:
def test_get_conn(self):
hook = DataSyncHook(aws_conn_id="aws_default")
assert hook.get_conn() is not None
@@ -50,21 +49,13 @@ class TestDataSyncHook(unittest.TestCase):
@mock_datasync
@mock.patch.object(DataSyncHook, "get_conn")
-class TestDataSyncHookMocked(unittest.TestCase):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.source_server_hostname = "host"
- self.source_subdirectory = "somewhere"
- self.destination_bucket_name = "my_bucket"
- self.destination_bucket_dir = "dir"
+class TestDataSyncHookMocked:
+ source_server_hostname = "host"
+ source_subdirectory = "somewhere"
+ destination_bucket_name = "my_bucket"
+ destination_bucket_dir = "dir"
- self.client = None
- self.hook = None
- self.source_location_arn = None
- self.destination_location_arn = None
- self.task_arn = None
-
- def setUp(self):
+ def setup_method(self, method):
self.client = boto3.client("datasync", region_name="us-east-1")
self.hook = DataSyncHook(aws_conn_id="aws_default", wait_interval_seconds=0)
@@ -86,7 +77,7 @@ class TestDataSyncHookMocked(unittest.TestCase):
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]
- def tearDown(self):
+ def teardown_method(self, method):
# Delete all tasks:
tasks = self.client.list_tasks()
for task in tasks["Tasks"]:
diff --git a/tests/providers/amazon/aws/hooks/test_dms_task.py b/tests/providers/amazon/aws/hooks/test_dms_task.py
index efe5561cd3..9d66df55c2 100644
--- a/tests/providers/amazon/aws/hooks/test_dms_task.py
+++ b/tests/providers/amazon/aws/hooks/test_dms_task.py
@@ -17,7 +17,6 @@
from __future__ import annotations
import json
-import unittest
from typing import Any
from unittest import mock
@@ -68,8 +67,8 @@ MOCK_STOP_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_D
MOCK_DELETE_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_DATA, "Status": "deleting"}}
-class TestDmsHook(unittest.TestCase):
- def setUp(self):
+class TestDmsHook:
+ def setup_method(self):
self.dms = DmsHook()
def test_init(self):
diff --git a/tests/providers/amazon/aws/hooks/test_dynamodb.py b/tests/providers/amazon/aws/hooks/test_dynamodb.py
index 7c06c4c304..8c5886639c 100644
--- a/tests/providers/amazon/aws/hooks/test_dynamodb.py
+++ b/tests/providers/amazon/aws/hooks/test_dynamodb.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
import uuid
from moto import mock_dynamodb
@@ -25,7 +24,7 @@ from moto import mock_dynamodb
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
-class TestDynamoDBHook(unittest.TestCase):
+class TestDynamoDBHook:
@mock_dynamodb
def test_get_conn_returns_a_boto3_connection(self):
hook = DynamoDBHook(aws_conn_id="aws_default")
@@ -39,7 +38,7 @@ class TestDynamoDBHook(unittest.TestCase):
)
# this table needs to be created in production
- table = hook.get_conn().create_table(
+ hook.get_conn().create_table(
TableName="test_airflow",
KeySchema=[
{"AttributeName": "id", "KeyType": "HASH"},
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py b/tests/providers/amazon/aws/hooks/test_ecs.py
index b7477c372e..d9a4f53fa8 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations
-import unittest
from datetime import timedelta
from unittest import mock
@@ -56,38 +55,34 @@ class TestEksHooks:
assert EcsHook().get_task_state(cluster="cluster_name", task="task_name") == "ACTIVE"
-class TestShouldRetry(unittest.TestCase):
+class TestShouldRetry:
def test_return_true_on_valid_reason(self):
- self.assertTrue(should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo")))
+ assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo"))
def test_return_false_on_invalid_reason(self):
- self.assertFalse(should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo")))
+ assert not should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo"))
-class TestShouldRetryEni(unittest.TestCase):
+class TestShouldRetryEni:
def test_return_true_on_valid_reason(self):
- self.assertTrue(
- should_retry_eni(
- EcsTaskFailToStart(
- "The task failed to start due to: "
- "Timeout waiting for network interface provisioning to complete."
- )
+ assert should_retry_eni(
+ EcsTaskFailToStart(
+ "The task failed to start due to: "
+ "Timeout waiting for network interface provisioning to complete."
)
)
def test_return_false_on_invalid_reason(self):
- self.assertFalse(
- should_retry_eni(
- EcsTaskFailToStart(
- "The task failed to start due to: "
- "CannotPullContainerError: "
- "ref pull has been retried 5 time(s): failed to resolve reference"
- )
+ assert not should_retry_eni(
+ EcsTaskFailToStart(
+ "The task failed to start due to: "
+ "CannotPullContainerError: "
+ "ref pull has been retried 5 time(s): failed to resolve reference"
)
)
-class TestEcsTaskLogFetcher(unittest.TestCase):
+class TestEcsTaskLogFetcher:
@mock.patch("logging.Logger")
def set_up_log_fetcher(self, logger_mock):
self.logger_mock = logger_mock
@@ -99,7 +94,7 @@ class TestEcsTaskLogFetcher(unittest.TestCase):
logger=logger_mock,
)
- def setUp(self):
+ def setup_method(self):
self.set_up_log_fetcher()
@mock.patch(
diff --git a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
index 18766a1e7f..8c72720ddc 100644
--- a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
+++ b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-from unittest import TestCase
from unittest.mock import Mock
import pytest
@@ -26,7 +25,7 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.elasticache_replication_group import ElastiCacheReplicationGroupHook
-class TestElastiCacheReplicationGroupHook(TestCase):
+class TestElastiCacheReplicationGroupHook:
REPLICATION_GROUP_ID = "test-elasticache-replication-group-hook"
REPLICATION_GROUP_CONFIG = {
@@ -44,7 +43,7 @@ class TestElastiCacheReplicationGroupHook(TestCase):
{"creating", "available", "modifying", "deleting", "create - failed", "snapshotting"}
)
- def setUp(self):
+ def setup_method(self):
self.hook = ElastiCacheReplicationGroupHook()
# noinspection PyPropertyAccess
self.hook.conn = Mock()
diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py
index 7bd1255b07..8a5f1303a6 100644
--- a/tests/providers/amazon/aws/hooks/test_emr_containers.py
+++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
@@ -47,8 +46,8 @@ JOB2_RUN_DESCRIPTION = {
}
-class TestEmrContainerHook(unittest.TestCase):
- def setUp(self):
+class TestEmrContainerHook:
+ def setup_method(self):
self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")
def test_init(self):
@@ -110,7 +109,9 @@ class TestEmrContainerHook(unittest.TestCase):
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
- query_status = self.emr_containers.poll_query_status(job_id="job123456", max_polling_attempts=2)
+ query_status = self.emr_containers.poll_query_status(
+ job_id="job123456", max_polling_attempts=2, poll_interval=0
+ )
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert query_status == "RUNNING"
diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py
index 28f0bca78c..cce5669960 100644
--- a/tests/providers/amazon/aws/hooks/test_glacier.py
+++ b/tests/providers/amazon/aws/hooks/test_glacier.py
@@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations
-import unittest
+import logging
from unittest import mock
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
@@ -30,8 +30,8 @@ RESPONSE_BODY = {"body": "data"}
JOB_STATUS = {"Action": "", "StatusCode": "Succeeded"}
-class TestAmazonGlacierHook(unittest.TestCase):
- def setUp(self):
+class TestAmazonGlacierHook:
+ def setup_method(self):
with mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.__init__", return_value=None):
self.hook = GlacierHook(aws_conn_id="aws_default")
@@ -47,25 +47,21 @@ class TestAmazonGlacierHook(unittest.TestCase):
assert job_id == result
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
- def test_retrieve_inventory_should_log_mgs(self, mock_conn):
+ def test_retrieve_inventory_should_log_mgs(self, mock_conn, caplog):
# given
job_id = {"jobId": "1234abcd"}
# when
- with self.assertLogs() as log:
+
+ with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+ caplog.clear()
mock_conn.return_value.initiate_job.return_value = job_id
self.hook.retrieve_inventory(VAULT_NAME)
- # then
- self.assertEqual(
- log.output,
- [
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Retrieving inventory for vault: {VAULT_NAME}",
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Initiated inventory-retrieval job for: {VAULT_NAME}",
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Retrieval Job ID: {job_id.get('jobId')}",
- ],
- )
+ # then
+ assert caplog.messages == [
+ f"Retrieving inventory for vault: {VAULT_NAME}",
+ f"Initiated inventory-retrieval job for: {VAULT_NAME}",
+ f"Retrieval Job ID: {job_id.get('jobId')}",
+ ]
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
def test_retrieve_inventory_results_should_return_response(self, mock_conn):
@@ -77,19 +73,14 @@ class TestAmazonGlacierHook(unittest.TestCase):
assert response == RESPONSE_BODY
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
- def test_retrieve_inventory_results_should_log_mgs(self, mock_conn):
+ def test_retrieve_inventory_results_should_log_mgs(self, mock_conn, caplog):
# when
- with self.assertLogs() as log:
+ with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+ caplog.clear()
mock_conn.return_value.get_job_output.return_value = REQUEST_RESULT
self.hook.retrieve_inventory_results(VAULT_NAME, JOB_ID)
- # then
- self.assertEqual(
- log.output,
- [
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Retrieving the job results for vault: {VAULT_NAME}...",
- ],
- )
+ # then
+ assert caplog.messages == [f"Retrieving the job results for vault: {VAULT_NAME}..."]
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
def test_describe_job_should_return_status_succeeded(self, mock_conn):
@@ -101,18 +92,14 @@ class TestAmazonGlacierHook(unittest.TestCase):
assert response == JOB_STATUS
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
- def test_describe_job_should_log_mgs(self, mock_conn):
+ def test_describe_job_should_log_mgs(self, mock_conn, caplog):
# when
- with self.assertLogs() as log:
+ with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+ caplog.clear()
mock_conn.return_value.describe_job.return_value = JOB_STATUS
self.hook.describe_job(VAULT_NAME, JOB_ID)
- # then
- self.assertEqual(
- log.output,
- [
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
- "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
- f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
- ],
- )
+ # then
+ assert caplog.messages == [
+ f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
+ f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
+ ]
diff --git a/tests/providers/amazon/aws/hooks/test_glue_crawler.py b/tests/providers/amazon/aws/hooks/test_glue_crawler.py
index ec966cb683..ac2d3cba2c 100644
--- a/tests/providers/amazon/aws/hooks/test_glue_crawler.py
+++ b/tests/providers/amazon/aws/hooks/test_glue_crawler.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from copy import deepcopy
from unittest import mock
@@ -83,18 +82,16 @@ mock_config = {
}
-class TestGlueCrawlerHook(unittest.TestCase):
- @classmethod
- def setUp(cls):
- cls.hook = GlueCrawlerHook(aws_conn_id="aws_default")
+class TestGlueCrawlerHook:
+ def setup_method(self):
+ self.hook = GlueCrawlerHook(aws_conn_id="aws_default")
def test_init(self):
- self.assertEqual(self.hook.aws_conn_id, "aws_default")
+ assert self.hook.aws_conn_id == "aws_default"
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_has_crawler(self, mock_get_conn):
- response = self.hook.has_crawler(mock_crawler_name)
- self.assertEqual(response, True)
+ assert self.hook.has_crawler(mock_crawler_name) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
@mock.patch.object(GlueCrawlerHook, "get_conn")
@@ -104,8 +101,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
mock_get_conn.return_value.exceptions.EntityNotFoundException = MockException
mock_get_conn.return_value.get_crawler.side_effect = MockException("AAA")
- response = self.hook.has_crawler(mock_crawler_name)
- self.assertEqual(response, False)
+ assert self.hook.has_crawler(mock_crawler_name) is False
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
@mock.patch.object(GlueCrawlerHook, "get_conn")
@@ -114,30 +110,28 @@ class TestGlueCrawlerHook(unittest.TestCase):
mock_config_two = deepcopy(mock_config)
mock_config_two["Role"] = "test-2-role"
- response = self.hook.update_crawler(**mock_config_two)
- self.assertEqual(response, True)
+ assert self.hook.update_crawler(**mock_config_two) is True
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_update_crawler_not_needed(self, mock_get_conn):
mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
- response = self.hook.update_crawler(**mock_config)
- self.assertEqual(response, False)
+ assert self.hook.update_crawler(**mock_config) is False
mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_create_crawler(self, mock_get_conn):
mock_get_conn.return_value.create_crawler.return_value = {"Crawler": {"Name": mock_crawler_name}}
glue_crawler = self.hook.create_crawler(**mock_config)
- self.assertIn("Crawler", glue_crawler)
- self.assertIn("Name", glue_crawler["Crawler"])
- self.assertEqual(glue_crawler["Crawler"]["Name"], mock_crawler_name)
+ assert "Crawler" in glue_crawler
+ assert "Name" in glue_crawler["Crawler"]
+ assert glue_crawler["Crawler"]["Name"] == mock_crawler_name
@mock.patch.object(GlueCrawlerHook, "get_conn")
def test_start_crawler(self, mock_get_conn):
result = self.hook.start_crawler(mock_crawler_name)
- self.assertEqual(result, mock_get_conn.return_value.start_crawler.return_value)
+ assert result == mock_get_conn.return_value.start_crawler.return_value
mock_get_conn.return_value.start_crawler.assert_called_once_with(Name=mock_crawler_name)
@@ -159,7 +153,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
]
}
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
- self.assertEqual(result, "MOCK_STATUS")
+ assert result == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
@@ -195,7 +189,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
},
]
result = self.hook.wait_for_crawler_completion(mock_crawler_name)
- self.assertEqual(result, "MOCK_STATUS")
+ assert result == "MOCK_STATUS"
mock_get_conn.assert_has_calls(
[
mock.call(),
@@ -208,7 +202,3 @@ class TestGlueCrawlerHook(unittest.TestCase):
mock.call(mock_crawler_name),
]
)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index e3af91c9e7..531d6a9b47 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -17,19 +17,16 @@
from __future__ import annotations
import json
-import unittest
from unittest import mock
-from parameterized import parameterized
+import pytest
from airflow.models import Connection
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
-class TestRedshiftSQLHookConn(unittest.TestCase):
- def setUp(self):
- super().setUp()
-
+class TestRedshiftSQLHookConn:
+ def setup_method(self):
self.connection = Connection(
conn_type="redshift", login="login", password="password", host="host", port=5439, schema="dev"
)
@@ -71,7 +68,8 @@ class TestRedshiftSQLHookConn(unittest.TestCase):
iam=True,
)
- @parameterized.expand(
+ @pytest.mark.parametrize(
+ "conn_params, conn_extra, expected_call_args",
[
({}, {}, {}),
({"login": "test"}, {}, {"user": "test"}),
@@ -81,7 +79,7 @@ class TestRedshiftSQLHookConn(unittest.TestCase):
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
- def test_get_conn_overrides_correctly(self, conn_params, conn_extra, expected_call_args, mock_connect):
+ def test_get_conn_overrides_correctly(self, mock_connect, conn_params, conn_extra, expected_call_args):
with mock.patch(
"airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.conn",
Connection(conn_type="redshift", extra=conn_extra, **conn_params),
diff --git a/tests/providers/amazon/aws/utils/test_emailer.py b/tests/providers/amazon/aws/utils/test_emailer.py
index 0e3b2ed7d4..e51a885b67 100644
--- a/tests/providers/amazon/aws/utils/test_emailer.py
+++ b/tests/providers/amazon/aws/utils/test_emailer.py
@@ -17,14 +17,14 @@
# under the License.
from __future__ import annotations
-from unittest import TestCase, mock
+from unittest import mock
import pytest
from airflow.providers.amazon.aws.utils.emailer import send_email
-class TestSendEmailSes(TestCase):
+class TestSendEmailSes:
@mock.patch("airflow.providers.amazon.aws.utils.emailer.SesHook")
def test_send_ses_email(self, mock_hook):
send_email(
diff --git a/tests/providers/amazon/aws/utils/test_redshift.py b/tests/providers/amazon/aws/utils/test_redshift.py
index f255e6bfd0..9d546c13d1 100644
--- a/tests/providers/amazon/aws/utils/test_redshift.py
+++ b/tests/providers/amazon/aws/utils/test_redshift.py
@@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations
-import unittest
from unittest import mock
from boto3.session import Session
@@ -25,7 +24,7 @@ from boto3.session import Session
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
-class TestS3ToRedshiftTransfer(unittest.TestCase):
+class TestS3ToRedshiftTransfer:
@mock.patch("boto3.session.Session")
def test_build_credentials_block(self, mock_session):
access_key = "aws_access_key_id"
diff --git a/tests/providers/amazon/aws/utils/test_utils.py b/tests/providers/amazon/aws/utils/test_utils.py
index ced274c6e7..6cf8bbb23e 100644
--- a/tests/providers/amazon/aws/utils/test_utils.py
+++ b/tests/providers/amazon/aws/utils/test_utils.py
@@ -17,7 +17,6 @@
from __future__ import annotations
from datetime import datetime
-from unittest import TestCase
import pytz
@@ -33,7 +32,7 @@ DT = datetime(2000, 1, 1, tzinfo=pytz.UTC)
EPOCH = 946_684_800
-class TestUtils(TestCase):
+class TestUtils:
def test_trim_none_values(self):
input_object = {
"test": "test",