You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/12/05 03:12:20 UTC

[airflow] branch main updated: Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f02a7e9a82 Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)
f02a7e9a82 is described below

commit f02a7e9a8292909b369daae6d573f58deed04440
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Dec 5 06:12:12 2022 +0300

    Migrate amazon provider hooks tests from `unittests` to `pytest` (#28039)
---
 tests/providers/amazon/aws/hooks/test_athena.py    |  9 +--
 tests/providers/amazon/aws/hooks/test_base_aws.py  | 10 +---
 .../amazon/aws/hooks/test_batch_client.py          | 11 ++--
 .../amazon/aws/hooks/test_batch_waiters.py         |  5 +-
 tests/providers/amazon/aws/hooks/test_datasync.py  | 25 +++------
 tests/providers/amazon/aws/hooks/test_dms_task.py  |  5 +-
 tests/providers/amazon/aws/hooks/test_dynamodb.py  |  5 +-
 tests/providers/amazon/aws/hooks/test_ecs.py       | 35 +++++-------
 .../hooks/test_elasticache_replication_group.py    |  5 +-
 .../amazon/aws/hooks/test_emr_containers.py        |  9 +--
 tests/providers/amazon/aws/hooks/test_glacier.py   | 65 +++++++++-------------
 .../amazon/aws/hooks/test_glue_crawler.py          | 38 +++++--------
 .../amazon/aws/hooks/test_redshift_sql.py          | 14 ++---
 tests/providers/amazon/aws/utils/test_emailer.py   |  4 +-
 tests/providers/amazon/aws/utils/test_redshift.py  |  3 +-
 tests/providers/amazon/aws/utils/test_utils.py     |  3 +-
 16 files changed, 96 insertions(+), 150 deletions(-)

diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py
index 58870bb1d0..a65470acea 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from unittest import mock
 
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
@@ -48,8 +47,8 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
 }
 
 
-class TestAthenaHook(unittest.TestCase):
-    def setUp(self):
+class TestAthenaHook:
+    def setup_method(self):
         self.athena = AthenaHook(sleep_time=0)
 
     def test_init(self):
@@ -196,7 +195,3 @@ class TestAthenaHook(unittest.TestCase):
         mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
         result = self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
         assert result == "s3://test_bucket/test.csv"
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py
index ad5a0c147e..837a3d2f89 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -19,7 +19,6 @@ from __future__ import annotations
 
 import json
 import os
-import unittest
 from base64 import b64encode
 from datetime import datetime, timedelta, timezone
 from unittest import mock
@@ -606,7 +605,7 @@ class TestAwsBaseHook:
     def test_connection_region_name(
         self, conn_type, connection_uri, region_name, env_region, expected_region_name
     ):
-        with unittest.mock.patch.dict(
+        with mock.patch.dict(
             "os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region
         ):
             if conn_type == "client":
@@ -629,10 +628,7 @@ class TestAwsBaseHook:
         ],
     )
     def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition):
-        with unittest.mock.patch.dict(
-            "os.environ",
-            AIRFLOW_CONN_TEST_CONN=connection_uri,
-        ):
+        with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri):
             if conn_type == "client":
                 hook = AwsBaseHook(aws_conn_id="test_conn", client_type="dynamodb")
             elif conn_type == "resource":
@@ -772,7 +768,7 @@ class TestAwsBaseHook:
             extra={"verify": conn_verify} if conn_verify is not None else {},
         )
 
-        with unittest.mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
+        with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
             hook = AwsBaseHook(aws_conn_id="test_conn", verify=verify)
             expected = verify if verify is not None else conn_verify
             assert hook.verify == expected
diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py
index c1ea153dfd..13726e5518 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_client.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_client.py
@@ -284,11 +284,11 @@ class TestBatchClient:
                 }
             ]
         }
-        with caplog.at_level(level=logging.getLevelName("WARNING")):
+
+        with caplog.at_level(level=logging.WARNING):
             assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
             assert len(caplog.records) == 1
-            log_record = caplog.records[0]
-            assert "doesn't create AWS CloudWatch Stream" in log_record.message
+            assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0]
 
     def test_job_splunk_logs(self, caplog):
         self.client_mock.describe_jobs.return_value = {
@@ -304,11 +304,10 @@ class TestBatchClient:
                 }
             ]
         }
