You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/29 21:48:06 UTC

[airflow] branch main updated: Fix MyPy errors for Airflow decorators (#20034)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 59e4b78  Fix MyPy errors for Airflow decorators (#20034)
59e4b78 is described below

commit 59e4b78daa3496cb0358ce34aeb5ebf6f5565ce0
Author: Josh Fell <48...@users.noreply.github.com>
AuthorDate: Wed Dec 29 16:47:36 2021 -0500

    Fix MyPy errors for Airflow decorators (#20034)
    
    Related: #19891
---
 airflow/decorators/__init__.py                |  2 +-
 airflow/decorators/base.py                    | 21 ++++++++++++---------
 airflow/models/baseoperator.py                |  6 +++---
 airflow/operators/python.py                   | 10 +++++-----
 airflow/providers/docker/decorators/docker.py |  9 ++++++---
 airflow/providers/docker/operators/docker.py  |  4 ++--
 6 files changed, 29 insertions(+), 23 deletions(-)

diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py
index ef490df..47a20d4 100644
--- a/airflow/decorators/__init__.py
+++ b/airflow/decorators/__init__.py
@@ -42,7 +42,7 @@ if TYPE_CHECKING:
         class _DockerTask(_TaskDecorator, DockerDecoratorMixin):
             pass
 
-        _TaskDecorator = _DockerTask
+        _TaskDecorator = _DockerTask  # type: ignore[misc]
     except ImportError:
         pass
 # [END mixin_for_autocomplete]
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 0557125..1a7e717 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -19,12 +19,13 @@ import functools
 import inspect
 import re
 from inspect import signature
-from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, cast
+from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
 from airflow.models.dag import DAG, DagContext
 from airflow.models.xcom_arg import XComArg
+from airflow.utils.context import Context
 from airflow.utils.task_group import TaskGroup, TaskGroupContext
 
 
@@ -101,27 +102,29 @@ class DecoratedOperator(BaseOperator):
     :type kwargs_to_upstream: dict
     """
 
-    template_fields: Iterable[str] = ('op_args', 'op_kwargs')
+    template_fields: Sequence[str] = ('op_args', 'op_kwargs')
     template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
-    shallow_copy_attrs = ('python_callable',)
+    shallow_copy_attrs: Sequence[str] = ('python_callable',)
 
     def __init__(
         self,
         *,
         python_callable: Callable,
         task_id: str,
-        op_args: Tuple[Any],
-        op_kwargs: Dict[str, Any],
+        op_args: Optional[Collection[Any]] = None,
+        op_kwargs: Optional[Mapping[str, Any]] = None,
         multiple_outputs: bool = False,
-        kwargs_to_upstream: dict = None,
+        kwargs_to_upstream: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> None:
         kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
         self.python_callable = python_callable
         kwargs_to_upstream = kwargs_to_upstream or {}
+        op_args = op_args or []
+        op_kwargs = op_kwargs or {}
 
         # Check that arguments can be binded
         signature(python_callable).bind(*op_args, **op_kwargs)
@@ -130,7 +133,7 @@ class DecoratedOperator(BaseOperator):
         self.op_kwargs = op_kwargs
         super().__init__(**kwargs_to_upstream, **kwargs)
 
-    def execute(self, context: Dict):
+    def execute(self, context: Context):
         return_value = super().execute(context)
         return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
 
@@ -180,7 +183,7 @@ T = TypeVar("T", bound=Callable)
 def task_decorator_factory(
     python_callable: Optional[Callable] = None,
     multiple_outputs: Optional[bool] = None,
-    decorated_operator_class: Type[BaseOperator] = None,
+    decorated_operator_class: Optional[Type[BaseOperator]] = None,
     **kwargs,
 ) -> Callable[[T], T]:
     """
@@ -196,7 +199,7 @@ def task_decorator_factory(
     :type multiple_outputs: bool
     :param decorated_operator_class: The operator that executes the logic needed to run the python function in
         the correct environment
-    :type decorated_operator_class: BaseDecoratedOperator
+    :type decorated_operator_class: BaseOperator
 
     """
     # try to infer from  type annotation
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 1233939..c3c5be6 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -423,9 +423,9 @@ class BaseOperator(Operator, LoggingMixin, DependencyMixin, metaclass=BaseOperat
     """
 
     # For derived classes to define which fields will get jinjaified
-    template_fields: Iterable[str] = ()
+    template_fields: Sequence[str] = ()
     # Defines which files extensions to look for in the templated fields
-    template_ext: Iterable[str] = ()
+    template_ext: Sequence[str] = ()
     # Template field renderers indicating type of the field, for example sql, json, bash
     template_fields_renderers: Dict[str, str] = {}
 
@@ -444,7 +444,7 @@ class BaseOperator(Operator, LoggingMixin, DependencyMixin, metaclass=BaseOperat
     )
 
     # each operator should override this class attr for shallow copy attrs.
-    shallow_copy_attrs: Tuple[str, ...] = ()
+    shallow_copy_attrs: Sequence[str] = ()
 
     # Defines the operator level extra links
     operator_extra_links: Iterable['BaseOperatorLink'] = ()
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 611a203..eb5f710 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -24,7 +24,7 @@ import types
 import warnings
 from tempfile import TemporaryDirectory
 from textwrap import dedent
-from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Union
+from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Union
 
 import dill
 
@@ -131,14 +131,14 @@ class PythonOperator(BaseOperator):
     :type show_return_value_in_logs: bool
     """
 
-    template_fields = ('templates_dict', 'op_args', 'op_kwargs')
+    template_fields: Sequence[str] = ('templates_dict', 'op_args', 'op_kwargs')
     template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}
     BLUE = '#ffefeb'
     ui_color = BLUE
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects(e.g protobuf).
-    shallow_copy_attrs = (
+    shallow_copy_attrs: Sequence[str] = (
         'python_callable',
         'op_kwargs',
     )
@@ -149,8 +149,8 @@ class PythonOperator(BaseOperator):
         python_callable: Callable,
         op_args: Optional[Collection[Any]] = None,
         op_kwargs: Optional[Mapping[str, Any]] = None,
-        templates_dict: Optional[Dict] = None,
-        templates_exts: Optional[List[str]] = None,
+        templates_dict: Optional[Dict[str, Any]] = None,
+        templates_exts: Optional[Sequence[str]] = None,
         show_return_value_in_logs: bool = True,
         **kwargs,
     ) -> None:
diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py
index e1830ea..a4f79a7 100644
--- a/airflow/providers/docker/decorators/docker.py
+++ b/airflow/providers/docker/decorators/docker.py
@@ -21,7 +21,7 @@ import os
 import pickle
 from tempfile import TemporaryDirectory
 from textwrap import dedent
-from typing import Callable, Dict, Iterable, List, Optional, TypeVar, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, TypeVar, Union
 
 import dill
 
@@ -29,6 +29,9 @@ from airflow.decorators.base import DecoratedOperator, task_decorator_factory
 from airflow.providers.docker.operators.docker import DockerOperator
 from airflow.utils.python_virtualenv import remove_task_decorator, write_python_script
 
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
 
 def _generate_decode_command(env_var, file):
     # We don't need `f.close()` as the interpreter is about to exit anyway
@@ -62,7 +65,7 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator):
     :type multiple_outputs: bool
     """
 
-    template_fields: Iterable[str] = ('op_args', 'op_kwargs')
+    template_fields = ('op_args', 'op_kwargs')
 
     # since we won't mutate the arguments, we should just do the shallow copy
     # there are some cases we can't deepcopy the objects (e.g protobuf).
@@ -79,7 +82,7 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator):
             command=command, retrieve_output=True, retrieve_output_path="/tmp/script.out", **kwargs
         )
 
-    def execute(self, context: Dict):
+    def execute(self, context: "Context") -> Any:
         with TemporaryDirectory(prefix='venv') as tmp_dir:
             input_filename = os.path.join(tmp_dir, 'script.in')
             script_filename = os.path.join(tmp_dir, 'script.py')
diff --git a/airflow/providers/docker/operators/docker.py b/airflow/providers/docker/operators/docker.py
index 652424f..7a048f6 100644
--- a/airflow/providers/docker/operators/docker.py
+++ b/airflow/providers/docker/operators/docker.py
@@ -21,7 +21,7 @@ import io
 import pickle
 import tarfile
 from tempfile import TemporaryDirectory
-from typing import Dict, Iterable, List, Optional, Union
+from typing import Dict, Iterable, List, Optional, Sequence, Union
 
 from docker import APIClient, tls
 from docker.errors import APIError
@@ -152,7 +152,7 @@ class DockerOperator(BaseOperator):
     :type retrieve_output_path: Optional[str]
     """
 
-    template_fields: Iterable[str] = ('image', 'command', 'environment', 'container_name')
+    template_fields: Sequence[str] = ('image', 'command', 'environment', 'container_name')
     template_ext = (
         '.sh',
         '.bash',