You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/12/15 05:19:30 UTC

[airflow] branch main updated: Fix mypy issues in airflow/jobs (#20298)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9844b91  Fix mypy issues in airflow/jobs (#20298)
9844b91 is described below

commit 9844b910470feac9e524832a3affdc657c340bcd
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Dec 15 06:18:57 2021 +0100

    Fix mypy issues in airflow/jobs (#20298)
---
 airflow/jobs/backfill_job.py   | 55 +++++++++++++++++++++--------------------
 airflow/jobs/local_task_job.py |  1 -
 airflow/jobs/scheduler_job.py  | 56 ++++++++++++++++++++++++------------------
 airflow/models/dagrun.py       |  2 +-
 airflow/models/taskinstance.py |  4 +--
 5 files changed, 64 insertions(+), 54 deletions(-)

diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 2d06dc6..de994b0 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -45,7 +45,7 @@ from airflow.timetables.base import DagRunInfo
 from airflow.utils import helpers, timezone
 from airflow.utils.configuration import conf as airflow_conf, tmp_configuration_copy
 from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
 
@@ -210,28 +210,28 @@ class BackfillJob(BaseJob):
         for ti in refreshed_tis:
             # Here we remake the key by subtracting 1 to match in memory information
             reduced_key = ti.key.reduced
-            if ti.state == State.SUCCESS:
+            if ti.state == TaskInstanceState.SUCCESS:
                 ti_status.succeeded.add(reduced_key)
                 self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
                 ti_status.running.pop(reduced_key)
                 continue
-            if ti.state == State.SKIPPED:
+            if ti.state == TaskInstanceState.SKIPPED:
                 ti_status.skipped.add(reduced_key)
                 self.log.debug("Task instance %s skipped. Don't rerun.", ti)
                 ti_status.running.pop(reduced_key)
                 continue
-            if ti.state == State.FAILED:
+            if ti.state == TaskInstanceState.FAILED:
                 self.log.error("Task instance %s failed", ti)
                 ti_status.failed.add(reduced_key)
                 ti_status.running.pop(reduced_key)
                 continue
             # special case: if the task needs to run again put it back
-            if ti.state == State.UP_FOR_RETRY:
+            if ti.state == TaskInstanceState.UP_FOR_RETRY:
                 self.log.warning("Task instance %s is up for retry", ti)
                 ti_status.running.pop(reduced_key)
                 ti_status.to_run[ti.key] = ti
             # special case: if the task needs to be rescheduled put it back
-            elif ti.state == State.UP_FOR_RESCHEDULE:
+            elif ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                 self.log.warning("Task instance %s is up for reschedule", ti)
                 # During handling of reschedule state in ti._handle_reschedule, try number is reduced
                 # by one, so we should not use reduced_key to avoid key error
@@ -256,7 +256,7 @@ class BackfillJob(BaseJob):
         if tis_to_be_scheduled:
             filter_for_tis = TI.filter_for_tis(tis_to_be_scheduled)
             session.query(TI).filter(filter_for_tis).update(
-                values={TI.state: State.SCHEDULED}, synchronize_session=False
+                values={TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False
             )
 
     def _manage_executor_state(self, running):
@@ -280,7 +280,10 @@ class BackfillJob(BaseJob):
 
             self.log.debug("Executor state: %s task %s", state, ti)
 
-            if state in (State.FAILED, State.SUCCESS) and ti.state in self.STATES_COUNT_AS_RUNNING:
+            if (
+                state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS)
+                and ti.state in self.STATES_COUNT_AS_RUNNING
+            ):
                 msg = (
                     f"Executor reports task instance {ti} finished ({state}) although the task says its "
                     f"{ti.state}. Was the task killed externally? Info: {info}"
@@ -313,7 +316,7 @@ class BackfillJob(BaseJob):
         run: Optional[DagRun]
         if runs:
             run = runs[0]
-            if run.state == State.RUNNING:
+            if run.state == DagRunState.RUNNING:
                 respect_dag_max_active_limit = False
         else:
             run = None
@@ -327,7 +330,7 @@ class BackfillJob(BaseJob):
             execution_date=run_date,
             data_interval=dagrun_info.data_interval,
             start_date=timezone.utcnow(),
-            state=State.RUNNING,
+            state=DagRunState.RUNNING,
             external_trigger=False,
             session=session,
             conf=self.conf,
@@ -339,7 +342,7 @@ class BackfillJob(BaseJob):
         run.dag = dag
 
         # explicitly mark as backfill and running
-        run.state = State.RUNNING
+        run.state = DagRunState.RUNNING
         run.run_type = DagRunType.BACKFILL_JOB
         run.verify_integrity(session=session)
         return run
@@ -371,8 +374,8 @@ class BackfillJob(BaseJob):
             for ti in dag_run.get_task_instances():
                 # all tasks part of the backfill are scheduled to run
                 if ti.state == State.NONE:
-                    ti.set_state(State.SCHEDULED, session=session)
-                if ti.state != State.REMOVED:
+                    ti.set_state(TaskInstanceState.SCHEDULED, session=session)
+                if ti.state != TaskInstanceState.REMOVED:
                     tasks_to_run[ti.key] = ti
             session.commit()
         except Exception:
@@ -448,14 +451,14 @@ class BackfillJob(BaseJob):
 
                 # The task was already marked successful or skipped by a
                 # different Job. Don't rerun it.
-                if ti.state == State.SUCCESS:
+                if ti.state == TaskInstanceState.SUCCESS:
                     ti_status.succeeded.add(key)
                     self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
                     ti_status.to_run.pop(key)
                     if key in ti_status.running:
                         ti_status.running.pop(key)
                     return
-                elif ti.state == State.SKIPPED:
+                elif ti.state == TaskInstanceState.SKIPPED:
                     ti_status.skipped.add(key)
                     self.log.debug("Task instance %s skipped. Don't rerun.", ti)
                     ti_status.to_run.pop(key)
@@ -469,18 +472,18 @@ class BackfillJob(BaseJob):
                     self.log.warning(
                         "FIXME: Task instance %s state was set to None externally. This should not happen", ti
                     )
-                    ti.set_state(State.SCHEDULED, session=session)
+                    ti.set_state(TaskInstanceState.SCHEDULED, session=session)
                 if self.rerun_failed_tasks:
                     # Rerun failed tasks or upstreamed failed tasks
-                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
+                    if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
                         self.log.error("Task instance %s with state %s", ti, ti.state)
                         if key in ti_status.running:
                             ti_status.running.pop(key)
                         # Reset the failed task in backfill to scheduled state
-                        ti.set_state(State.SCHEDULED, session=session)
+                        ti.set_state(TaskInstanceState.SCHEDULED, session=session)
                 else:
                     # Default behaviour which works for subdag.
-                    if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
+                    if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
                         self.log.error("Task instance %s with state %s", ti, ti.state)
                         ti_status.failed.add(key)
                         ti_status.to_run.pop(key)
@@ -511,7 +514,7 @@ class BackfillJob(BaseJob):
                     else:
                         self.log.debug('Sending %s to executor', ti)
                         # Skip scheduled state, we are executing immediately
-                        ti.state = State.QUEUED
+                        ti.state = TaskInstanceState.QUEUED
                         ti.queued_by_job_id = self.id
                         ti.queued_dttm = timezone.utcnow()
                         session.merge(ti)
@@ -537,7 +540,7 @@ class BackfillJob(BaseJob):
                     session.commit()
                     return
 
-                if ti.state == State.UPSTREAM_FAILED:
+                if ti.state == TaskInstanceState.UPSTREAM_FAILED:
                     self.log.error("Task instance %s upstream failed", ti)
                     ti_status.failed.add(key)
                     ti_status.to_run.pop(key)
@@ -546,7 +549,7 @@ class BackfillJob(BaseJob):
                     return
 
                 # special case
-                if ti.state == State.UP_FOR_RETRY:
+                if ti.state == TaskInstanceState.UP_FOR_RETRY:
                     self.log.debug("Task instance %s retry period not expired yet", ti)
                     if key in ti_status.running:
                         ti_status.running.pop(key)
@@ -554,7 +557,7 @@ class BackfillJob(BaseJob):
                     return
 
                 # special case
-                if ti.state == State.UP_FOR_RESCHEDULE:
+                if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
                     self.log.debug("Task instance %s reschedule period not expired yet", ti)
                     if key in ti_status.running:
                         ti_status.running.pop(key)
@@ -752,7 +755,7 @@ class BackfillJob(BaseJob):
         for dag_run in dag_runs:
             dag_run.update_state()
             if dag_run.state not in State.finished:
-                dag_run.set_state(State.FAILED)
+                dag_run.set_state(DagRunState.FAILED)
             session.merge(dag_run)
 
     @provide_session
@@ -866,13 +869,13 @@ class BackfillJob(BaseJob):
         running_tis = self.executor.running
 
         # Can't use an update here since it doesn't support joins.
-        resettable_states = [State.SCHEDULED, State.QUEUED]
+        resettable_states = [TaskInstanceState.SCHEDULED, TaskInstanceState.QUEUED]
         if filter_by_dag_run is None:
             resettable_tis = (
                 session.query(TaskInstance)
                 .join(TaskInstance.dag_run)
                 .filter(
-                    DagRun.state == State.RUNNING,
+                    DagRun.state == DagRunState.RUNNING,
                     DagRun.run_type != DagRunType.BACKFILL_JOB,
                     TaskInstance.state.in_(resettable_states),
                 )
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 0acf46f..67574d8 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -66,7 +66,6 @@ class LocalTaskJob(BaseJob):
         self.pickle_id = pickle_id
         self.mark_success = mark_success
         self.external_executor_id = external_executor_id
-        self.task_runner = None
 
         # terminating state is used so that a job don't try to
         # terminate multiple times
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 51d7b41..77aa33f 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -16,7 +16,6 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-import datetime
 import itertools
 import logging
 import multiprocessing
@@ -107,7 +106,7 @@ class SchedulerJob(BaseJob):
         num_times_parse_dags: int = -1,
         scheduler_idle_sleep_time: float = conf.getfloat('scheduler', 'scheduler_idle_sleep_time'),
         do_pickle: bool = False,
-        log: logging.Logger = None,
+        log: Optional[logging.Logger] = None,
         processor_poll_interval: Optional[float] = None,
         *args,
         **kwargs,
@@ -276,7 +275,7 @@ class SchedulerJob(BaseJob):
             .filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
             .join(TI.dag_model)
             .filter(not_(DM.is_paused))
-            .filter(TI.state == State.SCHEDULED)
+            .filter(TI.state == TaskInstanceState.SCHEDULED)
             .options(selectinload('dag_model'))
             .order_by(-TI.priority_weight, DR.execution_date)
         )
@@ -430,7 +429,11 @@ class SchedulerJob(BaseJob):
             session.query(TI).filter(filter_for_tis).update(
                 # TODO[ha]: should we use func.now()? How does that work with DB timezone
                 # on mysql when it's not UTC?
-                {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id},
+                {
+                    TI.state: TaskInstanceState.QUEUED,
+                    TI.queued_dttm: timezone.utcnow(),
+                    TI.queued_by_job_id: self.id,
+                },
                 synchronize_session=False,
             )
 
@@ -506,7 +509,7 @@ class SchedulerJob(BaseJob):
         """Respond to executor events."""
         if not self.processor_agent:
             raise ValueError("Processor agent is not started.")
-        ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {}
+        ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str], int] = {}
         event_buffer = self.executor.get_event_buffer()
         tis_with_right_state: List[TaskInstanceKey] = []
 
@@ -525,7 +528,7 @@ class SchedulerJob(BaseJob):
                 state,
                 ti_key.try_number,
             )
