You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2021/12/15 05:19:30 UTC
[airflow] branch main updated: Fix mypy issues in airflow/jobs (#20298)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9844b91 Fix mypy issues in airflow/jobs (#20298)
9844b91 is described below
commit 9844b910470feac9e524832a3affdc657c340bcd
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Wed Dec 15 06:18:57 2021 +0100
Fix mypy issues in airflow/jobs (#20298)
---
airflow/jobs/backfill_job.py | 55 +++++++++++++++++++++--------------------
airflow/jobs/local_task_job.py | 1 -
airflow/jobs/scheduler_job.py | 56 ++++++++++++++++++++++++------------------
airflow/models/dagrun.py | 2 +-
airflow/models/taskinstance.py | 4 +--
5 files changed, 64 insertions(+), 54 deletions(-)
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 2d06dc6..de994b0 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -45,7 +45,7 @@ from airflow.timetables.base import DagRunInfo
from airflow.utils import helpers, timezone
from airflow.utils.configuration import conf as airflow_conf, tmp_configuration_copy
from airflow.utils.session import provide_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
@@ -210,28 +210,28 @@ class BackfillJob(BaseJob):
for ti in refreshed_tis:
# Here we remake the key by subtracting 1 to match in memory information
reduced_key = ti.key.reduced
- if ti.state == State.SUCCESS:
+ if ti.state == TaskInstanceState.SUCCESS:
ti_status.succeeded.add(reduced_key)
self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
ti_status.running.pop(reduced_key)
continue
- if ti.state == State.SKIPPED:
+ if ti.state == TaskInstanceState.SKIPPED:
ti_status.skipped.add(reduced_key)
self.log.debug("Task instance %s skipped. Don't rerun.", ti)
ti_status.running.pop(reduced_key)
continue
- if ti.state == State.FAILED:
+ if ti.state == TaskInstanceState.FAILED:
self.log.error("Task instance %s failed", ti)
ti_status.failed.add(reduced_key)
ti_status.running.pop(reduced_key)
continue
# special case: if the task needs to run again put it back
- if ti.state == State.UP_FOR_RETRY:
+ if ti.state == TaskInstanceState.UP_FOR_RETRY:
self.log.warning("Task instance %s is up for retry", ti)
ti_status.running.pop(reduced_key)
ti_status.to_run[ti.key] = ti
# special case: if the task needs to be rescheduled put it back
- elif ti.state == State.UP_FOR_RESCHEDULE:
+ elif ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
self.log.warning("Task instance %s is up for reschedule", ti)
# During handling of reschedule state in ti._handle_reschedule, try number is reduced
# by one, so we should not use reduced_key to avoid key error
@@ -256,7 +256,7 @@ class BackfillJob(BaseJob):
if tis_to_be_scheduled:
filter_for_tis = TI.filter_for_tis(tis_to_be_scheduled)
session.query(TI).filter(filter_for_tis).update(
- values={TI.state: State.SCHEDULED}, synchronize_session=False
+ values={TI.state: TaskInstanceState.SCHEDULED}, synchronize_session=False
)
def _manage_executor_state(self, running):
@@ -280,7 +280,10 @@ class BackfillJob(BaseJob):
self.log.debug("Executor state: %s task %s", state, ti)
- if state in (State.FAILED, State.SUCCESS) and ti.state in self.STATES_COUNT_AS_RUNNING:
+ if (
+ state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS)
+ and ti.state in self.STATES_COUNT_AS_RUNNING
+ ):
msg = (
f"Executor reports task instance {ti} finished ({state}) although the task says its "
f"{ti.state}. Was the task killed externally? Info: {info}"
@@ -313,7 +316,7 @@ class BackfillJob(BaseJob):
run: Optional[DagRun]
if runs:
run = runs[0]
- if run.state == State.RUNNING:
+ if run.state == DagRunState.RUNNING:
respect_dag_max_active_limit = False
else:
run = None
@@ -327,7 +330,7 @@ class BackfillJob(BaseJob):
execution_date=run_date,
data_interval=dagrun_info.data_interval,
start_date=timezone.utcnow(),
- state=State.RUNNING,
+ state=DagRunState.RUNNING,
external_trigger=False,
session=session,
conf=self.conf,
@@ -339,7 +342,7 @@ class BackfillJob(BaseJob):
run.dag = dag
# explicitly mark as backfill and running
- run.state = State.RUNNING
+ run.state = DagRunState.RUNNING
run.run_type = DagRunType.BACKFILL_JOB
run.verify_integrity(session=session)
return run
@@ -371,8 +374,8 @@ class BackfillJob(BaseJob):
for ti in dag_run.get_task_instances():
# all tasks part of the backfill are scheduled to run
if ti.state == State.NONE:
- ti.set_state(State.SCHEDULED, session=session)
- if ti.state != State.REMOVED:
+ ti.set_state(TaskInstanceState.SCHEDULED, session=session)
+ if ti.state != TaskInstanceState.REMOVED:
tasks_to_run[ti.key] = ti
session.commit()
except Exception:
@@ -448,14 +451,14 @@ class BackfillJob(BaseJob):
# The task was already marked successful or skipped by a
# different Job. Don't rerun it.
- if ti.state == State.SUCCESS:
+ if ti.state == TaskInstanceState.SUCCESS:
ti_status.succeeded.add(key)
self.log.debug("Task instance %s succeeded. Don't rerun.", ti)
ti_status.to_run.pop(key)
if key in ti_status.running:
ti_status.running.pop(key)
return
- elif ti.state == State.SKIPPED:
+ elif ti.state == TaskInstanceState.SKIPPED:
ti_status.skipped.add(key)
self.log.debug("Task instance %s skipped. Don't rerun.", ti)
ti_status.to_run.pop(key)
@@ -469,18 +472,18 @@ class BackfillJob(BaseJob):
self.log.warning(
"FIXME: Task instance %s state was set to None externally. This should not happen", ti
)
- ti.set_state(State.SCHEDULED, session=session)
+ ti.set_state(TaskInstanceState.SCHEDULED, session=session)
if self.rerun_failed_tasks:
# Rerun failed tasks or upstreamed failed tasks
- if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
+ if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
self.log.error("Task instance %s with state %s", ti, ti.state)
if key in ti_status.running:
ti_status.running.pop(key)
# Reset the failed task in backfill to scheduled state
- ti.set_state(State.SCHEDULED, session=session)
+ ti.set_state(TaskInstanceState.SCHEDULED, session=session)
else:
# Default behaviour which works for subdag.
- if ti.state in (State.FAILED, State.UPSTREAM_FAILED):
+ if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
self.log.error("Task instance %s with state %s", ti, ti.state)
ti_status.failed.add(key)
ti_status.to_run.pop(key)
@@ -511,7 +514,7 @@ class BackfillJob(BaseJob):
else:
self.log.debug('Sending %s to executor', ti)
# Skip scheduled state, we are executing immediately
- ti.state = State.QUEUED
+ ti.state = TaskInstanceState.QUEUED
ti.queued_by_job_id = self.id
ti.queued_dttm = timezone.utcnow()
session.merge(ti)
@@ -537,7 +540,7 @@ class BackfillJob(BaseJob):
session.commit()
return
- if ti.state == State.UPSTREAM_FAILED:
+ if ti.state == TaskInstanceState.UPSTREAM_FAILED:
self.log.error("Task instance %s upstream failed", ti)
ti_status.failed.add(key)
ti_status.to_run.pop(key)
@@ -546,7 +549,7 @@ class BackfillJob(BaseJob):
return
# special case
- if ti.state == State.UP_FOR_RETRY:
+ if ti.state == TaskInstanceState.UP_FOR_RETRY:
self.log.debug("Task instance %s retry period not expired yet", ti)
if key in ti_status.running:
ti_status.running.pop(key)
@@ -554,7 +557,7 @@ class BackfillJob(BaseJob):
return
# special case
- if ti.state == State.UP_FOR_RESCHEDULE:
+ if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE:
self.log.debug("Task instance %s reschedule period not expired yet", ti)
if key in ti_status.running:
ti_status.running.pop(key)
@@ -752,7 +755,7 @@ class BackfillJob(BaseJob):
for dag_run in dag_runs:
dag_run.update_state()
if dag_run.state not in State.finished:
- dag_run.set_state(State.FAILED)
+ dag_run.set_state(DagRunState.FAILED)
session.merge(dag_run)
@provide_session
@@ -866,13 +869,13 @@ class BackfillJob(BaseJob):
running_tis = self.executor.running
# Can't use an update here since it doesn't support joins.
- resettable_states = [State.SCHEDULED, State.QUEUED]
+ resettable_states = [TaskInstanceState.SCHEDULED, TaskInstanceState.QUEUED]
if filter_by_dag_run is None:
resettable_tis = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
.filter(
- DagRun.state == State.RUNNING,
+ DagRun.state == DagRunState.RUNNING,
DagRun.run_type != DagRunType.BACKFILL_JOB,
TaskInstance.state.in_(resettable_states),
)
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 0acf46f..67574d8 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -66,7 +66,6 @@ class LocalTaskJob(BaseJob):
self.pickle_id = pickle_id
self.mark_success = mark_success
self.external_executor_id = external_executor_id
- self.task_runner = None
# terminating state is used so that a job don't try to
# terminate multiple times
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 51d7b41..77aa33f 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
#
-import datetime
import itertools
import logging
import multiprocessing
@@ -107,7 +106,7 @@ class SchedulerJob(BaseJob):
num_times_parse_dags: int = -1,
scheduler_idle_sleep_time: float = conf.getfloat('scheduler', 'scheduler_idle_sleep_time'),
do_pickle: bool = False,
- log: logging.Logger = None,
+ log: Optional[logging.Logger] = None,
processor_poll_interval: Optional[float] = None,
*args,
**kwargs,
@@ -276,7 +275,7 @@ class SchedulerJob(BaseJob):
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.filter(not_(DM.is_paused))
- .filter(TI.state == State.SCHEDULED)
+ .filter(TI.state == TaskInstanceState.SCHEDULED)
.options(selectinload('dag_model'))
.order_by(-TI.priority_weight, DR.execution_date)
)
@@ -430,7 +429,11 @@ class SchedulerJob(BaseJob):
session.query(TI).filter(filter_for_tis).update(
# TODO[ha]: should we use func.now()? How does that work with DB timezone
# on mysql when it's not UTC?
- {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id},
+ {
+ TI.state: TaskInstanceState.QUEUED,
+ TI.queued_dttm: timezone.utcnow(),
+ TI.queued_by_job_id: self.id,
+ },
synchronize_session=False,
)
@@ -506,7 +509,7 @@ class SchedulerJob(BaseJob):
"""Respond to executor events."""
if not self.processor_agent:
raise ValueError("Processor agent is not started.")
- ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {}
+ ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str], int] = {}
event_buffer = self.executor.get_event_buffer()
tis_with_right_state: List[TaskInstanceKey] = []
@@ -525,7 +528,7 @@ class SchedulerJob(BaseJob):
state,
ti_key.try_number,
)
- if state in (State.FAILED, State.SUCCESS, State.QUEUED):
+ if state in (TaskInstanceState.FAILED, TaskInstanceState.SUCCESS, TaskInstanceState.QUEUED):
tis_with_right_state.append(ti_key)
# Return if no finished tasks
@@ -549,7 +552,7 @@ class SchedulerJob(BaseJob):
state, info = event_buffer.pop(buffer_key)
# TODO: should we fail RUNNING as well, as we do in Backfills?
- if state == State.QUEUED:
+ if state == TaskInstanceState.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
continue
@@ -743,7 +746,7 @@ class SchedulerJob(BaseJob):
# If the scheduler is doing things, don't sleep. This means when there is work to do, the
# scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU
# usage when "idle"
- time.sleep(min(self._scheduler_idle_sleep_time, next_event))
+ time.sleep(min(self._scheduler_idle_sleep_time, next_event if next_event else 0))
if loop_count >= self.num_runs > 0:
self.log.info(
@@ -799,7 +802,7 @@ class SchedulerJob(BaseJob):
self._start_queued_dagruns(session)
guard.commit()
- dag_runs = self._get_next_dagruns_to_examine(State.RUNNING, session)
+ dag_runs = self._get_next_dagruns_to_examine(DagRunState.RUNNING, session)
# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
@@ -922,7 +925,7 @@ class SchedulerJob(BaseJob):
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=dag_model.next_dagrun,
- state=State.QUEUED,
+ state=DagRunState.QUEUED,
data_interval=data_interval,
external_trigger=False,
session=session,
@@ -951,9 +954,9 @@ class SchedulerJob(BaseJob):
def _start_queued_dagruns(
self,
session: Session,
- ) -> int:
+ ) -> None:
"""Find DagRuns in queued state and decide moving them to running state"""
- dag_runs = self._get_next_dagruns_to_examine(State.QUEUED, session)
+ dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session)
active_runs_of_dags = defaultdict(
int,
@@ -961,7 +964,7 @@ class SchedulerJob(BaseJob):
)
def _update_state(dag: DAG, dag_run: DagRun):
- dag_run.state = State.RUNNING
+ dag_run.state = DagRunState.RUNNING
dag_run.start_date = timezone.utcnow()
if dag.timetable.periodic:
# TODO: Logically, this should be DagRunInfo.run_after, but the
@@ -1003,11 +1006,13 @@ class SchedulerJob(BaseJob):
:param dag_run: The DagRun to schedule
:return: Callback that needs to be executed
"""
+ callback: Optional[DagCallbackRequest] = None
+
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
if not dag:
self.log.error("Couldn't find dag %s in DagBag/DB!", dag_run.dag_id)
- return 0
+ return callback
dag_model = DM.get_dagmodel(dag.dag_id, session)
if (
@@ -1015,7 +1020,7 @@ class SchedulerJob(BaseJob):
and dag.dagrun_timeout
and dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout
):
- dag_run.set_state(State.FAILED)
+ dag_run.set_state(DagRunState.FAILED)
unfinished_task_instances = (
session.query(TI)
.filter(TI.dag_id == dag_run.dag_id)
@@ -1023,7 +1028,7 @@ class SchedulerJob(BaseJob):
.filter(TI.state.in_(State.unfinished))
)
for task_instance in unfinished_task_instances:
- task_instance.state = State.SKIPPED
+ task_instance.state = TaskInstanceState.SKIPPED
session.merge(task_instance)
session.flush()
self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id)
@@ -1042,12 +1047,12 @@ class SchedulerJob(BaseJob):
# Send SLA & DAG Success/Failure Callbacks to be executed
self._send_dag_callbacks_to_processor(dag, callback_to_execute)
-
- return 0
+ # Because we send the callback here, we need to return None
+ return callback
if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates:
self.log.error("Execution date is in future: %s", dag_run.execution_date)
- return 0
+ return callback
self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
@@ -1111,8 +1116,8 @@ class SchedulerJob(BaseJob):
pools = models.Pool.slots_stats(session=session)
for pool_name, slot_stats in pools.items():
Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
- Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) # type: ignore
- Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) # type: ignore
+ Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[TaskInstanceState.QUEUED])
+ Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[TaskInstanceState.RUNNING])
@provide_session
def heartbeat_callback(self, session: Session = None) -> None:
@@ -1153,7 +1158,7 @@ class SchedulerJob(BaseJob):
self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
Stats.incr(self.__class__.__name__.lower() + '_end', num_failed)
- resettable_states = [State.QUEUED, State.RUNNING]
+ resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING]
query = (
session.query(TI)
.filter(TI.state.in_(resettable_states))
@@ -1214,12 +1219,15 @@ class SchedulerJob(BaseJob):
"""
num_timed_out_tasks = (
session.query(TaskInstance)
- .filter(TaskInstance.state == State.DEFERRED, TaskInstance.trigger_timeout < timezone.utcnow())
+ .filter(
+ TaskInstance.state == TaskInstanceState.DEFERRED,
+ TaskInstance.trigger_timeout < timezone.utcnow(),
+ )
.update(
# We have to schedule these to fail themselves so it doesn't
# happen inside the scheduler.
{
- "state": State.SCHEDULED,
+ "state": TaskInstanceState.SCHEDULED,
"next_method": "__fail__",
"next_kwargs": {"error": "Trigger/execution timeout"},
"trigger_id": None,
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index e0210eb..83d3507 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -417,7 +417,7 @@ class DagRun(Base, LoggingMixin):
tis = tis.filter(TI.state == state)
else:
# this is required to deal with NULL values
- if TaskInstanceState.NONE in state:
+ if State.NONE in state:
if all(x is None for x in state):
tis = tis.filter(TI.state.is_(None))
else:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index bc432bb..d8463b9 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -820,7 +820,7 @@ class TaskInstance(Base, LoggingMixin):
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number)
@provide_session
- def set_state(self, state: str, session=NEW_SESSION):
+ def set_state(self, state: Optional[str], session=NEW_SESSION):
"""
Set TaskInstance state.
@@ -1691,7 +1691,7 @@ class TaskInstance(Base, LoggingMixin):
@provide_session
def handle_failure(
self,
- error: Union[str, BaseException],
+ error: Optional[Union[str, BaseException]] = None,
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,