-        with caplog.at_level(level=logging.getLevelName("WARNING")):
+        with caplog.at_level(level=logging.WARNING):
             assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
             assert len(caplog.records) == 1
-            log_record = caplog.records[0]
-            assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message
+            assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0]
 
 
 class TestBatchClientDelays:
diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
index 3ff9a154fd..c245ac4da4 100644
--- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py
+++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py
@@ -30,7 +30,6 @@ derived from the moto test suite for testing the Batch client.
 from __future__ import annotations
 
 import inspect
-import unittest
 from typing import NamedTuple
 from unittest import mock
 
@@ -317,12 +316,12 @@ def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definiti
     assert job_status == "SUCCEEDED"
 
 
-class TestBatchWaiters(unittest.TestCase):
+class TestBatchWaiters:
     @mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
     @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
     @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
     @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
-    def setUp(self, get_client_type_mock):
+    def setup_method(self, method, get_client_type_mock):
         self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
         self.region_name = AWS_REGION
 
diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py b/tests/providers/amazon/aws/hooks/test_datasync.py
index f68b441de8..eeb976e4e0 100644
--- a/tests/providers/amazon/aws/hooks/test_datasync.py
+++ b/tests/providers/amazon/aws/hooks/test_datasync.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from unittest import mock
 
 import boto3
@@ -29,7 +28,7 @@ from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
 
 
 @mock_datasync
-class TestDataSyncHook(unittest.TestCase):
+class TestDataSyncHook:
     def test_get_conn(self):
         hook = DataSyncHook(aws_conn_id="aws_default")
         assert hook.get_conn() is not None
@@ -50,21 +49,13 @@ class TestDataSyncHook(unittest.TestCase):
 
 @mock_datasync
 @mock.patch.object(DataSyncHook, "get_conn")
-class TestDataSyncHookMocked(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.source_server_hostname = "host"
-        self.source_subdirectory = "somewhere"
-        self.destination_bucket_name = "my_bucket"
-        self.destination_bucket_dir = "dir"
+class TestDataSyncHookMocked:
+    source_server_hostname = "host"
+    source_subdirectory = "somewhere"
+    destination_bucket_name = "my_bucket"
+    destination_bucket_dir = "dir"
 
-        self.client = None
-        self.hook = None
-        self.source_location_arn = None
-        self.destination_location_arn = None
-        self.task_arn = None
-
-    def setUp(self):
+    def setup_method(self, method):
         self.client = boto3.client("datasync", region_name="us-east-1")
         self.hook = DataSyncHook(aws_conn_id="aws_default", wait_interval_seconds=0)
 
@@ -86,7 +77,7 @@ class TestDataSyncHookMocked(unittest.TestCase):
             DestinationLocationArn=self.destination_location_arn,
         )["TaskArn"]
 
-    def tearDown(self):
+    def teardown_method(self, method):
         # Delete all tasks:
         tasks = self.client.list_tasks()
         for task in tasks["Tasks"]:
diff --git a/tests/providers/amazon/aws/hooks/test_dms_task.py b/tests/providers/amazon/aws/hooks/test_dms_task.py
index efe5561cd3..9d66df55c2 100644
--- a/tests/providers/amazon/aws/hooks/test_dms_task.py
+++ b/tests/providers/amazon/aws/hooks/test_dms_task.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 import json
-import unittest
 from typing import Any
 from unittest import mock
 
@@ -68,8 +67,8 @@ MOCK_STOP_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_D
 MOCK_DELETE_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_DATA, "Status": "deleting"}}
 
 
-class TestDmsHook(unittest.TestCase):
-    def setUp(self):
+class TestDmsHook:
+    def setup_method(self):
         self.dms = DmsHook()
 
     def test_init(self):
diff --git a/tests/providers/amazon/aws/hooks/test_dynamodb.py b/tests/providers/amazon/aws/hooks/test_dynamodb.py
index 7c06c4c304..8c5886639c 100644
--- a/tests/providers/amazon/aws/hooks/test_dynamodb.py
+++ b/tests/providers/amazon/aws/hooks/test_dynamodb.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 import uuid
 
 from moto import mock_dynamodb