-            if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+            if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED):
                 tis_with_right_state.append(ti_key)
 
         # Return if no finished tasks
@@ -549,7 +552,7 @@ class SchedulerJob(BaseJob):
             state, info = event_buffer.pop(buffer_key)
 
             # TODO: should we fail RUNNING as well, as we do in Backfills?
-            if state == State.QUEUED:
+            if state == TaskInstanceState.QUEUED:
                 ti.external_executor_id = info
                 self.log.info("Setting external_id for %s to %s", ti, info)
                 continue
@@ -743,7 +746,7 @@ class SchedulerJob(BaseJob):
                 # If the scheduler is doing things, don't sleep. This means when there is work to do, the
                 # scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU
                 # usage when "idle"
-                time.sleep(min(self._scheduler_idle_sleep_time, next_event))
+                time.sleep(min(self._scheduler_idle_sleep_time, next_event if next_event else 0))
 
             if loop_count >= self.num_runs > 0:
                 self.log.info(
@@ -799,7 +802,7 @@ class SchedulerJob(BaseJob):
 
             self._start_queued_dagruns(session)
             guard.commit()
-            dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session)
+            dag_runs = self._get_next_dagruns_to_examine(DagRunState.RUNNING, session)
             # Bulk fetch the currently active dag runs for the dags we are
             # examining, rather than making one query per DagRun
 
