You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/13 19:49:14 UTC

[airflow] 05/08: Add 'queued' to DagRunState (#16854)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit f4c95c5ad7241c55cc27b13b94e4bccede0b48c8
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Jul 7 21:37:22 2021 +0100

    Add 'queued' to DagRunState (#16854)
    
    This change adds 'queued' to DagRunState and improved typing for DagRun state
    
    Co-authored-by: Kaxil Naik <ka...@gmail.com>
    (cherry picked from commit 5a5f30f9133a6c5f0c41886ff9ae80ea53c73989)
---
 airflow/jobs/scheduler_job.py  |  8 ++++----
 airflow/models/dag.py          | 10 +++++-----
 airflow/models/dagrun.py       | 16 +++++++++-------
 airflow/models/taskinstance.py |  4 ++--
 airflow/utils/state.py         |  2 ++
 5 files changed, 22 insertions(+), 18 deletions(-)

diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index b564717..7a37b25 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -54,7 +54,7 @@ from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
 from airflow.utils.session import create_session, provide_session
 from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 TI = models.TaskInstance
@@ -192,7 +192,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def _change_state_for_tis_without_dagrun(
-        self, old_states: List[str], new_state: str, session: Session = None
+        self, old_states: List[TaskInstanceState], new_state: TaskInstanceState, session: Session = None
     ) -> None:
         """
         For all DAG IDs in the DagBag, look for task instances in the
@@ -266,7 +266,7 @@ class SchedulerJob(BaseJob):
 
     @provide_session
     def __get_concurrency_maps(
-        self, states: List[str], session: Session = None
+        self, states: List[TaskInstanceState], session: Session = None
     ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]:
         """
         Get the concurrency maps.
@@ -936,7 +936,7 @@ class SchedulerJob(BaseJob):
             return num_queued_tis
 
     @retry_db_transaction
-    def _get_next_dagruns_to_examine(self, state, session):
+    def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session):
         """Get Next DagRuns to Examine with retries"""
         return DagRun.next_dagruns_to_examine(state, session)
 
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a3d06db..2e66b40 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -75,7 +75,7 @@ from airflow.utils.helpers import validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType, EdgeInfoType
 
 if TYPE_CHECKING:
@@ -1153,7 +1153,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=True,
-        dag_run_state: str = State.QUEUED,
+        dag_run_state: DagRunState = DagRunState.QUEUED,
         dry_run=False,
         session=None,
         get_tis=False,
@@ -1369,7 +1369,7 @@ class DAG(LoggingMixin):
         confirm_prompt=False,
         include_subdags=True,
         include_parentdag=False,
-        dag_run_state=State.QUEUED,
+        dag_run_state=DagRunState.QUEUED,
         dry_run=False,
     ):
         all_tis = []
@@ -1731,7 +1731,7 @@ class DAG(LoggingMixin):
     @provide_session
     def create_dagrun(
         self,
-        state: State,
+        state: DagRunState,
         execution_date: Optional[datetime] = None,
         run_id: Optional[str] = None,
         start_date: Optional[datetime] = None,
@@ -1753,7 +1753,7 @@ class DAG(LoggingMixin):
         :param execution_date: the execution date of this dag run
         :type execution_date: datetime.datetime
         :param state: the state of the dag run
-        :type state: airflow.utils.state.State
+        :type state: airflow.utils.state.DagRunState
         :param start_date: the date this dag run should be evaluated
         :type start_date: datetime
         :param external_trigger: whether this dag run is externally triggered
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index a061dcc..6f47077 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -36,7 +36,7 @@ from airflow.utils import callback_requests, timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks
-from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 if TYPE_CHECKING:
@@ -110,7 +110,7 @@ class DagRun(Base, LoggingMixin):
         start_date: Optional[datetime] = None,
         external_trigger: Optional[bool] = None,
         conf: Optional[Any] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         run_type: Optional[str] = None,
         dag_hash: Optional[str] = None,
         creating_job_id: Optional[int] = None,
@@ -144,7 +144,7 @@ class DagRun(Base, LoggingMixin):
     def get_state(self):
         return self._state
 
-    def set_state(self, state):
+    def set_state(self, state: DagRunState):
         if self._state != state:
             self._state = state
             self.end_date = timezone.utcnow() if self._state in State.finished else None
@@ -170,7 +170,7 @@ class DagRun(Base, LoggingMixin):
     @classmethod
     def next_dagruns_to_examine(
         cls,
-        state: str,
+        state: DagRunState,
         session: Session,
         max_number: Optional[int] = None,
     ):
@@ -219,7 +219,7 @@ class DagRun(Base, LoggingMixin):
         dag_id: Optional[Union[str, List[str]]] = None,
         run_id: Optional[str] = None,
         execution_date: Optional[datetime] = None,
-        state: Optional[str] = None,
+        state: Optional[DagRunState] = None,
         external_trigger: Optional[bool] = None,
         no_backfills: bool = False,
         run_type: Optional[DagRunType] = None,
@@ -239,7 +239,7 @@ class DagRun(Base, LoggingMixin):
         :param execution_date: the execution date
         :type execution_date: datetime.datetime or list[datetime.datetime]
         :param state: the state of the dag run
-        :type state: str
+        :type state: DagRunState
         :param external_trigger: whether this dag run is externally triggered
         :type external_trigger: bool
         :param no_backfills: return no backfills (True), return all (False).
@@ -343,7 +343,9 @@ class DagRun(Base, LoggingMixin):
         return self.dag
 
     @provide_session
-    def get_previous_dagrun(self, state: Optional[str] = None, session: Session = None) -> Optional['DagRun']:
+    def get_previous_dagrun(
+        self, state: Optional[DagRunState] = None, session: Session = None
+    ) -> Optional['DagRun']:
         """The previous DagRun, if there is one"""
         filters = [
             DagRun.dag_id == self.dag_id,
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0e10567..c715f22 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -70,7 +70,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.session import provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.timeout import timeout
 
 try:
@@ -137,7 +137,7 @@ def clear_task_instances(
     session,
     activate_dag_runs=None,
     dag=None,
-    dag_run_state: Union[str, Literal[False]] = State.QUEUED,
+    dag_run_state: Union[DagRunState, Literal[False]] = DagRunState.QUEUED,
 ):
     """
     Clears a set of task instances, but makes sure the running ones
diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index b1b27d0..e95b409 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -60,6 +60,7 @@ class DagRunState(str, Enum):
     same name in TaskInstanceState.
     """
 
+    QUEUED = "queued"
     RUNNING = "running"
     SUCCESS = "success"
     FAILED = "failed"
@@ -92,6 +93,7 @@ class State:
     task_states: Tuple[Optional[TaskInstanceState], ...] = (None,) + tuple(TaskInstanceState)
 
     dag_states: Tuple[DagRunState, ...] = (
+        DagRunState.QUEUED,
         DagRunState.SUCCESS,
         DagRunState.RUNNING,
         DagRunState.FAILED,