You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/15 03:32:56 UTC
[airflow] 21/47: Update Serialized DAGs in Webserver when DAGs are
Updated (#9851)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit b95126ba14c861d7bcbb1dbf9248aec73a8b343b
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Mon Jul 20 12:45:18 2020 +0100
Update Serialized DAGs in Webserver when DAGs are Updated (#9851)
Before this change, if DAG Serialization was enabled the Webserver would not update the DAGs once they are fetched from DB. The default worker_refresh_interval was `30` so whenever the gunicorn workers were restarted, they used to pull the updated DAGs when needed.
This change will allow us to have a larged worker_refresh_interval (e.g 30 mins or even 1 day)
(cherry picked from commit 84b85d8acc181edfe1fdd21b82c1773c19c47044)
---
airflow/config_templates/config.yml | 8 +++
airflow/config_templates/default_airflow.cfg | 4 ++
airflow/models/dagbag.py | 40 +++++++++++----
airflow/models/serialized_dag.py | 14 ++++++
airflow/settings.py | 5 ++
docs/dag-serialization.rst | 11 ++++-
tests/models/test_dagbag.py | 45 +++++++++++++++++
tests/test_utils/asserts.py | 73 ++++++++++++++++++++++++++++
8 files changed, 188 insertions(+), 12 deletions(-)
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index 75c47cb..9535d5b 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -447,6 +447,14 @@
type: string
example: ~
default: "30"
+ - name: min_serialized_dag_fetch_interval
+ description: |
+ Fetching serialized DAG can not be faster than a minimum interval to reduce database
+ read rate. This config controls when your DAGs are updated in the Webserver
+ version_added: 1.10.12
+ type: string
+ example: ~
+ default: "10"
- name: store_dag_code
description: |
Whether to persist DAG files code in DB.
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index 3a9bba2..9729403 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -234,6 +234,10 @@ store_serialized_dags = False
# Updating serialized DAG can not be faster than a minimum interval to reduce database write rate.
min_serialized_dag_update_interval = 30
+# Fetching serialized DAG can not be faster than a minimum interval to reduce database
+# read rate. This config controls when your DAGs are updated in the Webserver
+min_serialized_dag_fetch_interval = 10
+
# Whether to persist DAG files code in DB.
# If set to True, Webserver reads file contents from DB instead of
# trying to access files in a DAG folder. Defaults to same as the
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 48cbd3e..1b8be89 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -28,7 +28,7 @@ import sys
import textwrap
import zipfile
from collections import namedtuple
-from datetime import datetime
+from datetime import datetime, timedelta
from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter
import six
@@ -102,6 +102,7 @@ class DagBag(BaseDagBag, LoggingMixin):
self.import_errors = {}
self.has_logged = False
self.store_serialized_dags = store_serialized_dags
+ self.dags_last_fetched = {}
self.collect_dags(
dag_folder=dag_folder,
@@ -127,20 +128,26 @@ class DagBag(BaseDagBag, LoggingMixin):
"""
from airflow.models.dag import DagModel # Avoid circular import
- # Only read DAGs from DB if this dagbag is store_serialized_dags.
if self.store_serialized_dags:
# Import here so that serialized dag is only imported when serialization is enabled
from airflow.models.serialized_dag import SerializedDagModel
if dag_id not in self.dags:
# Load from DB if not (yet) in the bag
- row = SerializedDagModel.get(dag_id)
- if not row:
- return None
-
- dag = row.dag
- for subdag in dag.subdags:
- self.dags[subdag.dag_id] = subdag
- self.dags[dag.dag_id] = dag
+ self._add_dag_from_db(dag_id=dag_id)
+ return self.dags.get(dag_id)
+
+ # If DAG is in the DagBag, check the following
+ # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
+ # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
+ # 3. if (2) is yes, fetch the Serialized DAG.
+ min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
+ if (
+ dag_id in self.dags_last_fetched and
+ timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs
+ ):
+ sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id)
+ if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
+ self._add_dag_from_db(dag_id=dag_id)
return self.dags.get(dag_id)
@@ -178,6 +185,19 @@ class DagBag(BaseDagBag, LoggingMixin):
del self.dags[dag_id]
return self.dags.get(dag_id)
+ def _add_dag_from_db(self, dag_id):
+ """Add DAG to DagBag from DB"""
+ from airflow.models.serialized_dag import SerializedDagModel
+ row = SerializedDagModel.get(dag_id)
+ if not row:
+ raise ValueError("DAG '{}' not found in serialized_dag table".format(dag_id))
+
+ dag = row.dag
+ for subdag in dag.subdags:
+ self.dags[subdag.dag_id] = subdag
+ self.dags[dag.dag_id] = dag
+ self.dags_last_fetched[dag.dag_id] = timezone.utcnow()
+
def process_file(self, filepath, only_if_updated=True, safe_mode=True):
"""
Given a path to a python module or zip file, this method imports
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 1313cac..d29e43c 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -219,3 +219,17 @@ class SerializedDagModel(Base):
DagModel.root_dag_id).filter(DagModel.dag_id == dag_id).scalar()
return session.query(cls).filter(cls.dag_id == root_dag_id).one_or_none()
+
+ @classmethod
+ @db.provide_session
+ def get_last_updated_datetime(cls, dag_id, session):
+ """
+ Get the date when the Serialized DAG associated to DAG was last updated
+ in serialized_dag table
+
+ :param dag_id: DAG ID
+ :type dag_id: str
+ :param session: ORM Session
+ :type session: Session
+ """
+ return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar()
diff --git a/airflow/settings.py b/airflow/settings.py
index 0158ec8..e39c960 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -428,6 +428,11 @@ STORE_SERIALIZED_DAGS = conf.getboolean('core', 'store_serialized_dags', fallbac
MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
'core', 'min_serialized_dag_update_interval', fallback=30)
+# Fetching serialized DAG can not be faster than a minimum interval to reduce database
+# read rate. This config controls when your DAGs are updated in the Webserver
+MIN_SERIALIZED_DAG_FETCH_INTERVAL = conf.getint(
+ 'core', 'min_serialized_dag_fetch_interval', fallback=10)
+
# Whether to persist DAG files code in DB. If set to True, Webserver reads file contents
# from DB instead of trying to access files in a DAG folder.
# Defaults to same as the store_serialized_dags setting.
diff --git a/docs/dag-serialization.rst b/docs/dag-serialization.rst
index 0edd644..e2fcf14 100644
--- a/docs/dag-serialization.rst
+++ b/docs/dag-serialization.rst
@@ -57,14 +57,21 @@ Add the following settings in ``airflow.cfg``:
[core]
store_serialized_dags = True
+ store_dag_code = True
+
+ # You can also update the following default configurations based on your needs
min_serialized_dag_update_interval = 30
+ min_serialized_dag_fetch_interval = 10
* ``store_serialized_dags``: This flag decides whether to serialise DAGs and persist them in DB.
If set to True, Webserver reads from DB instead of parsing DAG files
-* ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which
- the serialized DAG in DB should be updated. This helps in reducing database write rate.
* ``store_dag_code``: This flag decides whether to persist DAG files code in DB.
If set to True, Webserver reads file contents from DB instead of trying to access files in a DAG folder.
+* ``min_serialized_dag_update_interval``: This flag sets the minimum interval (in seconds) after which
+ the serialized DAG in DB should be updated. This helps in reducing database write rate.
+* ``min_serialized_dag_fetch_interval``: This flag controls how often a SerializedDAG will be re-fetched
+ from the DB when it's already loaded in the DagBag in the Webserver. Setting this higher will reduce
+ load on the DB, but at the expense of displaying a possibly stale cached version of the DAG.
If you are updating Airflow from <1.10.7, please do not forget to run ``airflow upgradedb``.
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index 04c2372..b9d18ac 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -19,6 +19,7 @@
import inspect
import os
+import six
import shutil
import textwrap
import unittest
@@ -26,15 +27,19 @@ from datetime import datetime
from tempfile import NamedTemporaryFile, mkdtemp
from mock import patch, ANY
+from freezegun import freeze_time
from airflow import models
from airflow.configuration import conf
from airflow.utils.dag_processing import SimpleTaskInstance
from airflow.models import DagModel, DagBag, TaskInstance as TI
+from airflow.models.serialized_dag import SerializedDagModel
+from airflow.utils.dates import timezone as tz
from airflow.utils.db import create_session
from airflow.utils.state import State
from airflow.utils.timezone import utc
from tests.models import TEST_DAGS_FOLDER, DEFAULT_DATE
+from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
import airflow.example_dags
@@ -650,3 +655,43 @@ class DagBagTest(unittest.TestCase):
# clean up
with create_session() as session:
session.query(DagModel).filter(DagModel.dag_id == 'test_deactivate_unknown_dags').delete()
+
+ @patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", True)
+ @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5)
+ @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5)
+ def test_get_dag_with_dag_serialization(self):
+ """
+ Test that Serialized DAG is updated in DagBag when it is updated in
+ Serialized DAG table after 'min_serialized_dag_fetch_interval' seconds are passed.
+ """
+
+ with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)):
+ example_bash_op_dag = DagBag(include_examples=True).dags.get("example_bash_operator")
+ SerializedDagModel.write_dag(dag=example_bash_op_dag)
+
+ dag_bag = DagBag(store_serialized_dags=True)
+ ser_dag_1 = dag_bag.get_dag("example_bash_operator")
+ ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
+ self.assertEqual(example_bash_op_dag.tags, ser_dag_1.tags)
+ self.assertEqual(ser_dag_1_update_time, tz.datetime(2020, 1, 5, 0, 0, 0))
+
+ # Check that if min_serialized_dag_fetch_interval has not passed we do not fetch the DAG
+ # from DB
+ with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 4)):
+ with assert_queries_count(0):
+ self.assertEqual(dag_bag.get_dag("example_bash_operator").tags, ["example"])
+
+ # Make a change in the DAG and write Serialized DAG to the DB
+ with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 6)):
+ example_bash_op_dag.tags += ["new_tag"]
+ SerializedDagModel.write_dag(dag=example_bash_op_dag)
+
+ # Since min_serialized_dag_fetch_interval is passed verify that calling 'dag_bag.get_dag'
+ # fetches the Serialized DAG from DB
+ with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 8)):
+ with assert_queries_count(2):
+ updated_ser_dag_1 = dag_bag.get_dag("example_bash_operator")
+ updated_ser_dag_1_update_time = dag_bag.dags_last_fetched["example_bash_operator"]
+
+ six.assertCountEqual(self, updated_ser_dag_1.tags, ["example", "new_tag"])
+ self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py
new file mode 100644
index 0000000..ca3cf2f
--- /dev/null
+++ b/tests/test_utils/asserts.py
@@ -0,0 +1,73 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import logging
+import re
+from contextlib import contextmanager
+
+from sqlalchemy import event
+
+# Long import to not create a copy of the reference, but to refer to one place.
+import airflow.settings
+
+log = logging.getLogger(__name__)
+
+
+def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
+ def _trim(s):
+ return re.sub(r"\s+", " ", s.strip())
+ return case.assertEqual(_trim(first), _trim(second), msg)
+
+
+class CountQueriesResult:
+ def __init__(self):
+ self.count = 0
+
+
+class CountQueries:
+ """
+ Counts the number of queries sent to Airflow Database in a given context.
+
+ Does not support multiple processes. When a new process is started in context, its queries will
+ not be included.
+ """
+ def __init__(self):
+ self.result = CountQueriesResult()
+
+ def __enter__(self):
+ event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
+ return self.result
+
+ def __exit__(self, type_, value, traceback):
+ event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
+ log.debug("Queries count: %d", self.result.count)
+
+ def after_cursor_execute(self, *args, **kwargs):
+ self.result.count += 1
+
+
+count_queries = CountQueries # pylint: disable=invalid-name
+
+
+@contextmanager
+def assert_queries_count(expected_count, message_fmt=None):
+ with count_queries() as result:
+ yield None
+ message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
+ "The current number is {current_count}."
+ message = message_fmt.format(current_count=result.count, expected_count=expected_count)
+ assert expected_count == result.count, message