You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/15 03:32:56 UTC

[airflow] 21/47: Update Serialized DAGs in Webserver when DAGs are Updated (#9851)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit b95126ba14c861d7bcbb1dbf9248aec73a8b343b
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Mon Jul 20 12:45:18 2020 +0100

    Update Serialized DAGs in Webserver when DAGs are Updated (#9851)
    
    Before this change, if DAG Serialization was enabled the Webserver would not update the DAGs once they are fetched from DB. The default worker_refresh_interval was `30` so whenever the gunicorn workers were restarted, they used to pull the updated DAGs when needed.
    
    This change will allow us to have a larged worker_refresh_interval (e.g 30 mins or even 1 day)
    
    (cherry picked from commit 84b85d8acc181edfe1fdd21b82c1773c19c47044)
---
 airflow/config_templates/config.yml          |  8 +++
 airflow/config_templates/default_airflow.cfg |  4 ++
 airflow/models/dagbag.py                     | 40 +++++++++++----
 airflow/models/serialized_dag.py             | 14 ++++++
 airflow/settings.py                          |  5 ++
 docs/dag-serialization.rst                   | 11 ++++-
 tests/models/test_dagbag.py                  | 45 +++++++++++++++++
 tests/test_utils/asserts.py                  | 73 ++++++++++++++++++++++++++++
 8 files changed, 188 insertions(+), 12 deletions(-)

diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 75c47cb..9535d5b 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -447,6 +447,14 @@
       type: string
       example: ~
       default: "30"
+    - name: min_serialized_dag_fetch_interval
+      description: |
+        Fetching serialized DAG can not be faster than a minimum interval to reduce database
+        read rate. This config controls when your DAGs are updated in the Webserver
+      version_added: 1.10.12
+      type: string
+      example: ~
+      default: "10"
     - name: store_dag_code
       description: |
         Whether to persist DAG files code in DB.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 3a9bba2..9729403 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -234,6 +234,10 @@ store_serialized_dags = False
 # Updating serialized DAG can not be faster than a minimum interval to reduce database write rate.
 min_serialized_dag_update_interval = 30
 
+# Fetching serialized DAG can not be faster than a minimum interval to reduce database
+# read rate. This config controls when your DAGs are updated in the Webserver
+min_serialized_dag_fetch_interval = 10
+
 # Whether to persist DAG files code in DB.
 # If set to True, Webserver reads file contents from DB instead of
 # trying to access files in a DAG folder. Defaults to same as the
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 48cbd3e..1b8be89 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -28,7 +28,7 @@ import sys
 import textwrap
 import zipfile
 from collections import namedtuple
-from datetime import datetime
+from datetime import datetime, timedelta
 
 from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter
 import six
@@ -102,6 +102,7 @@ class DagBag(BaseDagBag, LoggingMixin):
         self.import_errors = {}
         self.has_logged = False
         self.store_serialized_dags = store_serialized_dags
+        self.dags_last_fetched = {}
 
         self.collect_dags(
             dag_folder=dag_folder,
@@ -127,20 +128,26 @@ class DagBag(BaseDagBag, LoggingMixin):
         """
         from airflow.models.dag import DagModel  # Avoid circular import
 
-        # Only read DAGs from DB if this dagbag is store_serialized_dags.
         if self.store_serialized_dags:
             # Import here so that serialized dag is only imported when serialization is enabled
             from airflow.models.serialized_dag import SerializedDagModel
             if dag_id not in self.dags:
                 # Load from DB if not (yet) in the bag
-                row = SerializedDagModel.get(dag_id)
-                if not row:
-                    return None
-
-                dag = row.dag
-                for subdag in dag.subdags:
-                    self.dags[subdag.dag_id] = subdag
-                self.dags[dag.dag_id] = dag
+                self._add_dag_from_db(dag_id=dag_id)
+                return self.dags.get(dag_id)
+
+            # If DAG is in the DagBag, check the following
+            # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
+            # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
+            # 3. if (2) is yes, fetch the Serialized DAG.
+            min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
+            if (
+                dag_id in self.dags_last_fetched and
+                timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs
+            ):
+                sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id)
+                if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
+                    self._add_dag_from_db(dag_id=dag_id)
 
             return self.dags.get(dag_id)
 
@@ -178,6 +185,19 @@ class DagBag(BaseDagBag, LoggingMixin):
                 del self.dags[dag_id]
         return self.dags.get(dag_id)
 
+    def _add_dag_from_db(self, dag_id):
+        """Add DAG to DagBag from DB"""
+        from airflow.models.serialized_dag import SerializedDagModel
+        row = SerializedDagModel.get(dag_id)
+        if not row:
+            raise ValueError("DAG '{}' not found in serialized_dag table".format(dag_id))
+
+        dag = row.dag
+        for subdag in dag.subdags:
+            self.dags[subdag.dag_id] = subdag
+        self.dags[dag.dag_id] = dag
+        self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
+
     def process_file(self, filepath, only_if_updated=True, safe_mode=True):
         """
         Given a path to a python module or zip file, this method imports
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 1313cac..d29e43c 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -219,3 +219,17 @@ class SerializedDagModel(Base):
             DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar()
 
         return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none()
+
+    @classmethod
+    @db.provide_session
+    def get_last_updated_datetime(cls, dag_id, session):
+        """
+        Get the date when the Serialized DAG associated to DAG was last updated
+        in serialized_dag table
+
+        :param dag_id: DAG ID
+        :type dag_id: str
+        :param session: ORM Session
+        :type session: Session
+        """
+        return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar()
diff --git a/airflow/settings.py b/airflow/settings.py
index 0158ec8..e39c960 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -428,6 +428,11 @@ STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallbac
 MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
     'core', 'min_serialized_dag_update_interval', fallback=30)
 
+# Fetching serialized DAG can not be faster than a minimum interval to reduce database
+# read rate. This config controls when your DAGs are updated in the Webserver
+MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
+    'core', 'min_serialized_dag_fetch_interval', fallback=10)
+
 # Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
 # from DB instead of trying to access files in a DAG folder.
 # Defaults to same as the store_serialized_dags setting.
diff --git a/docs/dag-serialization.rst b/docs/dag-serialization.rst
index 0edd644..e2fcf14 100644
--- a/docs/dag-serialization.rst
+++ b/docs/dag-serialization.rst
@@ -57,14 +57,21 @@ Add the following settings in ``airflow.cfg``:
 
     [core]
     store_serialized_dags = True
+    store_dag_code = True
+
+    # You can also update the following default configurations based on your needs
     min_serialized_dag_update_interval = 30
+    min_serialized_dag_fetch_interval = 10
 
 *   ``store_serialized_dags``: This flag decides whether to serialise DAGs and persist them in DB.
     If set to True, Webserver reads from DB instead of parsing DAG files
-*   ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which
-    the serialized DAG in DB should be updated. This helps in reducing database write rate.
 *   ``store_dag_code``: This flag decides whether to persist DAG files code in DB.
     If set to True, Webserver reads file contents from DB instead of trying to access files in a DAG folder.
+*   ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which
+    the serialized DAG in DB should be updated. This helps in reducing database write rate.
+*   ``min_serialized_dag_fetch_interval``: This flag controls how often a SerializedDAG will be re-fetched
+    from the DB when it's already loaded in the DagBag in the Webserver. Setting this higher will reduce
+    load on the DB, but at the expense of displaying a possibly stale cached version of the DAG.
 
 If you are updating Airflow from <1.10.7, please do not forget to run ``airflow upgradedb``.
 
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 04c2372..b9d18ac 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -19,6 +19,7 @@
 
 import inspect
 import os
+import six
 import shutil
 import textwrap
 import unittest
@@ -26,15 +27,19 @@ from datetime import datetime
 from tempfile import NamedTemporaryFile, mkdtemp
 
 from mock import patch, ANY
+from freezegun import freeze_time
 
 from airflow import models
 from airflow.configuration import conf
 from airflow.utils.dag_processing import SimpleTaskInstance
 from airflow.models import DagModel, DagBag, TaskInstance as TI
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.utils.dates import timezone as tz
 from airflow.utils.db import create_session
 from airflow.utils.state import State
 from airflow.utils.timezone import utc
 from tests.models import TEST_DAGS_FOLDER, DEFAULT_DATE
+from tests.test_utils.asserts import assert_queries_count
 from tests.test_utils.config import conf_vars
 import airflow.example_dags
 
@@ -650,3 +655,43 @@ class DagBagTest(unittest.TestCase):
         # clean up
         with create_session() as session:
             session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
+
+    @patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
+    @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5)
+    @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5)
+    def test_get_dag_with_dag_serialization(self):
+        """
+        Test that Serialized DAG is updated in DagBag when it is updated in
+        Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed.
+        """
+
+        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)):
+            example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator")
+            SerializedDagModel.write_dag(dag=example_bash_op_dag)
+
+            dag_bag = DagBag(store_serialized_dags=True)
+            ser_dag_1 = dag_bag.get_dag("example_bash_operator")
+            ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
+            self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags)
+            self.assertEqual(ser_dag_1_update_time, tz.datetime(2020, 1, 5, 0, 0, 0))
+
+        # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
+        # from DB
+        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
+            with assert_queries_count(0):
+                self.assertEqual(dag_bag.get_dag("example_bash_operator").tags, ["example"])
+
+        # Make a change in the DAG and write Serialized DAG to the DB
+        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
+            example_bash_op_dag.tags += ["new_tag"]
+            SerializedDagModel.write_dag(dag=example_bash_op_dag)
+
+        # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
+        # fetches the Serialized DAG from DB
+        with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)):
+            with assert_queries_count(2):
+                updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
+                updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
+
+        six.assertCountEqual(self, updated_ser_dag_1.tags, ["example", "new_tag"])
+        self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py
new file mode 100644
index 0000000..ca3cf2f
--- /dev/null
+++ b/tests/test_utils/asserts.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import re
+from contextlib import contextmanager
+
+from sqlalchemy import event
+
+# Long import to not create a copy of the reference, but to refer to one place.
+import airflow.settings
+
+log = logging.getLogger(__name__)
+
+
+def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
+    def _trim(s):
+        return re.sub(r"\s+", " ", s.strip())
+    return case.assertEqual(_trim(first), _trim(second), msg)
+
+
+class CountQueriesResult:
+    def __init__(self):
+        self.count = 0
+
+
+class CountQueries:
+    """
+    Counts the number of queries sent to Airflow Database in a given context.
+
+    Does not support multiple processes. When a new process is started in context, its queries will
+    not be included.
+    """
+    def __init__(self):
+        self.result = CountQueriesResult()
+
+    def __enter__(self):
+        event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
+        return self.result
+
+    def __exit__(self, type_, value, traceback):
+        event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
+        log.debug("Queries count: %d", self.result.count)
+
+    def after_cursor_execute(self, *args, **kwargs):
+        self.result.count += 1
+
+
+count_queries = CountQueries  # pylint: disable=invalid-name
+
+
+@contextmanager
+def assert_queries_count(expected_count, message_fmt=None):
+    with count_queries() as result:
+        yield None
+    message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
+                                 "The current number is {current_count}."
+    message = message_fmt.format(current_count=result.count, expected_count=expected_count)
+    assert expected_count == result.count, message