You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/10/18 13:10:22 UTC

[airflow] 06/41: Fix airflow tasks run --local when dags_folder differs from that of processor (#26509)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 804abc71e674bbdc52b0ed0187bad21ba3a171d7
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Wed Sep 28 15:56:14 2022 -0700

    Fix airflow tasks run --local when dags_folder differs from that of processor (#26509)
    
    Previously the code used the dags_folder of the "current" process (e.g. the celery worker, or k8s executor worker pod) to calculate the relative fileloc based on the full fileloc stored in the serialized dag.  But if the worker dags_folder folder is different from the dags folder configured on the dag processor, then airflow can't calculate the relative path, so it will just use the full path, which in this case will be a bad path.  We can fix this by keeping track of the dags_folder  [...]
    
    (cherry picked from commit c94f978a66a7cfc31b6d461bbcbfd0f2ddb2962e)
---
 airflow/models/dag.py                         | 11 +++-
 airflow/serialization/schema.json             |  8 ++-
 airflow/serialization/serialized_objects.py   |  4 +-
 airflow/utils/cli.py                          | 45 ++++++++++++--
 tests/cli/commands/test_task_command.py       | 84 ++++++++++++++++++++++++++-
 tests/dags/test_dags_folder.py                | 38 ++++++++++++
 tests/models/test_dag.py                      | 57 +++++++++++++++++-
 tests/serialization/test_dag_serialization.py | 11 +++-
 tests/utils/test_cli_util.py                  | 21 +++++++
 9 files changed, 268 insertions(+), 11 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 60660ce0fb..1a6b9906b5 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -584,6 +584,11 @@ class DAG(LoggingMixin):
                 f"Bad formatted links are: {wrong_links}"
             )
 
+        # this will only be set at serialization time
+        # it's only use is for determining the relative
+        # fileloc based only on the serialize dag
+        self._processor_dags_folder = None
+
     def get_doc_md(self, doc_md: str | None) -> str | None:
         if doc_md is None:
             return doc_md
@@ -1189,7 +1194,11 @@ class DAG(LoggingMixin):
         """File location of the importable dag 'file' relative to the configured DAGs folder."""
         path = pathlib.Path(self.fileloc)
         try:
-            return path.relative_to(settings.DAGS_FOLDER)
+            rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER)
+            if rel_path == pathlib.Path('.'):
+                return path
+            else:
+                return rel_path
         except ValueError:
             # Not relative to DAGS_FOLDER.
             return path
diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json
index ddbedad42c..13e91b33d6 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -133,7 +133,13 @@
         "catchup": { "type": "boolean" },
         "is_subdag": { "type": "boolean" },
         "fileloc": { "type" : "string"},
-        "orientation": { "type" : "string"},
+        "_processor_dags_folder": {
+            "anyOf": [
+                { "type": "null" },
+                {"type": "string"}
+            ]
+        },
+         "orientation": { "type" : "string"},
         "_description": { "type" : "string"},
         "_concurrency": { "type" : "number"},
         "_max_active_tasks": { "type" : "number"},
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 969b6014db..542573fbcc 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -50,7 +50,7 @@ from airflow.providers_manager import ProvidersManager
 from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
 from airflow.serialization.helpers import serialize_template_field
 from airflow.serialization.json_schema import Validator, load_dag_schema
-from airflow.settings import json
+from airflow.settings import DAGS_FOLDER, json
 from airflow.timetables.base import Timetable
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.docs import get_docs_url
@@ -1120,6 +1120,8 @@ class SerializedDAG(DAG, BaseSerialization):
         try:
             serialized_dag = cls.serialize_to_json(dag, cls._decorated_fields)
 
+            serialized_dag['_processor_dags_folder'] = DAGS_FOLDER
+
             # If schedule_interval is backed by timetable, serialize only
             # timetable; vice versa for a timetable backed by schedule_interval.
             if dag.timetable.summary == dag.schedule_interval:
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index 37fe0d2702..87313f46f5 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -29,6 +29,7 @@ import traceback
 import warnings
 from argparse import Namespace
 from datetime import datetime
+from pathlib import Path
 from typing import TYPE_CHECKING, Callable, TypeVar, cast
 
 from airflow import settings
@@ -43,6 +44,8 @@ T = TypeVar("T", bound=Callable)
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
 
+logger = logging.getLogger(__name__)
+
 
 def _check_cli_args(args):
     if not args:
