You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/12/16 13:28:30 UTC

[airflow] branch master updated: Annotate DagRun methods with return types (#11486)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new ccaca0a  Annotate DagRun methods with return types (#11486)
ccaca0a is described below

commit ccaca0af3933306c7ca0aa3d78fa2237ff3bfb19
Author: Joshua Carp <jm...@gmail.com>
AuthorDate: Wed Dec 16 08:27:21 2020 -0500

    Annotate DagRun methods with return types (#11486)
---
 airflow/models/dagrun.py | 43 ++++++++++++++++++++++---------------------
 1 file changed, 22 insertions(+), 21 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 5979237..130deed 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -16,7 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 from datetime import datetime
-from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Iterable, List, NamedTuple, Optional, Tuple, Union
 
 from sqlalchemy import (
     Boolean,
@@ -52,6 +52,9 @@ from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 
+if TYPE_CHECKING:
+    from airflow.models.dag import DAG
+
 
 class TISchedulingDecision(NamedTuple):
     """Type of return for DagRun.task_instance_scheduling_decisions"""
@@ -290,9 +293,7 @@ class DagRun(Base, LoggingMixin):
         if no_backfills:
             qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB)
 
-        dr = qry.order_by(DR.execution_date).all()
-
-        return dr
+        return qry.order_by(DR.execution_date).all()
 
     @staticmethod
     def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
@@ -300,7 +301,7 @@ class DagRun(Base, LoggingMixin):
         return f"{run_type}__{execution_date.isoformat()}"
 
     @provide_session
-    def get_task_instances(self, state=None, session=None):
+    def get_task_instances(self, state=None, session=None) -> Iterable[TI]:
         """Returns the task instances for this dag run"""
         tis = session.query(TI).filter(
             TI.dag_id == self.dag_id,
@@ -326,7 +327,7 @@ class DagRun(Base, LoggingMixin):
         return tis.all()
 
     @provide_session
-    def get_task_instance(self, task_id: str, session: Session = None):
+    def get_task_instance(self, task_id: str, session: Session = None) -> Optional[TI]:
         """
         Returns the task instance specified by task_id for this dag run
 
@@ -335,15 +336,13 @@ class DagRun(Base, LoggingMixin):
         :param session: Sqlalchemy ORM Session
         :type session: Session
         """
-        ti = (
+        return (
             session.query(TI)
             .filter(TI.dag_id == self.dag_id, TI.execution_date == self.execution_date, TI.task_id == task_id)
             .first()
         )
 
-        return ti
-
-    def get_dag(self):
+    def get_dag(self) -> "DAG":
         """
         Returns the Dag associated with this DagRun.
 
@@ -366,7 +365,7 @@ class DagRun(Base, LoggingMixin):
         return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
 
     @provide_session
-    def get_previous_scheduled_dagrun(self, session: Session = None):
+    def get_previous_scheduled_dagrun(self, session: Session = None) -> Optional['DagRun']:
         """The previous, SCHEDULED DagRun, if there is one"""
         dag = self.get_dag()
 
@@ -668,7 +667,7 @@ class DagRun(Base, LoggingMixin):
             session.rollback()
 
     @staticmethod
-    def get_run(session: Session, dag_id: str, execution_date: datetime):
+    def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
         """
         Get a single DAG Run
 
@@ -682,27 +681,30 @@ class DagRun(Base, LoggingMixin):
             if one exists. None otherwise.
         :rtype: airflow.models.DagRun
         """
-        qry = session.query(DagRun).filter(
-            DagRun.dag_id == dag_id,
-            DagRun.external_trigger == False,  # noqa pylint: disable=singleton-comparison
-            DagRun.execution_date == execution_date,
+        return (
+            session.query(DagRun)
+            .filter(
+                DagRun.dag_id == dag_id,
+                DagRun.external_trigger == False,  # noqa pylint: disable=singleton-comparison
+                DagRun.execution_date == execution_date,
+            )
+            .first()
         )
-        return qry.first()
 
     @property
-    def is_backfill(self):
+    def is_backfill(self) -> bool:
         return self.run_type == DagRunType.BACKFILL_JOB
 
     @classmethod
     @provide_session
-    def get_latest_runs(cls, session=None):
+    def get_latest_runs(cls, session=None) -> List['DagRun']:
         """Returns the latest DagRun for each DAG"""
         subquery = (
             session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date'))
             .group_by(cls.dag_id)
             .subquery()
         )
-        dagruns = (
+        return (
             session.query(cls)
             .join(
                 subquery,
@@ -710,7 +712,6 @@ class DagRun(Base, LoggingMixin):
             )
             .all()
         )
-        return dagruns
 
     @provide_session
     def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int: