You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/21 14:28:17 UTC
[airflow] 14/22: Bump typing-extensions and mypy for ParamSpec (#25088)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit ec8ea0e200d5a13b1414312e017eea1af6fdf98d
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Mon Jul 18 17:20:35 2022 +0800
Bump typing-extensions and mypy for ParamSpec (#25088)
* Bump typing-extensions and mypy for ParamSpec
I want to use them in some @task signature improvements. Mypy added this
in 0.950, but let's just bump to latest since why not.
Changelog of typing-extensions is spotty before 4.0, but ParamSpec was
introduced some time before that (likely some time in 2021), and it
seems to be a reasonble minimum to bump to.
For more about ParamSpec, read PEP 612: https://peps.python.org/pep-0612/
(cherry picked from commit e32e9c58802fe9363cc87ea283a59218df7cec3a)
---
airflow/jobs/scheduler_job.py | 4 +-
airflow/mypy/plugin/decorators.py | 5 +-
.../amazon/aws/transfers/dynamodb_to_s3.py | 1 +
.../providers/amazon/aws/transfers/sql_to_s3.py | 19 ++++---
.../providers/google/cloud/operators/cloud_sql.py | 2 +-
airflow/providers/microsoft/azure/hooks/cosmos.py | 62 +++++++++++++---------
airflow/utils/context.py | 2 +-
.../airflow_breeze/commands/testing_commands.py | 8 +--
scripts/in_container/run_migration_reference.py | 1 +
setup.cfg | 2 +-
setup.py | 2 +-
.../microsoft/azure/hooks/test_azure_cosmos.py | 8 ++-
12 files changed, 71 insertions(+), 45 deletions(-)
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 3440832275..3613b9be47 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -170,7 +170,7 @@ class SchedulerJob(BaseJob):
signal.signal(signal.SIGTERM, self._exit_gracefully)
signal.signal(signal.SIGUSR2, self._debug_dump)
- def _exit_gracefully(self, signum, frame) -> None:
+ def _exit_gracefully(self, signum: int, frame) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
@@ -181,7 +181,7 @@ class SchedulerJob(BaseJob):
self.processor_agent.end()
sys.exit(os.EX_OK)
- def _debug_dump(self, signum, frame):
+ def _debug_dump(self, signum: int, frame) -> None:
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return
diff --git a/airflow/mypy/plugin/decorators.py b/airflow/mypy/plugin/decorators.py
index 76f1af54cd..32e1113876 100644
--- a/airflow/mypy/plugin/decorators.py
+++ b/airflow/mypy/plugin/decorators.py
@@ -68,7 +68,10 @@ def _change_decorator_function_type(
# Mark provided arguments as optional
decorator.arg_types = copy.copy(decorated.arg_types)
for argument in provided_arguments:
- index = decorated.arg_names.index(argument)
+ try:
+ index = decorated.arg_names.index(argument)
+ except ValueError:
+ continue
decorated_type = decorated.arg_types[index]
decorator.arg_types[index] = UnionType.make_union([decorated_type, NoneType()])
decorated.arg_kinds[index] = ARG_NAMED_OPT
diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index a6f5f8da21..218f4dc16c 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -114,6 +114,7 @@ class DynamoDBToS3Operator(BaseOperator):
scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
+ f: IO[Any]
with NamedTemporaryFile() as f:
try:
f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index f399c27141..d9bebf5a39 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.
+import enum
from collections import namedtuple
-from enum import Enum
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
@@ -35,10 +35,13 @@ if TYPE_CHECKING:
from airflow.utils.context import Context
-FILE_FORMAT = Enum(
- "FILE_FORMAT",
- "CSV, JSON, PARQUET",
-)
+class FILE_FORMAT(enum.Enum):
+ """Possible file formats."""
+
+ CSV = enum.auto()
+ JSON = enum.auto()
+ PARQUET = enum.auto()
+
FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function'])
@@ -118,9 +121,9 @@ class SqlToS3Operator(BaseOperator):
if "path_or_buf" in self.pd_kwargs:
raise AirflowException('The argument path_or_buf is not allowed, please remove it')
- self.file_format = getattr(FILE_FORMAT, file_format.upper(), None)
-
- if self.file_format is None:
+ try:
+ self.file_format = FILE_FORMAT[file_format.upper()]
+ except KeyError:
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")
@staticmethod
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py
index 1441f518b4..fb5a88593e 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
SETTINGS = 'settings'
SETTINGS_VERSION = 'settingsVersion'
-CLOUD_SQL_CREATE_VALIDATION = [
+CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [
dict(name="name", allow_empty=False),
dict(
name="settings",
diff --git a/airflow/providers/microsoft/azure/hooks/cosmos.py b/airflow/providers/microsoft/azure/hooks/cosmos.py
index ed475978b0..954b584846 100644
--- a/airflow/providers/microsoft/azure/hooks/cosmos.py
+++ b/airflow/providers/microsoft/azure/hooks/cosmos.py
@@ -23,6 +23,7 @@ Airflow connection of type `azure_cosmos` exists. Authorization can be done by s
login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
the default database and collection to use (see connection `azure_cosmos_default` for an example).
"""
+import json
import uuid
from typing import Any, Dict, Optional
@@ -140,14 +141,22 @@ class AzureCosmosDBHook(BaseHook):
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
- .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
+ .query_containers(
+ "SELECT * FROM r WHERE r.id=@id",
+ parameters=[json.dumps({"name": "@id", "value": collection_name})],
+ )
)
if len(existing_container) == 0:
return False
return True
- def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None:
+ def create_collection(
+ self,
+ collection_name: str,
+ database_name: Optional[str] = None,
+ partition_key: Optional[str] = None,
+ ) -> None:
"""Creates a new collection in the CosmosDB database."""
if collection_name is None:
raise AirflowBadRequest("Collection name cannot be None.")
@@ -157,13 +166,16 @@ class AzureCosmosDBHook(BaseHook):
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
- .query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
+ .query_containers(
+ "SELECT * FROM r WHERE r.id=@id",
+ parameters=[json.dumps({"name": "@id", "value": collection_name})],
+ )
)
# Only create if we did not find it already existing
if len(existing_container) == 0:
self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
- collection_name
+ collection_name, partition_key=partition_key
)
def does_database_exist(self, database_name: str) -> bool:
@@ -173,10 +185,8 @@ class AzureCosmosDBHook(BaseHook):
existing_database = list(
self.get_conn().query_databases(
- {
- "query": "SELECT * FROM r WHERE r.id=@id",
- "parameters": [{"name": "@id", "value": database_name}],
- }
+ "SELECT * FROM r WHERE r.id=@id",
+ parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)
if len(existing_database) == 0:
@@ -193,10 +203,8 @@ class AzureCosmosDBHook(BaseHook):
# to create it twice
existing_database = list(
self.get_conn().query_databases(
- {
- "query": "SELECT * FROM r WHERE r.id=@id",
- "parameters": [{"name": "@id", "value": database_name}],
- }
+ "SELECT * FROM r WHERE r.id=@id",
+ parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)
@@ -267,18 +275,28 @@ class AzureCosmosDBHook(BaseHook):
return created_documents
def delete_document(
- self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
+ self,
+ document_id: str,
+ database_name: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ partition_key: Optional[str] = None,
) -> None:
"""Delete an existing document out of a collection in the CosmosDB database."""
if document_id is None:
raise AirflowBadRequest("Cannot delete a document without an id")
-
- self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client(
- self.__get_collection_name(collection_name)
- ).delete_item(document_id)
+ (
+ self.get_conn()
+ .get_database_client(self.__get_database_name(database_name))
+ .get_container_client(self.__get_collection_name(collection_name))
+ .delete_item(document_id, partition_key=partition_key)
+ )
def get_document(
- self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
+ self,
+ document_id: str,
+ database_name: Optional[str] = None,
+ collection_name: Optional[str] = None,
+ partition_key: Optional[str] = None,
):
"""Get a document from an existing collection in the CosmosDB database."""
if document_id is None:
@@ -289,7 +307,7 @@ class AzureCosmosDBHook(BaseHook):
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
- .read_item(document_id)
+ .read_item(document_id, partition_key=partition_key)
)
except CosmosHttpResponseError:
return None
@@ -305,17 +323,13 @@ class AzureCosmosDBHook(BaseHook):
if sql_string is None:
raise AirflowBadRequest("SQL query string cannot be None")
- # Query them in SQL
- query = {'query': sql_string}
-
try:
result_iterable = (
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
- .query_items(query, partition_key)
+ .query_items(sql_string, partition_key=partition_key)
)
-
return list(result_iterable)
except CosmosHttpResponseError:
return None
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 04dababa24..648a0f9a03 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -175,7 +175,7 @@ class Context(MutableMapping[str, Any]):
}
def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
- self._context = context or {}
+ self._context: MutableMapping[str, Any] = context or {}
if kwargs:
self._context.update(kwargs)
self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
diff --git a/dev/breeze/src/airflow_breeze/commands/testing_commands.py b/dev/breeze/src/airflow_breeze/commands/testing_commands.py
index b53333ea64..05aa3aa7e8 100644
--- a/dev/breeze/src/airflow_breeze/commands/testing_commands.py
+++ b/dev/breeze/src/airflow_breeze/commands/testing_commands.py
@@ -197,9 +197,9 @@ def run_with_progress(
) -> RunCommandResult:
title = f"Running tests: {test_type}, Python: {python}, Backend: {backend}:{version}"
try:
- with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as f:
+ with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tf:
get_console().print(f"[info]Starting test = {title}[/]")
- thread = MonitoringThread(title=title, file_name=f.name)
+ thread = MonitoringThread(title=title, file_name=tf.name)
thread.start()
try:
result = run_command(
@@ -208,14 +208,14 @@ def run_with_progress(
dry_run=dry_run,
env=env_variables,
check=False,
- stdout=f,
+ stdout=tf,
stderr=subprocess.STDOUT,
)
finally:
thread.stop()
thread.join()
with ci_group(f"Result of {title}", message_type=message_type_from_return_code(result.returncode)):
- with open(f.name) as f:
+ with open(tf.name) as f:
shutil.copyfileobj(f, sys.stdout)
finally:
os.unlink(f.name)
diff --git a/scripts/in_container/run_migration_reference.py b/scripts/in_container/run_migration_reference.py
index cc05408c2a..12ff265c55 100755
--- a/scripts/in_container/run_migration_reference.py
+++ b/scripts/in_container/run_migration_reference.py
@@ -102,6 +102,7 @@ def revision_suffix(rev: "Script"):
def ensure_airflow_version(revisions: Iterable["Script"]):
for rev in revisions:
+ assert rev.module.__file__ is not None # For Mypy.
file = Path(rev.module.__file__)
content = file.read_text()
if not has_version(content):
diff --git a/setup.cfg b/setup.cfg
index 0e1e9f7b84..2c96a0de42 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -146,7 +146,7 @@ install_requires =
tabulate>=0.7.5
tenacity>=6.2.0
termcolor>=1.1.0
- typing-extensions>=3.7.4
+ typing-extensions>=4.0.0
unicodecsv>=0.14.1
werkzeug>=2.0
diff --git a/setup.py b/setup.py
index 4d5dbd1bb8..6447281e5d 100644
--- a/setup.py
+++ b/setup.py
@@ -578,7 +578,7 @@ zendesk = [
# mypyd which does not support installing the types dynamically with --install-types
mypy_dependencies = [
# TODO: upgrade to newer versions of MyPy continuously as they are released
- 'mypy==0.910',
+ 'mypy==0.950',
'types-boto',
'types-certifi',
'types-croniter',
diff --git a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
index b407fbdb3c..e157a5276b 100644
--- a/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
+++ b/tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
@@ -90,7 +90,9 @@ class TestAzureCosmosDbHook(unittest.TestCase):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name, self.test_database_name)
expected_calls = [
- mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+ mock.call()
+ .get_database_client('test_database_name')
+ .create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
@@ -100,7 +102,9 @@ class TestAzureCosmosDbHook(unittest.TestCase):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name)
expected_calls = [
- mock.call().get_database_client('test_database_name').create_container('test_collection_name')
+ mock.call()
+ .get_database_client('test_database_name')
+ .create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)