@@ -181,15 +184,47 @@ def get_dag_by_file_location(dag_id: str):
     return dagbag.dags[dag_id]
 
 
+def _search_for_dag_file(val: str | None) -> str | None:
+    """
+    Search for the file referenced at fileloc.
+
+    By the time we get to this function, we've already run this `val` through `process_subdir`
+    and loaded the DagBag there and came up empty.  So here, if `val` is a file path, we make
+    a last ditch effort to try and find a dag file with the same name in our dags folder. (This
+    avoids the unnecessary dag parsing that would occur if we just parsed the dags folder).
+
+    If `val` is a path to a file, this likely means that the serializing process had a dags_folder
+    equal to only the dag file in question. This prevents us from determining the relative location.
+    And if the paths are different between worker and dag processor / scheduler, then we won't find
+    the dag at the given location.
+    """
+    if val and Path(val).suffix in ('.zip', '.py'):
+        matches = list(Path(settings.DAGS_FOLDER).rglob(Path(val).name))
+        if len(matches) == 1:
+            return matches[0].as_posix()
+    return None
+
+
 def get_dag(subdir: str | None, dag_id: str) -> DAG:
-    """Returns DAG of a given dag_id"""
+    """
+    Returns DAG of a given dag_id
+
+    First it we'll try to use the given subdir.  If that doesn't work, we'll try to
+    find the correct path (assuming it's a file) and failing that, use the configured
+    dags folder.
+    """
     from airflow.models import DagBag
 
-    dagbag = DagBag(process_subdir(subdir))
+    first_path = process_subdir(subdir)
+    dagbag = DagBag(first_path)
     if dag_id not in dagbag.dags:
-        raise AirflowException(
-            f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
-        )
+        fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
+        logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
+        dagbag = DagBag(dag_folder=fallback_path)
+        if dag_id not in dagbag.dags:
+            raise AirflowException(
+                f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
+            )
     return dagbag.dags[dag_id]
 
 
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 56a955b0ab..03b9259f8d 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -22,9 +22,10 @@ import json
 import logging
 import os
 import re
+import tempfile
 import unittest
 from argparse import ArgumentParser
-from contextlib import redirect_stdout
+from contextlib import contextmanager, redirect_stdout
 from pathlib import Path
 from unittest import mock
 
@@ -60,6 +61,13 @@ def reset(dag_id):
         runs.delete()
 
 
+@contextmanager
+def move_back(old_path, new_path):
+    os.rename(old_path, new_path)
+    yield
+    os.rename(new_path, old_path)
+
+
 # TODO: Check if tests needs side effects - locally there's missing DAG
 class TestCliTasks:
     run_id = 'TEST_RUN_ID'
@@ -183,6 +191,80 @@ class TestCliTasks:
         )
         mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)
 
+    def test_cli_test_different_path(self, session):
+        """
+        When thedag processor has a different dags folder
+        from the worker, ``airflow tasks run --local`` should still work.
+        """
+        repo_root = Path(__file__).parent.parent.parent.parent
+        orig_file_path = repo_root / 'tests/dags/test_dags_folder.py'
+        orig_dags_folder = orig_file_path.parent
+
+        # parse dag in original path
+        with conf_vars({('core', 'dags_folder'): orig_dags_folder.as_posix()}):
+            dagbag = DagBag(include_examples=False)
+            dag = dagbag.get_dag('test_dags_folder')
+            dagbag.sync_to_db(session=session)
+
+        dag.create_dagrun(
+            state=State.NONE,
+            run_id='abc123',
+            run_type=DagRunType.MANUAL,
+            execution_date=pendulum.now('UTC'),
+            session=session,
+        )
+        session.commit()
+
+        # now let's move the file
+        # additionally let's update the dags folder to be the new path
+        # ideally since dags_folder points correctly to the file, airflow
+        # should be able to find the dag.
+        with tempfile.TemporaryDirectory() as td:
+            new_file_path = Path(td) / Path(orig_file_path).name
+            new_dags_folder = new_file_path.parent
+            with move_back(orig_file_path, new_file_path), conf_vars(
+                {('core', 'dags_folder'): new_dags_folder.as_posix()}
+            ):
+                ser_dag = (
+                    session.query(SerializedDagModel)
+                    .filter(SerializedDagModel.dag_id == 'test_dags_folder')
+                    .one()
+                )
+                # confirm that the serialized dag location has not been updated
+                assert ser_dag.fileloc == orig_file_path.as_posix()
+                assert ser_dag.data['dag']['_processor_dags_folder'] == orig_dags_folder.as_posix()
+                assert ser_dag.data['dag']['fileloc'] == orig_file_path.as_posix()
+                assert ser_dag.dag._processor_dags_folder == orig_dags_folder.as_posix()
+                from airflow.settings import DAGS_FOLDER
+
+                assert DAGS_FOLDER == new_dags_folder.as_posix() != orig_dags_folder.as_posix()
+                task_command.task_run(
+                    self.parser.parse_args(
+                        [
+                            'tasks',
+                            'run',
+                            '--ignore-all-dependencies',
+                            '--local',
+                            'test_dags_folder',
+                            'task',
+                            'abc123',
+                        ]
+                    )
+                )
+            ti = (
+                session.query(TaskInstance)
+                .filter(
+                    TaskInstance.task_id == 'task',
+                    TaskInstance.dag_id == 'test_dags_folder',
+                    TaskInstance.run_id == 'abc123',
+                    TaskInstance.map_index == -1,
+                )
+                .one()
+            )
+            assert ti.state == 'success'
+            # verify that the file was in different location when run
+            assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()
+
     @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
     @mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
     def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
