You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/03/31 14:49:04 UTC

[airflow] branch master updated: Allow pathlib.Path in DagBag and various util fns (#15110)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e99ae0  Allow pathlib.Path in DagBag and various util fns (#15110)
6e99ae0 is described below

commit 6e99ae05642758691361dfe9d7b767cfc9a2b616
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Wed Mar 31 15:48:46 2021 +0100

    Allow pathlib.Path in DagBag and various util fns (#15110)
    
    We do a lot of path manipulation in this test file, and it's easier to
    understand by using pathlib without all the nested `os.path.*` calls.
    
    This change adds "support" for passing Path objects to DagBag and
    util functions.
---
 airflow/models/dagbag.py           | 20 ++++++++++++--------
 airflow/utils/dag_processing.py    |  7 +++++--
 airflow/utils/file.py              | 13 ++++++++-----
 tests/utils/test_dag_processing.py | 19 ++++++++++---------
 4 files changed, 35 insertions(+), 24 deletions(-)

diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index e5f986a..7099e18 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -27,7 +27,7 @@ import traceback
 import warnings
 import zipfile
 from datetime import datetime, timedelta
-from typing import Dict, List, NamedTuple, Optional
+from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Union
 
 from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter
 from sqlalchemy.exc import OperationalError
@@ -46,6 +46,9 @@ from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
 from airflow.utils.session import provide_session
 from airflow.utils.timeout import timeout
 
+if TYPE_CHECKING:
+    import pathlib
+
 
 class FileLoadStat(NamedTuple):
     """Information about single file"""
@@ -89,7 +92,7 @@ class DagBag(LoggingMixin):
 
     def __init__(
         self,
-        dag_folder: Optional[str] = None,
+        dag_folder: Union[str, "pathlib.Path", None] = None,
         include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
         include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
         safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
@@ -424,11 +427,11 @@ class DagBag(LoggingMixin):
 
     def collect_dags(
         self,
-        dag_folder=None,
-        only_if_updated=True,
-        include_examples=conf.getboolean('core', 'LOAD_EXAMPLES'),
-        include_smart_sensor=conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
-        safe_mode=conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
+        dag_folder: Union[str, "pathlib.Path", None] = None,
+        only_if_updated: bool = True,
+        include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
+        include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
+        safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
     ):
         """
         Given a file path or a folder, this method looks for python modules,
@@ -450,7 +453,8 @@ class DagBag(LoggingMixin):
         # Used to store stats around DagBag processing
         stats = []
 
-        dag_folder = correct_maybe_zipped(dag_folder)
+        # Ensure dag_folder is a str -- it may have been a pathlib.Path
+        dag_folder = correct_maybe_zipped(str(dag_folder))
         for filepath in list_py_file_paths(
             dag_folder,
             safe_mode=safe_mode,
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index 7e35150..cf27b40 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -31,7 +31,7 @@ from collections import defaultdict
 from datetime import datetime, timedelta
 from importlib import import_module
 from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union, cast
 
 from setproctitle import setproctitle  # pylint: disable=no-name-in-module
 from sqlalchemy import or_
@@ -54,6 +54,9 @@ from airflow.utils.process_utils import kill_child_processes_by_pids, reap_proce
 from airflow.utils.session import provide_session
 from airflow.utils.state import State
 
+if TYPE_CHECKING:
+    import pathlib
+
 
 class AbstractDagFileProcessorProcess(metaclass=ABCMeta):
     """Processes a DAG file. See SchedulerJob.process_file() for more details."""
@@ -491,7 +494,7 @@ class DagFileProcessorManager(LoggingMixin):  # pylint: disable=too-many-instanc
 
     def __init__(
         self,
-        dag_directory: str,
+        dag_directory: Union[str, "pathlib.Path"],
         max_runs: int,
         processor_factory: Callable[[str, List[CallbackRequest]], AbstractDagFileProcessorProcess],
         processor_timeout: timedelta,
diff --git a/airflow/utils/file.py b/airflow/utils/file.py
index 41c6b32..96515a0 100644
--- a/airflow/utils/file.py
+++ b/airflow/utils/file.py
@@ -21,10 +21,13 @@ import os
 import re
 import zipfile
 from pathlib import Path
-from typing import Dict, Generator, List, Optional, Pattern
+from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Pattern, Union
 
 from airflow.configuration import conf
 
+if TYPE_CHECKING:
+    import pathlib
+
 log = logging.getLogger(__name__)
 
 
@@ -131,7 +134,7 @@ def find_path_from_directory(base_dir_path: str, ignore_file_name: str) -> Gener
 
 
 def list_py_file_paths(
-    directory: str,
+    directory: Union[str, "pathlib.Path"],
     safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE', fallback=True),
     include_examples: Optional[bool] = None,
     include_smart_sensor: Optional[bool] = conf.getboolean('smart_sensor', 'use_smart_sensor'),
@@ -159,7 +162,7 @@ def list_py_file_paths(
     if directory is None:
         file_paths = []
     elif os.path.isfile(directory):
-        file_paths = [directory]
+        file_paths = [str(directory)]
     elif os.path.isdir(directory):
         file_paths.extend(find_dag_file_paths(directory, safe_mode))
     if include_examples:
@@ -175,11 +178,11 @@ def list_py_file_paths(
     return file_paths
 
 
-def find_dag_file_paths(directory: str, safe_mode: bool) -> List[str]:
+def find_dag_file_paths(directory: Union[str, "pathlib.Path"], safe_mode: bool) -> List[str]:
     """Finds file paths of all DAG files."""
     file_paths = []
 
-    for file_path in find_path_from_directory(directory, ".airflowignore"):
+    for file_path in find_path_from_directory(str(directory), ".airflowignore"):
         try:
             if not os.path.isfile(file_path):
                 continue
diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py
index 67936cd..b8a7953 100644
--- a/tests/utils/test_dag_processing.py
+++ b/tests/utils/test_dag_processing.py
@@ -18,6 +18,7 @@
 
 import multiprocessing
 import os
+import pathlib
 import random
 import sys
 import unittest
@@ -50,7 +51,7 @@ from tests.core.test_logging_config import SETTINGS_FILE_VALID, settings_context
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags
 
-TEST_DAG_FOLDER = os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags')
+TEST_DAG_FOLDER = pathlib.Path(__file__).parent.parent / 'dags'
 
 DEFAULT_DATE = timezone.datetime(2016, 1, 1)
 
@@ -372,7 +373,7 @@ class TestDagFileProcessorManager(unittest.TestCase):
         Check that the same set of failure callback with zombies are passed to the dag
         file processors until the next zombie detection logic is invoked.
         """
-        test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py')
+        test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py'
         with conf_vars({('scheduler', 'parsing_processes'): '1', ('core', 'load_examples'): 'False'}):
             dagbag = DagBag(test_dag_path, read_dags_from_db=False)
             with create_session() as session:
@@ -401,7 +402,7 @@ class TestDagFileProcessorManager(unittest.TestCase):
                     )
                 ]
 
-            test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_example_bash_operator.py')
+            test_dag_path = TEST_DAG_FOLDER / 'test_example_bash_operator.py'
 
             child_pipe, parent_pipe = multiprocessing.Pipe()
             async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
@@ -430,12 +431,12 @@ class TestDagFileProcessorManager(unittest.TestCase):
             if async_mode:
                 # Once for initial parse, and then again for the add_callback_to_queue
                 assert len(fake_processors) == 2
-                assert fake_processors[0]._file_path == test_dag_path
+                assert fake_processors[0]._file_path == str(test_dag_path)
                 assert fake_processors[0]._callback_requests == []
             else:
                 assert len(fake_processors) == 1
 
-            assert fake_processors[-1]._file_path == test_dag_path
+            assert fake_processors[-1]._file_path == str(test_dag_path)
             callback_requests = fake_processors[-1]._callback_requests
             assert {zombie.simple_task_instance.key for zombie in expected_failure_callback_requests} == {
                 result.simple_task_instance.key for result in callback_requests
@@ -499,7 +500,7 @@ class TestDagFileProcessorManager(unittest.TestCase):
         from airflow.jobs.scheduler_job import SchedulerJob
 
         dag_id = 'exit_test_dag'
-        dag_directory = os.path.normpath(os.path.join(TEST_DAG_FOLDER, os.pardir, "dags_with_system_exit"))
+        dag_directory = TEST_DAG_FOLDER.parent / 'dags_with_system_exit'
 
         # Delete the one valid DAG/SerializedDAG, and check that it gets re-created
         clear_db_dags()
@@ -561,7 +562,7 @@ class TestDagFileProcessorAgent(unittest.TestCase):
         with settings_context(SETTINGS_FILE_VALID):
             # Launch a process through DagFileProcessorAgent, which will try
             # reload the logging module.
-            test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py')
+            test_dag_path = TEST_DAG_FOLDER / 'test_scheduler_dags.py'
             async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
             log_file_loc = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')
 
@@ -589,7 +590,7 @@ class TestDagFileProcessorAgent(unittest.TestCase):
         clear_db_serialized_dags()
         clear_db_dags()
 
-        test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py')
+        test_dag_path = TEST_DAG_FOLDER / 'test_scheduler_dags.py'
         async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
         processor_agent = DagFileProcessorAgent(
             test_dag_path, 1, type(self)._processor_factory, timedelta.max, [], False, async_mode
@@ -613,7 +614,7 @@ class TestDagFileProcessorAgent(unittest.TestCase):
             assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)]
 
     def test_launch_process(self):
-        test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py')
+        test_dag_path = TEST_DAG_FOLDER / 'test_scheduler_dags.py'
         async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn')
 
         log_file_loc = conf.get('logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION')