You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/21 14:28:17 UTC

[airflow] 14/22: Bump typing-extensions and mypy for ParamSpec (#25088)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit ec8ea0e200d5a13b1414312e017eea1af6fdf98d
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Jul 18 17:20:35 2022 +0800

    Bump typing-extensions and mypy for ParamSpec (#25088)
    
    * Bump typing-extensions and mypy for ParamSpec
    
    I want to use them in some @task signature improvements. Mypy added this
    in 0.950, but let's just bump to latest since why not.
    
    Changelog of typing-extensions is spotty before 4.0, but ParamSpec was
    introduced some time before that (likely some time in 2021), and it
    seems to be a reasonble minimum to bump to.
    
    For more about ParamSpec, read PEP 612: https://peps.python.org/pep-0612/
    
    (cherry picked from commit e32e9c58802fe9363cc87ea283a59218df7cec3a)
---
 airflow/jobs/scheduler_job.py                      |  4 +-
 airflow/mypy/plugin/decorators.py                  |  5 +-
 .../amazon/aws/transfers/dynamodb_to_s3.py         |  1 +
 .../providers/amazon/aws/transfers/sql_to_s3.py    | 19 ++++---
 .../providers/google/cloud/operators/cloud_sql.py  |  2 +-
 airflow/providers/microsoft/azure/hooks/cosmos.py  | 62 +++++++++++++---------
 airflow/utils/context.py                           |  2 +-
 .../airflow_breeze/commands/testing_commands.py    |  8 +--
 scripts/in_container/run_migration_reference.py    |  1 +
 setup.cfg                                          |  2 +-
 setup.py                                           |  2 +-
 .../microsoft/azure/hooks/test_azure_cosmos.py     |  8 ++-
 12 files changed, 71 insertions(+), 45 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 3440832275..3613b9be47 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -170,7 +170,7 @@ class SchedulerJob(BaseJob):
         signal.signal(signal.SIGTERM, self._exit_gracefully)
         signal.signal(signal.SIGUSR2, self._debug_dump)
 
-    def _exit_gracefully(self, signum, frame) -> None:
+    def _exit_gracefully(self, signum: int, frame) -> None:
         """Helper method to clean up processor_agent to avoid leaving orphan processes."""
         if not _is_parent_process():
             # Only the parent process should perform the cleanup.
@@ -181,7 +181,7 @@ class SchedulerJob(BaseJob):
             self.processor_agent.end()
         sys.exit(os.EX_OK)
 
-    def _debug_dump(self, signum, frame):
+    def _debug_dump(self, signum: int, frame) -> None:
         if not _is_parent_process():
             # Only the parent process should perform the debug dump.
             return
diff --git a/airflow/mypy/plugin/decorators.py b/airflow/mypy/plugin/decorators.py
index 76f1af54cd..32e1113876 100644
--- a/airflow/mypy/plugin/decorators.py
+++ b/airflow/mypy/plugin/decorators.py
@@ -68,7 +68,10 @@ def _change_decorator_function_type(
     # Mark provided arguments as optional
     decorator.arg_types = copy.copy(decorated.arg_types)
     for argument in provided_arguments:
-        index = decorated.arg_names.index(argument)
+        try:
+            index = decorated.arg_names.index(argument)
+        except ValueError:
+            continue
         decorated_type = decorated.arg_types[index]
         decorator.arg_types[index] = UnionType.make_union([decorated_type, NoneType()])
         decorated.arg_kinds[index] = ARG_NAMED_OPT
diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index a6f5f8da21..218f4dc16c 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -114,6 +114,7 @@ class DynamoDBToS3Operator(BaseOperator):
 
         scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
         err = None
+        f: IO[Any]
         with NamedTemporaryFile() as f:
             try:
                 f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index f399c27141..d9bebf5a39 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -16,8 +16,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import enum
 from collections import namedtuple
-from enum import Enum
 from tempfile import NamedTemporaryFile
 from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
 
@@ -35,10 +35,13 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-FILE_FORMAT = Enum(
-    "FILE_FORMAT",
-    "CSV, JSON, PARQUET",
-)
+class FILE_FORMAT(enum.Enum):
+    """Possible file formats."""
+
+    CSV = enum.auto()
+    JSON = enum.auto()
+    PARQUET = enum.auto()
+
 
 FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function'])
 
@@ -118,9 +121,9 @@ class SqlToS3Operator(BaseOperator):
         if "path_or_buf" in self.pd_kwargs:
             raise AirflowException('The argument path_or_buf is not allowed, please remove it')
 
-        self.file_format = getattr(FILE_FORMAT, file_format.upper(), None)
-
-        if self.file_format is None:
+        try:
+            self.file_format = FILE_FORMAT[file_format.upper()]
+        except KeyError:
             raise AirflowException(f"The argument file_format doesn't support {file_format} value.")
 
     @staticmethod
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py
index 1441f518b4..fb5a88593e 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
 SETTINGS = 'settings'
 SETTINGS_VERSION = 'settingsVersion'
 
-CLOUD_SQL_CREATE_VALIDATION = [
+CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [
     dict(name="name", allow_empty=False),
     dict(
         name="settings",
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py
index ed475978b0..954b584846 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -23,6 +23,7 @@ Airflow connection of type `azure_cosmos` exists. Authorization can be done by s
 login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
 the default database and collection to use (see connection `azure_cosmos_default` for an example).
 """
+import json
 import uuid
 from typing import Any, Dict, Optional
 
@@ -140,14 +141,22 @@ class AzureCosmosDBHook(BaseHook):
         existing_container = list(
             self.get_conn()
             .get_database_client(self.__get_database_name(database_name))
-            .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
+            .query_containers(
+                "SELECT * FROM r WHERE r.id=@id",
+                parameters=[json.dumps({"name": "@id", "value": collection_name})],
+            )
         )
         if len(existing_container) == 0:
             return False
 
         return True
 
-    def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None:
+    def create_collection(
+        self,
+        collection_name: str,
+        database_name: Optional[str] = None,
+        partition_key: Optional[str] = None,
+    ) -> None:
         """Creates a new collection in the CosmosDB database."""
         if collection_name is None:
             raise AirflowBadRequest("Collection name cannot be None.")
@@ -157,13 +166,16 @@ class AzureCosmosDBHook(BaseHook):
         existing_container = list(
             self.get_conn()
             .get_database_client(self.__get_database_name(database_name))
-            .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
+            .query_containers(
+                "SELECT * FROM r WHERE r.id=@id",
+                parameters=[json.dumps({"name": "@id", "value": collection_name})],
+            )
         )
 
         # Only create if we did not find it already existing
         if len(existing_container) == 0:
             self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
-                collection_name
+                collection_name, partition_key=partition_key
             )
 
     def does_database_exist(self, database_name: str) -> bool:
@@ -173,10 +185,8 @@ class AzureCosmosDBHook(BaseHook):
 
         existing_database = list(
             self.get_conn().query_databases(
-                {
-                    "query": "SELECT * FROM r WHERE r.id=@id",
-                    "parameters": [{"name": "@id", "value": database_name}],
-                }
+                "SELECT * FROM r WHERE r.id=@id",
+                parameters=[json.dumps({"name": "@id", "value": database_name})],
             )
         )
         if len(existing_database) == 0:
@@ -193,10 +203,8 @@ class AzureCosmosDBHook(BaseHook):
         # to create it twice
         existing_database = list(
             self.get_conn().query_databases(
-                {
-                    "query": "SELECT * FROM r WHERE r.id=@id",
-                    "parameters": [{"name": "@id", "value": database_name}],
-                }
+                "SELECT * FROM r WHERE r.id=@id",
+                parameters=[json.dumps({"name": "@id", "value": database_name})],
             )
         )
 
@@ -267,18 +275,28 @@ class AzureCosmosDBHook(BaseHook):
         return created_documents
 
     def delete_document(
-        self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
+        self,
+        document_id: str,
+        database_name: Optional[str] = None,
+        collection_name: Optional[str] = None,
+        partition_key: Optional[str] = None,
     ) -> None:
         """Delete an existing document out of a collection in the CosmosDB database."""
         if document_id is None:
             raise AirflowBadRequest("Cannot delete a document without an id")
-
-        self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client(
-            self.__get_collection_name(collection_name)
-        ).delete_item(document_id)
+        (
+            self.get_conn()
+            .get_database_client(self.__get_database_name(database_name))
+            .get_container_client(self.__get_collection_name(collection_name))
+            .delete_item(document_id, partition_key=partition_key)
+        )
 
     def get_document(
-        self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
+        self,
+        document_id: str,
+        database_name: Optional[str] = None,
+        collection_name: Optional[str] = None,
+        partition_key: Optional[str] = None,
     ):
         """Get a document from an existing collection in the CosmosDB database."""
         if document_id is None:
@@ -289,7 +307,7 @@ class AzureCosmosDBHook(BaseHook):
                 self.get_conn()
                 .get_database_client(self.__get_database_name(database_name))
                 .get_container_client(self.__get_collection_name(collection_name))
-                .read_item(document_id)
+                .read_item(document_id, partition_key=partition_key)
             )
         except CosmosHttpResponseError:
             return None
@@ -305,17 +323,13 @@ class AzureCosmosDBHook(BaseHook):
         if sql_string is None:
             raise AirflowBadRequest("SQL query string cannot be None")
 
-        # Query them in SQL
-        query = {'query': sql_string}
-
         try:
             result_iterable = (
                 self.get_conn()
                 .get_database_client(self.__get_database_name(database_name))
                 .get_container_client(self.__get_collection_name(collection_name))
-                .query_items(query, partition_key)
+                .query_items(sql_string, partition_key=partition_key)
             )
-
             return list(result_iterable)
         except CosmosHttpResponseError:
             return None
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 04dababa24..648a0f9a03 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -175,7 +175,7 @@ class Context(MutableMapping[str, Any]):
     }
 
     def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
-        self._context = context or {}
+        self._context: MutableMapping[str, Any] = context or {}
         if kwargs:
             self._context.update(kwargs)
         self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
diff --git a/dev/breeze/src/airflow_breeze/commands/testing_commands.py b/dev/breeze/src/airflow_breeze/commands/testing_commands.py
index b53333ea64..05aa3aa7e8 100644
--- a/dev/breeze/src/airflow_breeze/commands/testing_commands.py
+++ b/dev/breeze/src/airflow_breeze/commands/testing_commands.py
@@ -197,9 +197,9 @@ def run_with_progress(
 ) -> RunCommandResult:
     title = f"Running tests: {test_type}, Python: {python}, Backend: {backend}:{version}"
     try:
-        with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as f:
+        with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tf:
             get_console().print(f"[info]Starting test = {title}[/]")
-            thread = MonitoringThread(title=title, file_name=f.name)
+            thread = MonitoringThread(title=title, file_name=tf.name)
             thread.start()
             try:
                 result = run_command(
@@ -208,14 +208,14 @@ def run_with_progress(
                     dry_run=dry_run,
                     env=env_variables,
                     check=False,
-                    stdout=f,
+                    stdout=tf,
                     stderr=subprocess.STDOUT,
                 )
             finally:
                 thread.stop()
                 thread.join()
         with ci_group(f"Result of {title}", message_type=message_type_from_return_code(result.returncode)):
-            with open(f.name) as f:
+            with open(tf.name) as f:
                 shutil.copyfileobj(f, sys.stdout)
     finally:
         os.unlink(f.name)
diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py
index cc05408c2a..12ff265c55 100755
--- a/scripts/in_container/run_migration_reference.py
+++ b/scripts/in_container/run_migration_reference.py
@@ -102,6 +102,7 @@ def revision_suffix(rev: "Script"):
 
 def ensure_airflow_version(revisions: Iterable["Script"]):
     for rev in revisions:
+        assert rev.module.__file__ is not None  # For Mypy.
         file = Path(rev.module.__file__)
         content = file.read_text()
         if not has_version(content):
diff --git a/setup.cfg b/setup.cfg
index 0e1e9f7b84..2c96a0de42 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -146,7 +146,7 @@ install_requires =
     tabulate>=0.7.5
     tenacity>=6.2.0
     termcolor>=1.1.0
-    typing-extensions>=3.7.4
+    typing-extensions>=4.0.0
     unicodecsv>=0.14.1
     werkzeug>=2.0
 
diff --git a/setup.py b/setup.py
index 4d5dbd1bb8..6447281e5d 100644
--- a/setup.py
+++ b/setup.py
@@ -578,7 +578,7 @@ zendesk = [
 # mypyd which does not support installing the types dynamically with --install-types
 mypy_dependencies = [
     # TODO: upgrade to newer versions of MyPy continuously as they are released
-    'mypy==0.910',
+    'mypy==0.950',
     'types-boto',
     'types-certifi',
     'types-croniter',
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index b407fbdb3c..e157a5276b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -90,7 +90,9 @@ class TestAzureCosmosDbHook(unittest.TestCase):
         hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
         hook.create_collection(self.test_collection_name, self.test_database_name)
         expected_calls = [
-            mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+            mock.call()
+            .get_database_client('test_database_name')
+            .create_container('test_collection_name', partition_key=None)
         ]
         mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)
@@ -100,7 +102,9 @@ class TestAzureCosmosDbHook(unittest.TestCase):
         hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
         hook.create_collection(self.test_collection_name)
         expected_calls = [
-            mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+            mock.call()
+            .get_database_client('test_database_name')
+            .create_container('test_collection_name', partition_key=None)
         ]
         mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
         mock_cosmos.assert_has_calls(expected_calls)