diff --git a/tests/dags/test_dags_folder.py b/tests/dags/test_dags_folder.py
new file mode 100644
index 0000000000..e4b15a0857
--- /dev/null
+++ b/tests/dags/test_dags_folder.py
@@ -0,0 +1,38 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pendulum
+
+from airflow import DAG
+from airflow.decorators import task
+
+with DAG(
+    dag_id='test_dags_folder',
+    schedule=None,
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+) as dag:
+
+    @task(task_id="task")
+    def return_file_path():
+        """Print the Airflow context and ds variable from the context."""
+        print(f"dag file location: {__file__}")
+        return __file__
+
+    return_file_path()
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 62f057c376..54634e463f 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -37,6 +37,7 @@ from dateutil.relativedelta import relativedelta
 from freezegun import freeze_time
 from sqlalchemy import inspect
 
+import airflow
 from airflow import models, settings
 from airflow.configuration import conf
 from airflow.datasets import Dataset
@@ -47,6 +48,7 @@ from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DagOwnerAttributes, dag as dag_decorator, get_dataset_triggered_next_run_info
 from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel, TaskOutletDatasetReference
 from airflow.models.param import DagParam, Param, ParamsDict
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
 from airflow.operators.python import PythonOperator
@@ -65,12 +67,24 @@ from airflow.utils.types import DagRunType
 from airflow.utils.weight_rule import WeightRule
 from tests.models import DEFAULT_DATE
 from tests.test_utils.asserts import assert_queries_count
-from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs
+from tests.test_utils.config import conf_vars
+from tests.test_utils.db import clear_db_dags, clear_db_datasets, clear_db_runs, clear_db_serialized_dags
 from tests.test_utils.mapping import expand_mapped_task
 from tests.test_utils.timetables import cron_timetable, delta_timetable
 
 TEST_DATE = datetime_tz(2015, 1, 2, 0, 0)
 
+repo_root = Path(airflow.__file__).parent.parent
+
+
+@pytest.fixture
+def clear_dags():
+    clear_db_dags()
+    clear_db_serialized_dags()
+    yield
+    clear_db_dags()
+    clear_db_serialized_dags()
+
 
 class TestDag:
     def setup_method(self) -> None:
@@ -2273,6 +2287,47 @@ class TestDagModel:
 
         assert dag.relative_fileloc == expected_relative
 
+    @pytest.mark.parametrize(
+        'reader_dags_folder', [settings.DAGS_FOLDER, str(repo_root / 'airflow/example_dags')]
+    )
+    @pytest.mark.parametrize(
+        ('fileloc', 'expected_relative'),
+        [
+            (str(Path(settings.DAGS_FOLDER, 'a.py')), Path('a.py')),
+            ('/tmp/foo.py', Path('/tmp/foo.py')),
+        ],
+    )
+    def test_relative_fileloc_serialized(
+        self, fileloc, expected_relative, session, clear_dags, reader_dags_folder
+    ):
+        """
+        The serialized dag model includes the dags folder as configured on the thing serializing
+        the dag.  On the thing deserializing the dag, when determining relative fileloc,
+        we should use the dags folder of the processor.  So even if the dags folder of
+        the deserializer is different (meaning that the full path is no longer relative to
+        the dags folder) then we should still get the relative fileloc as it existed on the
+        serializer process.  When the full path is not relative to the configured dags folder,
+        then relative fileloc should just be the full path.
+        """
+        dag = DAG(dag_id='test')
+        dag.fileloc = fileloc
+        sdm = SerializedDagModel(dag)
+        session.add(sdm)
+        session.commit()
+        session.expunge_all()
+        sdm = SerializedDagModel.get(dag.dag_id, session)
+        dag = sdm.dag
+        with conf_vars({('core', 'dags_folder'): reader_dags_folder}):
+            assert dag.relative_fileloc == expected_relative
+
+    def test__processor_dags_folder(self, session):
+        """Only populated after deserializtion"""
+        dag = DAG(dag_id='test')
+        dag.fileloc = '/abc/test.py'
+        assert dag._processor_dags_folder is None
+        sdm = SerializedDagModel(dag)
+        assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER
+
     @pytest.mark.need_serialized_dag
     def test_dags_needing_dagruns_dataset_triggered_dag_info_queued_times(self, session, dag_maker):
         dataset1 = Dataset(uri="ds1")
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 44a218290e..6b485f7465 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -27,6 +27,7 @@ import os
 import pickle
 from datetime import datetime, timedelta
 from glob import glob
+from pathlib import Path
 from unittest import mock
 
 import pendulum
@@ -34,6 +35,7 @@ import pytest
 from dateutil.relativedelta import FR, relativedelta
 from kubernetes.client import models as k8s
 
+import airflow
 from airflow.datasets import Dataset
 from airflow.exceptions import SerializationError
 from airflow.hooks.base import BaseHook
@@ -63,6 +65,8 @@ from tests.test_utils.config import conf_vars
 from tests.test_utils.mock_operators import CustomOperator, GoogleLink, MockOperator
 from tests.test_utils.timetables import CustomSerializationTimetable, cron_timetable, delta_timetable
 
+repo_root = Path(airflow.__file__).parent.parent
+
 
 class CustomDepOperator(BashOperator):
     """
@@ -133,6 +137,7 @@ serialized_simple_dag_ground_truth = {
         "_dag_id": "simple_dag",
         "doc_md": "### DAG Tutorial Documentation",
         "fileloc": None,
+        "_processor_dags_folder": f"{repo_root}/tests/dags",
         "tasks": [
             {
                 "task_id": "bash_task",
@@ -494,13 +499,17 @@ class TestStringifiedDAGs:
             'default_args',
             "_task_group",
             'params',
+            '_processor_dags_folder',
         }
         fields_to_check = dag.get_serialized_fields() - exclusion_list
         for field in fields_to_check:
             assert getattr(serialized_dag, field) == getattr(
                 dag, field
             ), f'{dag.dag_id}.{field} does not match'
-
+        # _processor_dags_folder is only populated at serialization time
+        # it's only used when relying on serialized dag to determine a dag's relative path
+        assert dag._processor_dags_folder is None
+        assert serialized_dag._processor_dags_folder == str(repo_root / 'tests/dags')
         if dag.default_args:
             for k, v in dag.default_args.items():
                 if callable(v):
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index 126f25d90d..e814d6bfd2 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -24,14 +24,19 @@ import sys
 from argparse import Namespace
 from contextlib import contextmanager
 from datetime import datetime
+from pathlib import Path
 from unittest import mock
 
 import pytest
 
+import airflow
 from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.models.log import Log
 from airflow.utils import cli, cli_action_loggers, timezone
+from airflow.utils.cli import _search_for_dag_file
+
+repo_root = Path(airflow.__file__).parent.parent
 
 
 class TestCliUtil:
@@ -189,3 +194,19 @@ def fail_func(_):
 @cli.action_cli(check_db=False)
 def success_func(_):
     pass
+
+
+def test__search_for_dags_file():
+    dags_folder = settings.DAGS_FOLDER
+    assert _search_for_dag_file('') is None
+    assert _search_for_dag_file(None) is None
+    # if it's a file, and one can be find in subdir, should return full path
+    assert _search_for_dag_file('any/hi/test_dags_folder.py') == str(
+        Path(dags_folder) / 'test_dags_folder.py'
+    )
+    # if a folder, even if exists, should return dags folder
+    existing_folder = Path(settings.DAGS_FOLDER, 'subdir1')
+    assert existing_folder.exists()
+    assert _search_for_dag_file(existing_folder.as_posix()) is None
+    # when multiple files found, default to the dags folder
+    assert _search_for_dag_file('any/hi/__init__.py') is None