You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/10/18 13:10:31 UTC
[airflow] 15/41: Remove DAG parsing from StandardTaskRunner (#26750)
This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 3e090209461b603fa5fb695ee83a02236e08d192
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Sep 29 22:51:46 2022 +0100
Remove DAG parsing from StandardTaskRunner (#26750)
This makes the starting of StandardTaskRunner faster as the parsing of DAG will now be done once at task_run.
Also removed parsing of example dags when running a task
(cherry picked from commit ce071172e22fba018889db7dcfac4a4d0fc41cda)
---
airflow/cli/commands/task_command.py | 10 +---
airflow/task/task_runner/standard_task_runner.py | 8 ++-
airflow/utils/cli.py | 19 +++----
tests/cli/commands/test_task_command.py | 64 ------------------------
4 files changed, 10 insertions(+), 91 deletions(-)
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 9caa8bb4bd..982aa31fd5 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -47,7 +47,6 @@ from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
- get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
@@ -364,14 +363,7 @@ def task_run(args, dag=None):
print(f'Loading pickle id: {args.pickle}')
dag = get_dag_by_pickle(args.pickle)
elif not dag:
- if args.local:
- try:
- dag = get_dag_by_deserialization(args.dag_id)
- except AirflowException:
- print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file')
- dag = get_dag(args.subdir, args.dag_id)
- else:
- dag = get_dag(args.subdir, args.dag_id)
+ dag = get_dag(args.subdir, args.dag_id, include_examples=False)
else:
# Use DAG from parameter
pass
diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py
index 3c13a28df4..27fd11b1b1 100644
--- a/airflow/task/task_runner/standard_task_runner.py
+++ b/airflow/task/task_runner/standard_task_runner.py
@@ -36,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner):
def __init__(self, local_task_job):
super().__init__(local_task_job)
self._rc = None
+ self.dag = local_task_job.task_instance.task.dag
def start(self):
if CAN_FORK and not self.run_as_user:
@@ -64,7 +65,6 @@ class StandardTaskRunner(BaseTaskRunner):
from airflow import settings
from airflow.cli.cli_parser import get_parser
from airflow.sentry import Sentry
- from airflow.utils.cli import get_dag
# Force a new SQLAlchemy session. We can't share open DB handles
# between process. The cli code will re-create this as part of its
@@ -92,10 +92,8 @@ class StandardTaskRunner(BaseTaskRunner):
dag_id=self._task_instance.dag_id,
task_id=self._task_instance.task_id,
):
- # parse dag file since `airflow tasks run --local` does not parse dag file
- dag = get_dag(args.subdir, args.dag_id)
- args.func(args, dag=dag)
- return_code = 0
+ args.func(args, dag=self.dag)
+ return_code = 0
except Exception as exc:
return_code = 1
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index 87313f46f5..522bf963e2 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -33,6 +33,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Callable, TypeVar, cast
from airflow import settings
+from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils import cli_action_loggers
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
@@ -205,7 +206,9 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None
-def get_dag(subdir: str | None, dag_id: str) -> DAG:
+def get_dag(
+ subdir: str | None, dag_id: str, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES')
+) -> DAG:
"""
Returns DAG of a given dag_id
@@ -216,11 +219,11 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
from airflow.models import DagBag
first_path = process_subdir(subdir)
- dagbag = DagBag(first_path)
+ dagbag = DagBag(first_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
- dagbag = DagBag(dag_folder=fallback_path)
+ dagbag = DagBag(dag_folder=fallback_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
@@ -228,16 +231,6 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
return dagbag.dags[dag_id]
-def get_dag_by_deserialization(dag_id: str) -> DAG:
- from airflow.models.serialized_dag import SerializedDagModel
-
- dag_model = SerializedDagModel.get(dag_id)
- if dag_model is None:
- raise AirflowException(f"Serialized DAG: {dag_id} could not be found")
-
- return dag_model.dag
-
-
def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False):
"""Returns DAG(s) matching a given regex or dag_id"""
from airflow.models import DagBag
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 03b9259f8d..802140755f 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -159,38 +159,6 @@ class TestCliTasks:
task_command.task_test(args)
assert capsys.readouterr().out.endswith(f"{not_password}\n")
- @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
- @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
- def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization):
- """
- Test using serialized dag for local task_run
- """
- task_id = self.dag.task_ids[0]
- args = [
- 'tasks',
- 'run',
- '--ignore-all-dependencies',
- '--local',
- self.dag_id,
- task_id,
- self.run_id,
- ]
- mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag
-
- task_command.task_run(self.parser.parse_args(args))
- mock_local_job.assert_called_once_with(
- task_instance=mock.ANY,
- mark_success=False,
- ignore_all_deps=True,
- ignore_depends_on_past=False,
- ignore_task_deps=False,
- ignore_ti_state=False,
- pickle_id=None,
- pool=None,
- external_executor_id=None,
- )
- mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)
-
def test_cli_test_different_path(self, session):
"""
When thedag processor has a different dags folder
@@ -265,38 +233,6 @@ class TestCliTasks:
# verify that the file was in different location when run
assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()
- @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
- @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
- def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
- """
- Fallback to parse dag_file when serialized dag does not exist in the db
- """
- task_id = self.dag.task_ids[0]
- args = [
- 'tasks',
- 'run',
- '--ignore-all-dependencies',
- '--local',
- self.dag_id,
- task_id,
- self.run_id,
- ]
- mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found'))
-
- task_command.task_run(self.parser.parse_args(args))
- mock_local_job.assert_called_once_with(
- task_instance=mock.ANY,
- mark_success=False,
- ignore_all_deps=True,
- ignore_depends_on_past=False,
- ignore_task_deps=False,
- ignore_ti_state=False,
- pickle_id=None,
- pool=None,
- external_executor_id=None,
- )
- mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)
-
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_with_existing_dag_run_id(self, mock_local_job):
"""