You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/10/18 13:10:31 UTC

[airflow] 15/41: Remove DAG parsing from StandardTaskRunner (#26750)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 3e090209461b603fa5fb695ee83a02236e08d192
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Sep 29 22:51:46 2022 +0100

    Remove DAG parsing from StandardTaskRunner (#26750)
    
    This makes the starting of StandardTaskRunner faster as the parsing of DAG will now be done once at task_run.
    Also removed parsing of example dags when running a task
    
    (cherry picked from commit ce071172e22fba018889db7dcfac4a4d0fc41cda)
---
 airflow/cli/commands/task_command.py             | 10 +---
 airflow/task/task_runner/standard_task_runner.py |  8 ++-
 airflow/utils/cli.py                             | 19 +++----
 tests/cli/commands/test_task_command.py          | 64 ------------------------
 4 files changed, 10 insertions(+), 91 deletions(-)

diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index 9caa8bb4bd..982aa31fd5 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -47,7 +47,6 @@ from airflow.typing_compat import Literal
 from airflow.utils import cli as cli_utils
 from airflow.utils.cli import (
     get_dag,
-    get_dag_by_deserialization,
     get_dag_by_file_location,
     get_dag_by_pickle,
     get_dags,
@@ -364,14 +363,7 @@ def task_run(args, dag=None):
         print(f'Loading pickle id: {args.pickle}')
         dag = get_dag_by_pickle(args.pickle)
     elif not dag:
-        if args.local:
-            try:
-                dag = get_dag_by_deserialization(args.dag_id)
-            except AirflowException:
-                print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file')
-                dag = get_dag(args.subdir, args.dag_id)
-        else:
-            dag = get_dag(args.subdir, args.dag_id)
+        dag = get_dag(args.subdir, args.dag_id, include_examples=False)
     else:
         # Use DAG from parameter
         pass
diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py
index 3c13a28df4..27fd11b1b1 100644
--- a/airflow/task/task_runner/standard_task_runner.py
+++ b/airflow/task/task_runner/standard_task_runner.py
@@ -36,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner):
     def __init__(self, local_task_job):
         super().__init__(local_task_job)
         self._rc = None
+        self.dag = local_task_job.task_instance.task.dag
 
     def start(self):
         if CAN_FORK and not self.run_as_user:
@@ -64,7 +65,6 @@ class StandardTaskRunner(BaseTaskRunner):
             from airflow import settings
             from airflow.cli.cli_parser import get_parser
             from airflow.sentry import Sentry
-            from airflow.utils.cli import get_dag
 
             # Force a new SQLAlchemy session. We can't share open DB handles
             # between process. The cli code will re-create this as part of its
@@ -92,10 +92,8 @@ class StandardTaskRunner(BaseTaskRunner):
                     dag_id=self._task_instance.dag_id,
                     task_id=self._task_instance.task_id,
                 ):
-                    # parse dag file since `airflow tasks run --local` does not parse dag file
-                    dag = get_dag(args.subdir, args.dag_id)
-                    args.func(args, dag=dag)
-                return_code = 0
+                    args.func(args, dag=self.dag)
+                    return_code = 0
             except Exception as exc:
                 return_code = 1
 
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index 87313f46f5..522bf963e2 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -33,6 +33,7 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Callable, TypeVar, cast
 
 from airflow import settings
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.utils import cli_action_loggers
 from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
@@ -205,7 +206,9 @@ def _search_for_dag_file(val: str | None) -> str | None:
     return None
 
 
-def get_dag(subdir: str | None, dag_id: str) -> DAG:
+def get_dag(
+    subdir: str | None, dag_id: str, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES')
+) -> DAG:
     """
     Returns DAG of a given dag_id
 
@@ -216,11 +219,11 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
     from airflow.models import DagBag
 
     first_path = process_subdir(subdir)
-    dagbag = DagBag(first_path)
+    dagbag = DagBag(first_path, include_examples=include_examples)
     if dag_id not in dagbag.dags:
         fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
         logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
-        dagbag = DagBag(dag_folder=fallback_path)
+        dagbag = DagBag(dag_folder=fallback_path, include_examples=include_examples)
         if dag_id not in dagbag.dags:
             raise AirflowException(
                 f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
@@ -228,16 +231,6 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
     return dagbag.dags[dag_id]
 
 
-def get_dag_by_deserialization(dag_id: str) -> DAG:
-    from airflow.models.serialized_dag import SerializedDagModel
-
-    dag_model = SerializedDagModel.get(dag_id)
-    if dag_model is None:
-        raise AirflowException(f"Serialized DAG: {dag_id} could not be found")
-
-    return dag_model.dag
-
-
 def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False):
     """Returns DAG(s) matching a given regex or dag_id"""
     from airflow.models import DagBag
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 03b9259f8d..802140755f 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -159,38 +159,6 @@ class TestCliTasks:
             task_command.task_test(args)
         assert capsys.readouterr().out.endswith(f"{not_password}\n")
 
-    @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
-    @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
-    def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization):
-        """
-        Test using serialized dag for local task_run
-        """
-        task_id = self.dag.task_ids[0]
-        args = [
-            'tasks',
-            'run',
-            '--ignore-all-dependencies',
-            '--local',
-            self.dag_id,
-            task_id,
-            self.run_id,
-        ]
-        mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag
-
-        task_command.task_run(self.parser.parse_args(args))
-        mock_local_job.assert_called_once_with(
-            task_instance=mock.ANY,
-            mark_success=False,
-            ignore_all_deps=True,
-            ignore_depends_on_past=False,
-            ignore_task_deps=False,
-            ignore_ti_state=False,
-            pickle_id=None,
-            pool=None,
-            external_executor_id=None,
-        )
-        mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)
-
     def test_cli_test_different_path(self, session):
         """
         When thedag processor has a different dags folder
@@ -265,38 +233,6 @@ class TestCliTasks:
             # verify that the file was in different location when run
             assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()
 
-    @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
-    @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
-    def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
-        """
-        Fallback to parse dag_file when serialized dag does not exist in the db
-        """
-        task_id = self.dag.task_ids[0]
-        args = [
-            'tasks',
-            'run',
-            '--ignore-all-dependencies',
-            '--local',
-            self.dag_id,
-            task_id,
-            self.run_id,
-        ]
-        mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found'))
-
-        task_command.task_run(self.parser.parse_args(args))
-        mock_local_job.assert_called_once_with(
-            task_instance=mock.ANY,
-            mark_success=False,
-            ignore_all_deps=True,
-            ignore_depends_on_past=False,
-            ignore_task_deps=False,
-            ignore_ti_state=False,
-            pickle_id=None,
-            pool=None,
-            external_executor_id=None,
-        )
-        mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)
-
     @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
     def test_run_with_existing_dag_run_id(self, mock_local_job):
         """