You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/27 09:24:54 UTC

[GitHub] stale[bot] closed pull request #3138: [AIRFLOW-2221] Create DagFetcher abstraction

stale[bot] closed pull request #3138: [AIRFLOW-2221] Create DagFetcher abstraction
URL: https://github.com/apache/incubator-airflow/pull/3138
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/__init__.py b/airflow/__init__.py
index 4c4509e00e..c4d6b238dd 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -75,14 +75,17 @@ class AirflowMacroPlugin(object):
     def __init__(self, namespace):
         self.namespace = namespace
 
-from airflow import operators
+
+from airflow import operators  # noqa: E402
 from airflow import sensors  # noqa: E402
-from airflow import hooks
-from airflow import executors
-from airflow import macros
+from airflow import hooks  # noqa: E402
+from airflow import executors  # noqa: E402
+from airflow import macros  # noqa: E402
+from airflow.dag import fetchers  # noqa: E402
 
 operators._integrate_plugins()
-sensors._integrate_plugins()  # noqa: E402
+sensors._integrate_plugins()
 hooks._integrate_plugins()
 executors._integrate_plugins()
 macros._integrate_plugins()
+fetchers._integrate_plugins()
diff --git a/airflow/dag/fetchers/__init__.py b/airflow/dag/fetchers/__init__.py
new file mode 100644
index 0000000000..861e63751a
--- /dev/null
+++ b/airflow/dag/fetchers/__init__.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.dag.fetchers.filesystem import FileSystemDagFetcher
+from airflow.dag.fetchers.hdfs import HDFSDagFetcher
+from airflow.dag.fetchers.s3 import S3DagFetcher
+from airflow.dag.fetchers.gcs import GCSDagFetcher
+from airflow.dag.fetchers.git import GitDagFetcher
+
+
+def get_dag_fetcher(dagbag, dags_uri):
+    """
+    Factory method that returns an instance of the right
+    DagFetcher, based on the dags_uri prefix.
+
+    Any prefix that does not match keys in the dag_fetchers
+    dict (or no prefix at all) defaults to FileSystemDagFetcher.
+    """
+    log = LoggingMixin().log
+
+    dag_fetchers = dict(
+        hdfs=HDFSDagFetcher,
+        s3=S3DagFetcher,
+        gcs=GCSDagFetcher,
+        git=GitDagFetcher)
+
+    uri_schema = dags_uri.split(':')[0]
+
+    if uri_schema not in dag_fetchers:
+        log.debug('Defaulting to FileSystemDagFetcher')
+        return FileSystemDagFetcher(dagbag, dags_uri)
+
+    return dag_fetchers[uri_schema](dagbag, dags_uri)
+
+
+def _integrate_plugins():
+    """Integrate plugins to the context."""
+    from airflow.plugins_manager import dag_fetchers_modules
+    for dag_fetchers_module in dag_fetchers_modules:
+        sys.modules[dag_fetchers_module.__name__] = dag_fetchers_module
+        globals()[dag_fetchers_module._name] = dag_fetchers_module
diff --git a/airflow/dag/fetchers/base.py b/airflow/dag/fetchers/base.py
new file mode 100644
index 0000000000..93faae8af3
--- /dev/null
+++ b/airflow/dag/fetchers/base.py
@@ -0,0 +1,64 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class BaseDagFetcher(LoggingMixin):
+    """
+    Abstract base class for all DagFetchers.
+
+    A DagFetcher's responsability is to find the dags in
+    the dags_uri and add them to the dagbag.
+
+    The fetch method must be implemented by any given DagFetcher,
+    and return the list of per dag statistics. It must also
+    implement a process_file method, which is used to reprocess
+    a DAG.
+
+    :param dagbag: a DagBag instance, which we will populate
+    :type dagbag: DagBag
+    :param dags_uri: the URI for the dags folder. The schema
+        prefix determines the child that will be instantiated
+    :type dags_uri: string
+    :param safe_mode: if dag files should be processed with safe_mode
+    :type safe_mode: boolean
+    """
+    FileLoadStat = namedtuple(
+        'FileLoadStat', 'file duration dag_num task_num dags')
+
+    def __init__(self, dagbag, dags_uri=None, safe_mode=True):
+        self.found_dags = []
+        self.stats = []
+        self.dagbag = dagbag
+        self.dags_uri = dags_uri
+        self.safe_mode = safe_mode
+
+    def process_file(self, filepath, only_if_updated=True):
+        """
+        This method is used to process/reprocess a single file and
+        must be implemented by all DagFetchers.
+
+        Must return the dags in the file.
+        """
+        raise NotImplementedError()
+
+    def fetch(self, only_if_updated=True):
+        """
+        This is the main method to derive when creating a DagFetcher.
+        """
+        raise NotImplementedError()
diff --git a/airflow/dag/fetchers/filesystem.py b/airflow/dag/fetchers/filesystem.py
new file mode 100644
index 0000000000..57d202db2b
--- /dev/null
+++ b/airflow/dag/fetchers/filesystem.py
@@ -0,0 +1,194 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from datetime import datetime
+
+import sys
+import os
+import re
+import zipfile
+import hashlib
+import imp
+import importlib
+
+import airflow
+from airflow import configuration
+from airflow.utils import timezone
+from airflow.utils.timeout import timeout
+from airflow.exceptions import AirflowDagCycleException
+from airflow.dag.fetchers.base import BaseDagFetcher
+
+
+class FileSystemDagFetcher(BaseDagFetcher):
+    """
+    Fetches dags from the local file system, by walking the dags_uri
+    folder on the local disk, looking for .py and .zip files.
+
+    :param dagbag: a DagBag instance, which we will populate
+    :type dagbag: DagBag
+    :param dags_uri: the URI for the dags folder. The schema
+        prefix determines the child that will be instantiated
+    :type dags_uri: string
+    :param safe_mode: if dag files should be processed with safe_mode
+    :type safe_mode: boolean
+    """
+    def process_file(self, filepath, only_if_updated=True):
+        """
+        Given a path to a python module or zip file, this method imports
+        the module and look for dag objects within it.
+        """
+        found_dags = []
+        # if the source file no longer exists in the DB or in the filesystem,
+        # return an empty list
+        # todo: raise exception?
+        if filepath is None or not os.path.isfile(filepath):
+            return found_dags
+
+        try:
+            # This failed before in what may have been a git sync
+            # race condition
+            file_last_changed = datetime.fromtimestamp(
+                os.path.getmtime(filepath))
+            if only_if_updated \
+                    and filepath in self.dagbag.file_last_changed \
+                    and file_last_changed == self.dagbag.file_last_changed[filepath]:
+                return found_dags
+
+        except Exception as e:
+            self.log.exception(e)
+            return found_dags
+
+        mods = []
+        if not zipfile.is_zipfile(filepath):
+            if self.safe_mode and os.path.isfile(filepath):
+                with open(filepath, 'rb') as f:
+                    content = f.read()
+                    if not all([s in content for s in (b'DAG', b'airflow')]):
+                        self.dagbag.file_last_changed[filepath] = file_last_changed
+                        return found_dags
+
+            self.log.debug("Importing %s", filepath)
+            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
+            mod_name = ('unusual_prefix_' +
+                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
+                        '_' + org_mod_name)
+
+            if mod_name in sys.modules:
+                del sys.modules[mod_name]
+
+            with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
+                try:
+                    m = imp.load_source(mod_name, filepath)
+                    mods.append(m)
+                except Exception as e:
+                    self.log.exception("Failed to import: %s", filepath)
+                    self.dagbag.import_errors[filepath] = str(e)
+                    self.dagbag.file_last_changed[filepath] = file_last_changed
+
+        else:
+            zip_file = zipfile.ZipFile(filepath)
+            for mod in zip_file.infolist():
+                head, _ = os.path.split(mod.filename)
+                mod_name, ext = os.path.splitext(mod.filename)
+                if not head and (ext == '.py' or ext == '.pyc'):
+                    if mod_name == '__init__':
+                        self.log.warning("Found __init__.%s at root of %s", ext, filepath)
+                    if self.safe_mode:
+                        with zip_file.open(mod.filename) as zf:
+                            self.log.debug("Reading %s from %s", mod.filename, filepath)
+                            content = zf.read()
+                            if not all([s in content for s in (b'DAG', b'airflow')]):
+                                self.dagbag.file_last_changed[filepath] = (
+                                    file_last_changed)
+                                # todo: create ignore list
+                                return found_dags
+
+                    if mod_name in sys.modules:
+                        del sys.modules[mod_name]
+
+                    try:
+                        sys.path.insert(0, filepath)
+                        m = importlib.import_module(mod_name)
+                        mods.append(m)
+                    except Exception as e:
+                        self.log.exception("Failed to import: %s", filepath)
+                        self.dagbag.import_errors[filepath] = str(e)
+                        self.dagbag.file_last_changed[filepath] = file_last_changed
+
+        for m in mods:
+            for dag in list(m.__dict__.values()):
+                if isinstance(dag, airflow.models.DAG):
+                    if not dag.full_filepath:
+                        dag.full_filepath = filepath
+                        if dag.fileloc != filepath:
+                            dag.fileloc = filepath
+                    try:
+                        dag.is_subdag = False
+                        self.dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
+                        found_dags.append(dag)
+                        found_dags += dag.subdags
+                    except AirflowDagCycleException as cycle_exception:
+                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
+                        self.dagbag.import_errors[dag.full_filepath] = \
+                            str(cycle_exception)
+                        self.dagbag.file_last_changed[dag.full_filepath] = \
+                            file_last_changed
+
+        self.dagbag.file_last_changed[filepath] = file_last_changed
+        return found_dags
+
+    def fetch(self, only_if_updated=True):
+        """
+        Walks the dags_folder (self.dags_uri) looking for files to process
+        """
+        if os.path.isfile(self.dags_uri):
+            self.process_file(self.dags_uri, only_if_updated=only_if_updated)
+        elif os.path.isdir(self.dags_uri):
+            patterns = []
+            for root, dirs, files in os.walk(self.dags_uri, followlinks=True):
+                ignore_file = [f for f in files if f == '.airflowignore']
+                if ignore_file:
+                    f = open(os.path.join(root, ignore_file[0]), 'r')
+                    patterns += [p for p in f.read().split('\n') if p]
+                    f.close()
+                for f in files:
+                    try:
+                        filepath = os.path.join(root, f)
+                        if not os.path.isfile(filepath):
+                            continue
+                        mod_name, file_ext = os.path.splitext(
+                            os.path.split(filepath)[-1])
+                        if file_ext != '.py' and not zipfile.is_zipfile(filepath):
+                            continue
+                        if not any(
+                                [re.findall(p, filepath) for p in patterns]):
+                            ts = timezone.utcnow()
+                            found_dags = self.process_file(
+                                filepath, only_if_updated=only_if_updated)
+
+                            td = timezone.utcnow() - ts
+                            td = td.total_seconds() + (
+                                float(td.microseconds) / 1000000)
+                            self.stats.append(self.FileLoadStat(
+                                filepath.replace(self.dags_uri, ''),
+                                td,
+                                len(found_dags),
+                                sum([len(dag.tasks) for dag in found_dags]),
+                                str([dag.dag_id for dag in found_dags]),
+                            ))
+                    except Exception as e:
+                        self.log.exception(e)
+
+        return self.stats
diff --git a/airflow/dag/fetchers/gcs.py b/airflow/dag/fetchers/gcs.py
new file mode 100644
index 0000000000..4e22403a72
--- /dev/null
+++ b/airflow/dag/fetchers/gcs.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from airflow.dag.fetchers.base import BaseDagFetcher
+
+
+class GCSDagFetcher(BaseDagFetcher):
+    """
+    GCSDagFetcher - Not Implemented
+    """
diff --git a/airflow/dag/fetchers/git.py b/airflow/dag/fetchers/git.py
new file mode 100644
index 0000000000..5b1a7b2269
--- /dev/null
+++ b/airflow/dag/fetchers/git.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from airflow.dag.fetchers.base import BaseDagFetcher
+
+
+class GitDagFetcher(BaseDagFetcher):
+    """
+    GitDagFetcher - Not Implemented
+    """
diff --git a/airflow/dag/fetchers/hdfs.py b/airflow/dag/fetchers/hdfs.py
new file mode 100644
index 0000000000..0495b1d2c9
--- /dev/null
+++ b/airflow/dag/fetchers/hdfs.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from airflow.dag.fetchers.base import BaseDagFetcher
+
+
+class HDFSDagFetcher(BaseDagFetcher):
+    """
+    HDFSDagFetecher - Not Implemented
+    """
diff --git a/airflow/dag/fetchers/s3.py b/airflow/dag/fetchers/s3.py
new file mode 100644
index 0000000000..56eacd3d1e
--- /dev/null
+++ b/airflow/dag/fetchers/s3.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from airflow.dag.fetchers.base import BaseDagFetcher
+
+
+class S3DagFetcher(BaseDagFetcher):
+    """
+    S3DagFetcher - Not Implemented
+    """
diff --git a/airflow/models.py b/airflow/models.py
index c1b608afbb..33fd5f3aa1 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -22,16 +22,13 @@
 from builtins import str
 from builtins import object, bytes
 import copy