@@ -25,7 +24,7 @@ from moto import mock_dynamodb
 from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
 
 
-class TestDynamoDBHook(unittest.TestCase):
+class TestDynamoDBHook:
     @mock_dynamodb
     def test_get_conn_returns_a_boto3_connection(self):
         hook = DynamoDBHook(aws_conn_id="aws_default")
@@ -39,7 +38,7 @@ class TestDynamoDBHook(unittest.TestCase):
         )
 
         # this table needs to be created in production
-        table = hook.get_conn().create_table(
+        hook.get_conn().create_table(
             TableName="test_airflow",
             KeySchema=[
                 {"AttributeName": "id", "KeyType": "HASH"},
diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py b/tests/providers/amazon/aws/hooks/test_ecs.py
index b7477c372e..d9a4f53fa8 100644
--- a/tests/providers/amazon/aws/hooks/test_ecs.py
+++ b/tests/providers/amazon/aws/hooks/test_ecs.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from datetime import timedelta
 from unittest import mock
 
@@ -56,38 +55,34 @@ class TestEksHooks:
         assert EcsHook().get_task_state(cluster="cluster_name", task="task_name") == "ACTIVE"
 
 
-class TestShouldRetry(unittest.TestCase):
+class TestShouldRetry:
     def test_return_true_on_valid_reason(self):
-        self.assertTrue(should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo")))
+        assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo"))
 
     def test_return_false_on_invalid_reason(self):
-        self.assertFalse(should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo")))
+        assert not should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo"))
 
 
-class TestShouldRetryEni(unittest.TestCase):
+class TestShouldRetryEni:
     def test_return_true_on_valid_reason(self):
-        self.assertTrue(
-            should_retry_eni(
-                EcsTaskFailToStart(
-                    "The task failed to start due to: "
-                    "Timeout waiting for network interface provisioning to complete."
-                )
+        assert should_retry_eni(
+            EcsTaskFailToStart(
+                "The task failed to start due to: "
+                "Timeout waiting for network interface provisioning to complete."
             )
         )
 
     def test_return_false_on_invalid_reason(self):
-        self.assertFalse(
-            should_retry_eni(
-                EcsTaskFailToStart(
-                    "The task failed to start due to: "
-                    "CannotPullContainerError: "
-                    "ref pull has been retried 5 time(s): failed to resolve reference"
-                )
+        assert not should_retry_eni(
+            EcsTaskFailToStart(
+                "The task failed to start due to: "
+                "CannotPullContainerError: "
+                "ref pull has been retried 5 time(s): failed to resolve reference"
             )
         )
 
 
-class TestEcsTaskLogFetcher(unittest.TestCase):
+class TestEcsTaskLogFetcher:
     @mock.patch("logging.Logger")
     def set_up_log_fetcher(self, logger_mock):
         self.logger_mock = logger_mock
@@ -99,7 +94,7 @@ class TestEcsTaskLogFetcher(unittest.TestCase):
             logger=logger_mock,
         )
 
-    def setUp(self):
+    def setup_method(self):
         self.set_up_log_fetcher()
 
     @mock.patch(
diff --git a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
index 18766a1e7f..8c72720ddc 100644
--- a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
+++ b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-from unittest import TestCase
 from unittest.mock import Mock
 
 import pytest
@@ -26,7 +25,7 @@ from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.elasticache_replication_group import ElastiCacheReplicationGroupHook
 
 
-class TestElastiCacheReplicationGroupHook(TestCase):
+class TestElastiCacheReplicationGroupHook:
     REPLICATION_GROUP_ID = "test-elasticache-replication-group-hook"
 
     REPLICATION_GROUP_CONFIG = {
@@ -44,7 +43,7 @@ class TestElastiCacheReplicationGroupHook(TestCase):
         {"creating", "available", "modifying", "deleting", "create - failed", "snapshotting"}
     )
 
-    def setUp(self):
+    def setup_method(self):
         self.hook = ElastiCacheReplicationGroupHook()
         # noinspection PyPropertyAccess
         self.hook.conn = Mock()
diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py
index 7bd1255b07..8a5f1303a6 100644
--- a/tests/providers/amazon/aws/hooks/test_emr_containers.py
+++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from unittest import mock
 
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
@@ -47,8 +46,8 @@ JOB2_RUN_DESCRIPTION = {
 }
 
 
-class TestEmrContainerHook(unittest.TestCase):
-    def setUp(self):
+class TestEmrContainerHook:
+    def setup_method(self):
         self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")
 
     def test_init(self):
@@ -110,7 +109,9 @@ class TestEmrContainerHook(unittest.TestCase):
         mock_session.return_value = emr_session_mock
         emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION
 
-        query_status = self.emr_containers.poll_query_status(job_id="job123456", max_polling_attempts=2)
+        query_status = self.emr_containers.poll_query_status(
+            job_id="job123456", max_polling_attempts=2, poll_interval=0
+        )
         # should poll until max_tries is reached since query is in non-terminal state
         assert emr_client_mock.describe_job_run.call_count == 2
         assert query_status == "RUNNING"
diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py
index 28f0bca78c..cce5669960 100644
--- a/tests/providers/amazon/aws/hooks/test_glacier.py
+++ b/tests/providers/amazon/aws/hooks/test_glacier.py
@@ -17,7 +17,7 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
+import logging
 from unittest import mock
 
 from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
@@ -30,8 +30,8 @@ RESPONSE_BODY = {"body": "data"}
 JOB_STATUS = {"Action": "", "StatusCode": "Succeeded"}
 
 
-class TestAmazonGlacierHook(unittest.TestCase):
-    def setUp(self):
+class TestAmazonGlacierHook:
+    def setup_method(self):
         with mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.__init__", return_value=None):
             self.hook = GlacierHook(aws_conn_id="aws_default")
 
@@ -47,25 +47,21 @@ class TestAmazonGlacierHook(unittest.TestCase):
         assert job_id == result
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
-    def test_retrieve_inventory_should_log_mgs(self, mock_conn):
+    def test_retrieve_inventory_should_log_mgs(self, mock_conn, caplog):
         # given
         job_id = {"jobId": "1234abcd"}
         # when
-        with self.assertLogs() as log:
+
+        with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+            caplog.clear()
             mock_conn.return_value.initiate_job.return_value = job_id
             self.hook.retrieve_inventory(VAULT_NAME)
-            # then
-            self.assertEqual(
-                log.output,
-                [
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Retrieving inventory for vault: {VAULT_NAME}",
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Initiated inventory-retrieval job for: {VAULT_NAME}",
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Retrieval Job ID: {job_id.get('jobId')}",
-                ],
-            )
+        # then
+        assert caplog.messages == [
+            f"Retrieving inventory for vault: {VAULT_NAME}",
+            f"Initiated inventory-retrieval job for: {VAULT_NAME}",
+            f"Retrieval Job ID: {job_id.get('jobId')}",
+        ]
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
     def test_retrieve_inventory_results_should_return_response(self, mock_conn):
@@ -77,19 +73,14 @@ class TestAmazonGlacierHook(unittest.TestCase):
         assert response == RESPONSE_BODY
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
-    def test_retrieve_inventory_results_should_log_mgs(self, mock_conn):
+    def test_retrieve_inventory_results_should_log_mgs(self, mock_conn, caplog):
         # when
-        with self.assertLogs() as log:
+        with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+            caplog.clear()
             mock_conn.return_value.get_job_output.return_value = REQUEST_RESULT
             self.hook.retrieve_inventory_results(VAULT_NAME, JOB_ID)
-            # then
-            self.assertEqual(
-                log.output,
-                [
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Retrieving the job results for vault: {VAULT_NAME}...",
-                ],
-            )
+        # then
+        assert caplog.messages == [f"Retrieving the job results for vault: {VAULT_NAME}..."]
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
     def test_describe_job_should_return_status_succeeded(self, mock_conn):
@@ -101,18 +92,14 @@ class TestAmazonGlacierHook(unittest.TestCase):
         assert response == JOB_STATUS
 
     @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn")
-    def test_describe_job_should_log_mgs(self, mock_conn):
+    def test_describe_job_should_log_mgs(self, mock_conn, caplog):
         # when
-        with self.assertLogs() as log:
+        with caplog.at_level(logging.INFO, logger=self.hook.log.name):
+            caplog.clear()
             mock_conn.return_value.describe_job.return_value = JOB_STATUS
             self.hook.describe_job(VAULT_NAME, JOB_ID)
-            # then
-            self.assertEqual(
-                log.output,
-                [
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
-                    "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:"
-                    f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
-                ],
-            )
+        # then
+        assert caplog.messages == [
+            f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}",
+            f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}",
+        ]
diff --git a/tests/providers/amazon/aws/hooks/test_glue_crawler.py b/tests/providers/amazon/aws/hooks/test_glue_crawler.py
index ec966cb683..ac2d3cba2c 100644
--- a/tests/providers/amazon/aws/hooks/test_glue_crawler.py
+++ b/tests/providers/amazon/aws/hooks/test_glue_crawler.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from copy import deepcopy
 from unittest import mock
 
@@ -83,18 +82,16 @@ mock_config = {
 }
 
 
-class TestGlueCrawlerHook(unittest.TestCase):
-    @classmethod
-    def setUp(cls):
-        cls.hook = GlueCrawlerHook(aws_conn_id="aws_default")
+class TestGlueCrawlerHook:
+    def setup_method(self):
+        self.hook = GlueCrawlerHook(aws_conn_id="aws_default")
 
     def test_init(self):
-        self.assertEqual(self.hook.aws_conn_id, "aws_default")
+        assert self.hook.aws_conn_id == "aws_default"
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
     def test_has_crawler(self, mock_get_conn):
-        response = self.hook.has_crawler(mock_crawler_name)
-        self.assertEqual(response, True)
+        assert self.hook.has_crawler(mock_crawler_name) is True
         mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
@@ -104,8 +101,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
 
         mock_get_conn.return_value.exceptions.EntityNotFoundException = MockException
         mock_get_conn.return_value.get_crawler.side_effect = MockException("AAA")
-        response = self.hook.has_crawler(mock_crawler_name)
-        self.assertEqual(response, False)
+        assert self.hook.has_crawler(mock_crawler_name) is False
         mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
@@ -114,30 +110,28 @@ class TestGlueCrawlerHook(unittest.TestCase):
 
         mock_config_two = deepcopy(mock_config)
         mock_config_two["Role"] = "test-2-role"
-        response = self.hook.update_crawler(**mock_config_two)
-        self.assertEqual(response, True)
+        assert self.hook.update_crawler(**mock_config_two) is True
         mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
         mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two)
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
     def test_update_crawler_not_needed(self, mock_get_conn):
         mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config}
-        response = self.hook.update_crawler(**mock_config)
-        self.assertEqual(response, False)
+        assert self.hook.update_crawler(**mock_config) is False
         mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name)
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
     def test_create_crawler(self, mock_get_conn):
         mock_get_conn.return_value.create_crawler.return_value = {"Crawler": {"Name": mock_crawler_name}}
         glue_crawler = self.hook.create_crawler(**mock_config)
-        self.assertIn("Crawler", glue_crawler)
-        self.assertIn("Name", glue_crawler["Crawler"])
-        self.assertEqual(glue_crawler["Crawler"]["Name"], mock_crawler_name)
+        assert "Crawler" in glue_crawler
+        assert "Name" in glue_crawler["Crawler"]
+        assert glue_crawler["Crawler"]["Name"] == mock_crawler_name
 
     @mock.patch.object(GlueCrawlerHook, "get_conn")
     def test_start_crawler(self, mock_get_conn):
         result = self.hook.start_crawler(mock_crawler_name)
-        self.assertEqual(result, mock_get_conn.return_value.start_crawler.return_value)
+        assert result == mock_get_conn.return_value.start_crawler.return_value
 
         mock_get_conn.return_value.start_crawler.assert_called_once_with(Name=mock_crawler_name)
 
@@ -159,7 +153,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
             ]
         }
         result = self.hook.wait_for_crawler_completion(mock_crawler_name)
-        self.assertEqual(result, "MOCK_STATUS")
+        assert result == "MOCK_STATUS"
         mock_get_conn.assert_has_calls(
             [
                 mock.call(),
@@ -195,7 +189,7 @@ class TestGlueCrawlerHook(unittest.TestCase):
             },
         ]
         result = self.hook.wait_for_crawler_completion(mock_crawler_name)
-        self.assertEqual(result, "MOCK_STATUS")
+        assert result == "MOCK_STATUS"
         mock_get_conn.assert_has_calls(
             [
                 mock.call(),
@@ -208,7 +202,3 @@ class TestGlueCrawlerHook(unittest.TestCase):
                 mock.call(mock_crawler_name),
             ]
         )
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
index e3af91c9e7..531d6a9b47 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py
@@ -17,19 +17,16 @@
 from __future__ import annotations
 
 import json
-import unittest
 from unittest import mock
 
-from parameterized import parameterized
+import pytest
 
 from airflow.models import Connection
 from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
 
 
-class TestRedshiftSQLHookConn(unittest.TestCase):
-    def setUp(self):
-        super().setUp()
-
+class TestRedshiftSQLHookConn:
+    def setup_method(self):
         self.connection = Connection(
             conn_type="redshift", login="login", password="password", host="host", port=5439, schema="dev"
         )
@@ -71,7 +68,8 @@ class TestRedshiftSQLHookConn(unittest.TestCase):
             iam=True,
         )
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "conn_params, conn_extra, expected_call_args",
         [
             ({}, {}, {}),
             ({"login": "test"}, {}, {"user": "test"}),
@@ -81,7 +79,7 @@ class TestRedshiftSQLHookConn(unittest.TestCase):
         ],
     )
     @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect")
