You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/11 19:27:18 UTC

[airflow] branch main updated: Only assert stuff for mypy when type checking (#24937)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ef79a0d1c4 Only assert stuff for mypy when type checking (#24937)
ef79a0d1c4 is described below

commit ef79a0d1c4c0a041d7ebf83b93cbb25aa3778a70
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Mon Jul 11 13:27:03 2022 -0600

    Only assert stuff for mypy when type checking (#24937)
---
 airflow/jobs/local_task_job.py                         | 5 +++--
 airflow/models/skipmixin.py                            | 3 ++-
 airflow/models/taskinstance.py                         | 9 ++++++---
 airflow/providers/cncf/kubernetes/hooks/kubernetes.py  | 5 +++--
 airflow/providers/elasticsearch/log/es_task_handler.py | 5 +++--
 airflow/providers/github/hooks/github.py               | 5 +++--
 airflow/utils/log/file_task_handler.py                 | 3 ++-
 7 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index b73c8992d8..147475da4b 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -17,7 +17,7 @@
 # under the License.
 #
 import signal
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 import psutil
 from sqlalchemy.exc import OperationalError
@@ -243,7 +243,8 @@ class LocalTaskJob(BaseJob):
             ).one()
 
             task = self.task_instance.task
-            assert task.dag  # For Mypy.
+            if TYPE_CHECKING:
+                assert task.dag
 
             # Get a partial DAG with just the specific tasks we want to examine.
             # In order for dep checks to work correctly, we include ourself (so
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d5b1481cb1..57864a5940 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -148,7 +148,8 @@ class SkipMixin(LoggingMixin):
         dag_run = ti.get_dagrun()
         task = ti.task
         dag = task.dag
-        assert dag  # For Mypy.
+        if TYPE_CHECKING:
+            assert dag
 
         # At runtime, the downstream list will only be operators
         downstream_tasks = cast("List[BaseOperator]", task.downstream_list)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 024a0a4b1b..c198c6f09d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -548,7 +548,8 @@ class TaskInstance(Base, LoggingMixin):
                     execution_date,
                 )
                 if self.task.has_dag():
-                    assert self.task.dag  # For Mypy.
+                    if TYPE_CHECKING:
+                        assert self.task.dag
                     execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
                 else:
                     execution_date = timezone.make_aware(execution_date)
@@ -1780,7 +1781,8 @@ class TaskInstance(Base, LoggingMixin):
 
         self.task = self.task.prepare_for_execution()
         self.render_templates()
-        assert isinstance(self.task, BaseOperator)  # For Mypy.
+        if TYPE_CHECKING:
+            assert isinstance(self.task, BaseOperator)
         self.task.dry_run()
 
     @provide_session
@@ -1952,7 +1954,8 @@ class TaskInstance(Base, LoggingMixin):
         integrate_macros_plugins()
 
         task = self.task
-        assert task.dag  # For Mypy.
+        if TYPE_CHECKING:
+            assert task.dag
         dag: DAG = task.dag
 
         dag_run = self.get_dagrun(session)
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 725343211b..3e2356aacb 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -16,7 +16,7 @@
 # under the License.
 import tempfile
 import warnings
-from typing import Any, Dict, Generator, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
 
 from kubernetes import client, config, watch
 from kubernetes.config import ConfigException
@@ -281,7 +281,8 @@ class KubernetesHook(BaseHook):
         if self._is_in_cluster is not None:
             return self._is_in_cluster
         self.api_client  # so we can determine if we are in_cluster or not
-        assert self._is_in_cluster is not None
+        if TYPE_CHECKING:
+            assert self._is_in_cluster is not None
         return self._is_in_cluster
 
     @cached_property
diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py
index 4707f523d6..b53fe3310c 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -23,7 +23,7 @@ from collections import defaultdict
 from datetime import datetime
 from operator import attrgetter
 from time import time
-from typing import List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 from urllib.parse import quote
 
 # Using `from elasticsearch import *` would break elasticsearch mocking used in unit test.
@@ -130,7 +130,8 @@ class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMix
         except AttributeError:  # ti.task is not always set.
             data_interval = (dag_run.data_interval_start, dag_run.data_interval_end)
         else:
-            assert dag is not None  # For Mypy.
+            if TYPE_CHECKING:
+                assert dag is not None
             data_interval = dag.get_run_data_interval(dag_run)
 
         if self.json_format:
diff --git a/airflow/providers/github/hooks/github.py b/airflow/providers/github/hooks/github.py
index 9a71ef5b38..bb21912115 100644
--- a/airflow/providers/github/hooks/github.py
+++ b/airflow/providers/github/hooks/github.py
@@ -17,7 +17,7 @@
 # under the License.
 
 """This module allows you to connect to GitHub."""
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 
 from github import Github as GithubClient
 
@@ -79,7 +79,8 @@ class GithubHook(BaseHook):
     def test_connection(self) -> Tuple[bool, str]:
         """Test GitHub connection."""
         try:
-            assert self.client  # For mypy union-attr check of Optional[GithubClient].
+            if TYPE_CHECKING:
+                assert self.client
             self.client.get_user().id
             return True, "Successfully connected to GitHub."
         except Exception as e:
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index 471d5b95be..041e4778e3 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -97,7 +97,8 @@ class FileTaskHandler(logging.Handler):
             except AttributeError:  # ti.task is not always set.
                 data_interval = (dag_run.data_interval_start, dag_run.data_interval_end)
             else:
-                assert dag is not None  # For Mypy.
+                if TYPE_CHECKING:
+                    assert dag is not None
                 data_interval = dag.get_run_data_interval(dag_run)
             if data_interval[0]:
                 data_interval_start = data_interval[0].isoformat()