-from collections import namedtuple, defaultdict
+from collections import defaultdict
 from datetime import timedelta
 
 import dill
 import functools
 import getpass
-import imp
-import importlib
 import itertools
-import zipfile
 import jinja2
 import json
 import logging
@@ -63,6 +60,7 @@
 from airflow import settings, utils
 from airflow.executors import GetDefaultExecutor, LocalExecutor
 from airflow import configuration
+from airflow.dag.fetchers import get_dag_fetcher
 from airflow.exceptions import (
     AirflowDagCycleException, AirflowException, AirflowSkipException, AirflowTaskTimeout
 )
@@ -206,13 +204,13 @@ def __init__(
         self.file_last_changed = {}
         self.executor = executor
         self.import_errors = {}
-
+        self.dag_fetcher = get_dag_fetcher(self, dag_folder)
         if include_examples:
             example_dag_folder = os.path.join(
                 os.path.dirname(__file__),
                 'example_dags')
-            self.collect_dags(example_dag_folder)
-        self.collect_dags(dag_folder)
+            self.collect_dags(get_dag_fetcher(self, example_dag_folder))
+        self.collect_dags(self.dag_fetcher)
 
     def size(self):
         """
@@ -241,8 +239,8 @@ def get_dag(self, dag_id):
                 )
         ):
             # Reprocess source file
-            found_dags = self.process_file(
-                filepath=orm_dag.fileloc, only_if_updated=False)
+            found_dags = self.dag_fetcher.process_file(
+                orm_dag.fileloc, only_if_updated=False)
 
             # If the source file no longer exports `dag_id`, delete it from self.dags
             if found_dags and dag_id in [dag.dag_id for dag in found_dags]:
@@ -251,111 +249,6 @@ def get_dag(self, dag_id):
                 del self.dags[dag_id]
         return self.dags.get(dag_id)
 
-    def process_file(self, filepath, only_if_updated=True, safe_mode=True):
-        """
-        Given a path to a python module or zip file, this method imports
-        the module and look for dag objects within it.
-        """
-        found_dags = []
-
-        # if the source file no longer exists in the DB or in the filesystem,
-        # return an empty list
-        # todo: raise exception?
-        if filepath is None or not os.path.isfile(filepath):
-            return found_dags
-
-        try:
-            # This failed before in what may have been a git sync
-            # race condition
-            file_last_changed_on_disk = datetime.fromtimestamp(os.path.getmtime(filepath))
-            if only_if_updated \
-                    and filepath in self.file_last_changed \
-                    and file_last_changed_on_disk == self.file_last_changed[filepath]:
-                return found_dags
-
-        except Exception as e:
-            self.log.exception(e)
-            return found_dags
-
-        mods = []
-        if not zipfile.is_zipfile(filepath):
-            if safe_mode and os.path.isfile(filepath):
-                with open(filepath, 'rb') as f:
-                    content = f.read()
-                    if not all([s in content for s in (b'DAG', b'airflow')]):
-                        self.file_last_changed[filepath] = file_last_changed_on_disk
-                        return found_dags
-
-            self.log.debug("Importing %s", filepath)
-            org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
-            mod_name = ('unusual_prefix_' +
-                        hashlib.sha1(filepath.encode('utf-8')).hexdigest() +
-                        '_' + org_mod_name)
-
-            if mod_name in sys.modules:
-                del sys.modules[mod_name]
-
-            with timeout(configuration.getint('core', "DAGBAG_IMPORT_TIMEOUT")):
-                try:
-                    m = imp.load_source(mod_name, filepath)
-                    mods.append(m)
-                except Exception as e:
-                    self.log.exception("Failed to import: %s", filepath)
-                    self.import_errors[filepath] = str(e)
-                    self.file_last_changed[filepath] = file_last_changed_on_disk
-
-        else:
-            zip_file = zipfile.ZipFile(filepath)
-            for mod in zip_file.infolist():
-                head, _ = os.path.split(mod.filename)
-                mod_name, ext = os.path.splitext(mod.filename)
-                if not head and (ext == '.py' or ext == '.pyc'):
-                    if mod_name == '__init__':
-                        self.log.warning("Found __init__.%s at root of %s", ext, filepath)
-                    if safe_mode:
-                        with zip_file.open(mod.filename) as zf:
-                            self.log.debug("Reading %s from %s", mod.filename, filepath)
-                            content = zf.read()
-                            if not all([s in content for s in (b'DAG', b'airflow')]):
-                                self.file_last_changed[filepath] = (
-                                    file_last_changed_on_disk)
-                                # todo: create ignore list
-                                return found_dags
-
-                    if mod_name in sys.modules:
-                        del sys.modules[mod_name]
-
-                    try:
-                        sys.path.insert(0, filepath)
-                        m = importlib.import_module(mod_name)
-                        mods.append(m)
-                    except Exception as e:
-                        self.log.exception("Failed to import: %s", filepath)
-                        self.import_errors[filepath] = str(e)
-                        self.file_last_changed[filepath] = file_last_changed_on_disk
-
-        for m in mods:
-            for dag in list(m.__dict__.values()):
-                if isinstance(dag, DAG):
-                    if not dag.full_filepath:
-                        dag.full_filepath = filepath
-                        if dag.fileloc != filepath:
-                            dag.fileloc = filepath
-                    try:
-                        dag.is_subdag = False
-                        self.bag_dag(dag, parent_dag=dag, root_dag=dag)
-                        found_dags.append(dag)
-                        found_dags += dag.subdags
-                    except AirflowDagCycleException as cycle_exception:
-                        self.log.exception("Failed to bag_dag: %s", dag.full_filepath)
-                        self.import_errors[dag.full_filepath] = str(cycle_exception)
-                        self.file_last_changed[dag.full_filepath] = \
-                            file_last_changed_on_disk
-
-
-        self.file_last_changed[filepath] = file_last_changed_on_disk
-        return found_dags
-
     @provide_session
     def kill_zombies(self, session=None):
         """
@@ -427,10 +320,9 @@ def bag_dag(self, dag, parent_dag, root_dag):
                         del self.dags[subdag.dag_id]
             raise cycle_exception
 
-
     def collect_dags(
             self,
-            dag_folder=None,
+            dag_fetcher=None,
             only_if_updated=True):
         """
         Given a file path or a folder, this method looks for python modules,
@@ -442,49 +334,10 @@ def collect_dags(
         in the file.
         """
         start_dttm = timezone.utcnow()
-        dag_folder = dag_folder or self.dag_folder
-
-        # Used to store stats around DagBag processing
-        stats = []
-        FileLoadStat = namedtuple(
-            'FileLoadStat', "file duration dag_num task_num dags")
-        if os.path.isfile(dag_folder):
-            self.process_file(dag_folder, only_if_updated=only_if_updated)
-        elif os.path.isdir(dag_folder):
-            patterns = []
-            for root, dirs, files in os.walk(dag_folder, followlinks=True):
-                ignore_file = [f for f in files if f == '.airflowignore']
-                if ignore_file:
-                    f = open(os.path.join(root, ignore_file[0]), 'r')
-                    patterns += [p for p in f.read().split('\n') if p]
-                    f.close()
-                for f in files:
-                    try:
-                        filepath = os.path.join(root, f)
-                        if not os.path.isfile(filepath):
-                            continue
-                        mod_name, file_ext = os.path.splitext(
-                            os.path.split(filepath)[-1])
-                        if file_ext != '.py' and not zipfile.is_zipfile(filepath):
-                            continue
-                        if not any(
-                                [re.findall(p, filepath) for p in patterns]):
-                            ts = timezone.utcnow()
-                            found_dags = self.process_file(
-                                filepath, only_if_updated=only_if_updated)
-
-                            td = timezone.utcnow() - ts
-                            td = td.total_seconds() + (
-                                float(td.microseconds) / 1000000)
-                            stats.append(FileLoadStat(
-                                filepath.replace(dag_folder, ''),
-                                td,
-                                len(found_dags),
-                                sum([len(dag.tasks) for dag in found_dags]),
-                                str([dag.dag_id for dag in found_dags]),
-                            ))
-                    except Exception as e:
-                        self.log.exception(e)
+        dag_fetcher = dag_fetcher or self.dag_fetcher
+
+        stats = dag_fetcher.fetch(only_if_updated=only_if_updated)
+
         Stats.gauge(
             'collect_dags', (timezone.utcnow() - start_dttm).total_seconds(), 1)
         Stats.gauge(
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index aaae4230b7..939f025a24 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -40,6 +40,7 @@ class AirflowPlugin(object):
     hooks = []
     executors = []
     macros = []
+    dag_fetchers = []
     admin_views = []
     flask_blueprints = []
     menu_links = []
@@ -108,6 +109,7 @@ def make_module(name, objects):
 hooks_modules = []
 executors_modules = []
 macros_modules = []
+dag_fetchers_modules = []
 
 # Plugin components to integrate directly
 admin_views = []
@@ -124,6 +126,8 @@ def make_module(name, objects):
     executors_modules.append(
         make_module('airflow.executors.' + p.name, p.executors))
     macros_modules.append(make_module('airflow.macros.' + p.name, p.macros))
+    dag_fetchers_modules.append(
+        make_module('airflow.dag.fetchers.' + p.name, p.dag_fetchers))
 
     admin_views.extend(p.admin_views)
     flask_blueprints.extend(p.flask_blueprints)
diff --git a/tests/models.py b/tests/models.py
index 5d8184c575..42259ed238 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -38,6 +38,12 @@
 from airflow.models import clear_task_instances
 from airflow.models import XCom
 from airflow.models import Connection
+from airflow.dag.fetchers import FileSystemDagFetcher
+from airflow.dag.fetchers import HDFSDagFetcher
+from airflow.dag.fetchers import S3DagFetcher
+from airflow.dag.fetchers import GCSDagFetcher
+from airflow.dag.fetchers import GitDagFetcher
+from airflow.dag.fetchers import get_dag_fetcher
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.bash_operator import BashOperator
 from airflow.operators.python_operator import PythonOperator
@@ -920,7 +926,7 @@ def test_get_non_existing_dag(self):
         non_existing_dag_id = "non_existing_dag_id"
         self.assertIsNone(dagbag.get_dag(non_existing_dag_id))
 
-    def test_process_file_that_contains_multi_bytes_char(self):
+    def test_process_local_file_that_contains_multi_bytes_char(self):
         """
         test that we're able to parse file that contains multi-byte char
         """
@@ -929,42 +935,35 @@ def test_process_file_that_contains_multi_bytes_char(self):
         f.flush()
 
         dagbag = models.DagBag(include_examples=True)
-        self.assertEqual([], dagbag.process_file(f.name))
+        fsfetcher = FileSystemDagFetcher(dagbag)
+
+        self.assertEqual([], fsfetcher.process_file(f.name))
 
     def test_zip(self):
         """
         test the loading of a DAG within a zip file that includes dependencies
         """
         dagbag = models.DagBag()
-        dagbag.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
+        fsfetcher = FileSystemDagFetcher(dagbag)
+        fsfetcher.process_file(os.path.join(TEST_DAGS_FOLDER, "test_zip.zip"))
         self.assertTrue(dagbag.get_dag("test_zip_dag"))
 
-    @patch.object(DagModel,'get_current')
-    def test_get_dag_without_refresh(self, mock_dagmodel):
+    def test_get_dag_fetcher(self):
         """
-        Test that, once a DAG is loaded, it doesn't get refreshed again if it
-        hasn't been expired.
+        Test that get_dag_fetcher returns the correct dag fetchers.
         """
-        dag_id = 'example_bash_operator'
-
-        mock_dagmodel.return_value = DagModel()
-        mock_dagmodel.return_value.last_expired = None
-        mock_dagmodel.return_value.fileloc = 'foo'
-
-        class TestDagBag(models.DagBag):
-            process_file_calls = 0
-            def process_file(self, filepath, only_if_updated=True, safe_mode=True):
-                if 'example_bash_operator.py' == os.path.basename(filepath):
-                    TestDagBag.process_file_calls += 1
-                super(TestDagBag, self).process_file(filepath, only_if_updated, safe_mode)
-
-        dagbag = TestDagBag(include_examples=True)
-        processed_files = dagbag.process_file_calls
-
-        # Should not call process_file agani, since it's already loaded during init.
-        self.assertEqual(1, dagbag.process_file_calls)
-        self.assertIsNotNone(dagbag.get_dag(dag_id))
-        self.assertEqual(1, dagbag.process_file_calls)
+        dagbag = models.DagBag()
+        default_fetcher = get_dag_fetcher(dagbag, '/a/local/path/without/schema/dags')
+        hdfs_fetcher = get_dag_fetcher(dagbag, 'hdfs://host:optional-port/dags')
+        s3_fetcher = get_dag_fetcher(dagbag, 's3://bucket/dags')
+        gcs_fetcher = get_dag_fetcher(dagbag, 'gcs://bucket/dags')
+        git_fetcher = get_dag_fetcher(dagbag, 'git://github.com/apache/airflow.git')
+
+        self.assertIsInstance(default_fetcher, FileSystemDagFetcher)
+        self.assertIsInstance(hdfs_fetcher, HDFSDagFetcher)
+        self.assertIsInstance(s3_fetcher, S3DagFetcher)
+        self.assertIsInstance(gcs_fetcher, GCSDagFetcher)
+        self.assertIsInstance(git_fetcher, GitDagFetcher)
 
     def test_get_dag_fileloc(self):
         """
@@ -996,7 +995,8 @@ def process_dag(self, create_dag):
         f.flush()
 
         dagbag = models.DagBag(include_examples=False)
-        found_dags = dagbag.process_file(f.name)
+        fsfetcher = FileSystemDagFetcher(dagbag)
+        found_dags = fsfetcher.process_file(f.name)
         return (dagbag, found_dags, f.name)
 
     def validate_dags(self, expected_parent_dag, actual_found_dags, actual_dagbag,
@@ -1301,13 +1301,14 @@ def subdag_1():
         self.validate_dags(testDag, found_dags, dagbag, should_be_found=False)
         self.assertIn(file_path, dagbag.import_errors)
 
-    def test_process_file_with_none(self):
+    def test_process_local_file_with_none(self):
         """
         test that process_file can handle Nones
         """
         dagbag = models.DagBag(include_examples=True)
+        fsfetcher = FileSystemDagFetcher(dagbag)
 
-        self.assertEqual([], dagbag.process_file(None))
+        self.assertEqual([], fsfetcher.process_file(None))
 
 
 class TaskInstanceTest(unittest.TestCase):
diff --git a/tests/plugins/test_plugin.py b/tests/plugins/test_plugin.py
index 49325e68b7..fac674b214 100644
--- a/tests/plugins/test_plugin.py
+++ b/tests/plugins/test_plugin.py
@@ -24,11 +24,14 @@
 from airflow.models import BaseOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator
 from airflow.executors.base_executor import BaseExecutor
+from airflow.dag.fetchers.base import BaseDagFetcher
+
 
 # Will show up under airflow.hooks.test_plugin.PluginHook
 class PluginHook(BaseHook):
     pass
 
+
 # Will show up under airflow.operators.test_plugin.PluginOperator
 class PluginOperator(BaseOperator):
     pass
@@ -43,10 +46,17 @@ class PluginSensorOperator(BaseSensorOperator):
 class PluginExecutor(BaseExecutor):
     pass
 
+
 # Will show up under airflow.macros.test_plugin.plugin_macro
 def plugin_macro():
     pass
 
+
+# Will show up under airflow.dag.fetchers.test_plugin.PluginDagFetcher
+class PluginDagFetcher(BaseDagFetcher):
+    pass
+
+
 # Creating a flask admin BaseView
 class TestView(BaseView):
     @expose('/')
@@ -76,6 +86,7 @@ class AirflowTestPlugin(AirflowPlugin):
     hooks = [PluginHook]
     executors = [PluginExecutor]
     macros = [plugin_macro]
+    dag_fetchers = [PluginDagFetcher]
     admin_views = [v]
     flask_blueprints = [bp]
     menu_links = [ml]
diff --git a/tests/plugins_manager.py b/tests/plugins_manager.py
index a00d476f03..9b76bf7d31 100644
--- a/tests/plugins_manager.py
+++ b/tests/plugins_manager.py
@@ -26,9 +26,10 @@
 from flask_admin.menu import MenuLink, MenuView
 
 from airflow.hooks.base_hook import BaseHook
-from airflow.models import  BaseOperator
+from airflow.models import BaseOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator
 from airflow.executors.base_executor import BaseExecutor
+from airflow.dag.fetchers.base import BaseDagFetcher
 from airflow.www.app import cached_app
 
 
@@ -62,6 +63,10 @@ def test_macros(self):
         from airflow.macros.test_plugin import plugin_macro
         self.assertTrue(callable(plugin_macro))
 
+    def test_fetchers(self):
+        from airflow.dag.fetchers.test_plugin import PluginDagFetcher
+        self.assertTrue(issubclass(PluginDagFetcher, BaseDagFetcher))
+
     def test_admin_views(self):
         app = cached_app()
         [admin] = app.extensions['admin']


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services