-    def test_get_conn_overrides_correctly(self, conn_params, conn_extra, expected_call_args, mock_connect):
+    def test_get_conn_overrides_correctly(self, mock_connect, conn_params, conn_extra, expected_call_args):
         with mock.patch(
             "airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.conn",
             Connection(conn_type="redshift", extra=conn_extra, **conn_params),
diff --git a/tests/providers/amazon/aws/utils/test_emailer.py b/tests/providers/amazon/aws/utils/test_emailer.py
index 0e3b2ed7d4..e51a885b67 100644
--- a/tests/providers/amazon/aws/utils/test_emailer.py
+++ b/tests/providers/amazon/aws/utils/test_emailer.py
@@ -17,14 +17,14 @@
 # under the License.
 from __future__ import annotations
 
-from unittest import TestCase, mock
+from unittest import mock
 
 import pytest
 
 from airflow.providers.amazon.aws.utils.emailer import send_email
 
 
-class TestSendEmailSes(TestCase):
+class TestSendEmailSes:
     @mock.patch("airflow.providers.amazon.aws.utils.emailer.SesHook")
     def test_send_ses_email(self, mock_hook):
         send_email(
diff --git a/tests/providers/amazon/aws/utils/test_redshift.py b/tests/providers/amazon/aws/utils/test_redshift.py
index f255e6bfd0..9d546c13d1 100644
--- a/tests/providers/amazon/aws/utils/test_redshift.py
+++ b/tests/providers/amazon/aws/utils/test_redshift.py
@@ -17,7 +17,6 @@
 # under the License.
 from __future__ import annotations
 
-import unittest
 from unittest import mock
 
 from boto3.session import Session
@@ -25,7 +24,7 @@ from boto3.session import Session
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
 
 
-class TestS3ToRedshiftTransfer(unittest.TestCase):
+class TestS3ToRedshiftTransfer:
     @mock.patch("boto3.session.Session")
     def test_build_credentials_block(self, mock_session):
         access_key = "aws_access_key_id"
diff --git a/tests/providers/amazon/aws/utils/test_utils.py b/tests/providers/amazon/aws/utils/test_utils.py
index ced274c6e7..6cf8bbb23e 100644
--- a/tests/providers/amazon/aws/utils/test_utils.py
+++ b/tests/providers/amazon/aws/utils/test_utils.py
@@ -17,7 +17,6 @@
 from __future__ import annotations
 
 from datetime import datetime
-from unittest import TestCase
 
 import pytz
 
@@ -33,7 +32,7 @@ DT = datetime(2000, 1, 1, tzinfo=pytz.UTC)
 EPOCH = 946_684_800
 
 
-class TestUtils(TestCase):
+class TestUtils:
     def test_trim_none_values(self):
         input_object = {
             "test": "test",