@@ -922,7 +925,7 @@ class SchedulerJob(BaseJob):
                 dag.create_dagrun(
                     run_type=DagRunType.SCHEDULED,
                     execution_date=dag_model.next_dagrun,
-                    state=State.QUEUED,
+                    state=DagRunState.QUEUED,
                     data_interval=data_interval,
                     external_trigger=False,
                     session=session,
@@ -951,9 +954,9 @@ class SchedulerJob(BaseJob):
     def _start_queued_dagruns(
         self,
         session: Session,
-    ) -> int:
+    ) -> None:
         """Find DagRuns in queued state and decide moving them to running state"""
-        dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)
+        dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session)
 
         active_runs_of_dags = defaultdict(
             int,
@@ -961,7 +964,7 @@ class SchedulerJob(BaseJob):
         )
 
         def _update_state(dag: DAG, dag_run: DagRun):
-            dag_run.state = State.RUNNING
+            dag_run.state = DagRunState.RUNNING
             dag_run.start_date = timezone.utcnow()
             if dag.timetable.periodic:
                 # TODO: Logically, this should be DagRunInfo.run_after, but the
@@ -1003,11 +1006,13 @@ class SchedulerJob(BaseJob):
         :param dag_run: The DagRun to schedule
         :return: Callback that needs to be executed
         """
+        callback: Optional[DagCallbackRequest] = None
+
         dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
 
         if not dag:
             self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id)
-            return 0
+            return callback
         dag_model = DM.get_dagmodel(dag.dag_id, session)
 
         if (
@@ -1015,7 +1020,7 @@ class SchedulerJob(BaseJob):
             and dag.dagrun_timeout
             and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout
         ):
-            dag_run.set_state(State.FAILED)
+            dag_run.set_state(DagRunState.FAILED)
             unfinished_task_instances = (
                 session.query(TI)
                 .filter(TI.dag_id == dag_run.dag_id)
@@ -1023,7 +1028,7 @@ class SchedulerJob(BaseJob):
                 .filter(TI.state.in_(State.unfinished))
             )
             for task_instance in unfinished_task_instances:
-                task_instance.state = State.SKIPPED
+                task_instance.state = TaskInstanceState.SKIPPED
                 session.merge(task_instance)
             session.flush()
             self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
@@ -1042,12 +1047,12 @@ class SchedulerJob(BaseJob):
 
             # Send SLA & DAG Success/Failure Callbacks to be executed
             self._send_dag_callbacks_to_processor(dag, callback_to_execute)
-
-            return 0
+            # Because we send the callback here, we need to return None
+            return callback
 
         if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates:
             self.log.error("Execution date is in future: %s", dag_run.execution_date)
-            return 0
+            return callback
 
         self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
         # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
@@ -1111,8 +1116,8 @@ class SchedulerJob(BaseJob):
         pools = models.Pool.slots_stats(session=session)
         for pool_name, slot_stats in pools.items():
             Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
-            Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED])  # type: ignore
-            Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING])  # type: ignore
+            Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[TaskInstanceState.QUEUED])
+            Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[TaskInstanceState.RUNNING])
 
     @provide_session
     def heartbeat_callback(self, session: Session = None) -> None:
@@ -1153,7 +1158,7 @@ class SchedulerJob(BaseJob):
                         self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
                         Stats.incr(self.__class__.__name__.lower() + '_end', num_failed)
 
-                    resettable_states = [State.QUEUED, State.RUNNING]
+                    resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING]
                     query = (
                         session.query(TI)
                         .filter(TI.state.in_(resettable_states))
@@ -1214,12 +1219,15 @@ class SchedulerJob(BaseJob):
         """
         num_timed_out_tasks = (
             session.query(TaskInstance)
-            .filter(TaskInstance.state == State.DEFERRED, TaskInstance.trigger_timeout < timezone.utcnow())
+            .filter(
+                TaskInstance.state == TaskInstanceState.DEFERRED,
+                TaskInstance.trigger_timeout < timezone.utcnow(),
+            )
             .update(
                 # We have to schedule these to fail themselves so it doesn't
                 # happen inside the scheduler.
                 {
-                    "state": State.SCHEDULED,
+                    "state": TaskInstanceState.SCHEDULED,
                     "next_method": "__fail__",
                     "next_kwargs": {"error": "Trigger/execution timeout"},
                     "trigger_id": None,
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index e0210eb..83d3507 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -417,7 +417,7 @@ class DagRun(Base, LoggingMixin):
                 tis = tis.filter(TI.state == state)
             else:
                 # this is required to deal with NULL values
-                if TaskInstanceState.NONE in state:
+                if State.NONE in state:
                     if all(x is None for x in state):
                         tis = tis.filter(TI.state.is_(None))
                     else:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index bc432bb..d8463b9 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -820,7 +820,7 @@ class TaskInstance(Base, LoggingMixin):
         return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
 
     @provide_session
-    def set_state(self, state: str, session=NEW_SESSION):
+    def set_state(self, state: Optional[str], session=NEW_SESSION):
         """
         Set TaskInstance state.
 
@@ -1691,7 +1691,7 @@ class TaskInstance(Base, LoggingMixin):
     @provide_session
     def handle_failure(
         self,
-        error: Union[str, BaseException],
+        error: Optional[Union[str, BaseException]] = None,
         test_mode: Optional[bool] = None,
         force_fail: bool = False,
         error_file: Optional[str] = None,