You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/11 19:27:18 UTC
[airflow] branch main updated: Only assert stuff for mypy when type checking (#24937)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ef79a0d1c4 Only assert stuff for mypy when type checking (#24937)
ef79a0d1c4 is described below
commit ef79a0d1c4c0a041d7ebf83b93cbb25aa3778a70
Author: Jed Cunningham <66...@users.noreply.github.com>
AuthorDate: Mon Jul 11 13:27:03 2022 -0600
Only assert stuff for mypy when type checking (#24937)
---
airflow/jobs/local_task_job.py | 5 +++--
airflow/models/skipmixin.py | 3 ++-
airflow/models/taskinstance.py | 9 ++++++---
airflow/providers/cncf/kubernetes/hooks/kubernetes.py | 5 +++--
airflow/providers/elasticsearch/log/es_task_handler.py | 5 +++--
airflow/providers/github/hooks/github.py | 5 +++--
airflow/utils/log/file_task_handler.py | 3 ++-
7 files changed, 22 insertions(+), 13 deletions(-)
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index b73c8992d8..147475da4b 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -17,7 +17,7 @@
# under the License.
#
import signal
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
import psutil
from sqlalchemy.exc import OperationalError
@@ -243,7 +243,8 @@ class LocalTaskJob(BaseJob):
).one()
task = self.task_instance.task
- assert task.dag # For Mypy.
+ if TYPE_CHECKING:
+ assert task.dag
# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d5b1481cb1..57864a5940 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -148,7 +148,8 @@ class SkipMixin(LoggingMixin):
dag_run = ti.get_dagrun()
task = ti.task
dag = task.dag
- assert dag # For Mypy.
+ if TYPE_CHECKING:
+ assert dag
# At runtime, the downstream list will only be operators
downstream_tasks = cast("List[BaseOperator]", task.downstream_list)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 024a0a4b1b..c198c6f09d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -548,7 +548,8 @@ class TaskInstance(Base, LoggingMixin):
execution_date,
)
if self.task.has_dag():
- assert self.task.dag # For Mypy.
+ if TYPE_CHECKING:
+ assert self.task.dag
execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
else:
execution_date = timezone.make_aware(execution_date)
@@ -1780,7 +1781,8 @@ class TaskInstance(Base, LoggingMixin):
self.task = self.task.prepare_for_execution()
self.render_templates()
- assert isinstance(self.task, BaseOperator) # For Mypy.
+ if TYPE_CHECKING:
+ assert isinstance(self.task, BaseOperator)
self.task.dry_run()
@provide_session
@@ -1952,7 +1954,8 @@ class TaskInstance(Base, LoggingMixin):
integrate_macros_plugins()
task = self.task
- assert task.dag # For Mypy.
+ if TYPE_CHECKING:
+ assert task.dag
dag: DAG = task.dag
dag_run = self.get_dagrun(session)
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index 725343211b..3e2356aacb 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -16,7 +16,7 @@
# under the License.
import tempfile
import warnings
-from typing import Any, Dict, Generator, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
from kubernetes import client, config, watch
from kubernetes.config import ConfigException
@@ -281,7 +281,8 @@ class KubernetesHook(BaseHook):
if self._is_in_cluster is not None:
return self._is_in_cluster
self.api_client # so we can determine if we are in_cluster or not
- assert self._is_in_cluster is not None
+ if TYPE_CHECKING:
+ assert self._is_in_cluster is not None
return self._is_in_cluster
@cached_property
diff --git a/airflow/providers/elasticsearch/log/es_task_handler.py b/airflow/providers/elasticsearch/log/es_task_handler.py
index 4707f523d6..b53fe3310c 100644
--- a/airflow/providers/elasticsearch/log/es_task_handler.py
+++ b/airflow/providers/elasticsearch/log/es_task_handler.py
@@ -23,7 +23,7 @@ from collections import defaultdict
from datetime import datetime
from operator import attrgetter
from time import time
-from typing import List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from urllib.parse import quote
# Using `from elasticsearch import *` would break elasticsearch mocking used in unit test.
@@ -130,7 +130,8 @@ class ElasticsearchTaskHandler(FileTaskHandler, ExternalLoggingMixin, LoggingMix
except AttributeError: # ti.task is not always set.
data_interval = (dag_run.data_interval_start, dag_run.data_interval_end)
else:
- assert dag is not None # For Mypy.
+ if TYPE_CHECKING:
+ assert dag is not None
data_interval = dag.get_run_data_interval(dag_run)
if self.json_format:
diff --git a/airflow/providers/github/hooks/github.py b/airflow/providers/github/hooks/github.py
index 9a71ef5b38..bb21912115 100644
--- a/airflow/providers/github/hooks/github.py
+++ b/airflow/providers/github/hooks/github.py
@@ -17,7 +17,7 @@
# under the License.
"""This module allows you to connect to GitHub."""
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from github import Github as GithubClient
@@ -79,7 +79,8 @@ class GithubHook(BaseHook):
def test_connection(self) -> Tuple[bool, str]:
"""Test GitHub connection."""
try:
- assert self.client # For mypy union-attr check of Optional[GithubClient].
+ if TYPE_CHECKING:
+ assert self.client
self.client.get_user().id
return True, "Successfully connected to GitHub."
except Exception as e:
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index 471d5b95be..041e4778e3 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -97,7 +97,8 @@ class FileTaskHandler(logging.Handler):
except AttributeError: # ti.task is not always set.
data_interval = (dag_run.data_interval_start, dag_run.data_interval_end)
else:
- assert dag is not None # For Mypy.
+ if TYPE_CHECKING:
+ assert dag is not None
data_interval = dag.get_run_data_interval(dag_run)
if data_interval[0]:
data_interval_start = data_interval[0].isoformat()