You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/19 10:00:59 UTC
[airflow] branch master updated: Add back 'refresh_all' method in
airflow/www/views.py (#10328)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 3bc3701 Add back 'refresh_all' method in airflow/www/views.py (#10328)
3bc3701 is described below
commit 3bc37013f6efc5cf758821f06257a02cae6e2d52
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Wed Aug 19 10:59:36 2020 +0100
Add back 'refresh_all' method in airflow/www/views.py (#10328)
closes https://github.com/apache/airflow/issues/9749
---
airflow/models/dagbag.py | 21 +++++++++++++++++++++
airflow/www/views.py | 15 +++++++++++++++
tests/models/test_dagbag.py | 20 ++++++++++++++++++++
tests/www/test_views.py | 18 ++++++++++++++++++
4 files changed, 74 insertions(+)
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 5f78215..a9def70 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -450,6 +450,27 @@ class DagBag(BaseDagBag, LoggingMixin):
format(filename),
file_stat.duration)
+ def collect_dags_from_db(self):
+ """Collects DAGs from database."""
+ from airflow.models.serialized_dag import SerializedDagModel
+ start_dttm = timezone.utcnow()
+ self.log.info("Filling up the DagBag from database")
+
+ # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
+ # from the table by the scheduler job.
+ self.dags = SerializedDagModel.read_all_dags()
+
+ # Adds subdags.
+ # DAG post-processing steps such as self.bag_dag and croniter are not needed as
+ # they are done by scheduler before serialization.
+ subdags = {}
+ for dag in self.dags.values():
+ for subdag in dag.subdags:
+ subdags[subdag.dag_id] = subdag
+ self.dags.update(subdags)
+
+ Stats.timing('collect_db_dags', timezone.utcnow() - start_dttm)
+
def dagbag_report(self):
"""Prints a report around DagBag loading stats"""
stats = self.dagbag_stats
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f649946..4ea826c 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1936,6 +1936,21 @@ class Airflow(AirflowBaseView): # noqa: D101
flash("DAG [{}] is now fresh as a daisy".format(dag_id))
return redirect(request.referrer)
+ @expose('/refresh_all', methods=['POST'])
+ @has_access
+ @action_logging
+ def refresh_all(self):
+ if settings.STORE_SERIALIZED_DAGS:
+ current_app.dag_bag.collect_dags_from_db()
+ else:
+ current_app.dag_bag.collect_dags(only_if_updated=False)
+
+ # sync permissions for all dags
+ for dag_id, dag in current_app.dag_bag.dags.items():
+ current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
+ flash("All DAGs are now up to date")
+ return redirect(url_for('Airflow.index'))
+
@expose('/gantt')
@has_dag_access(can_dag_read=True)
@has_access
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index f895638..9173ad1 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -701,6 +701,26 @@ class TestDagBag(unittest.TestCase):
self.assertCountEqual(updated_ser_dag_1.tags, ["example", "new_tag"])
self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
+ def test_collect_dags_from_db(self):
+ """DAGs are collected from Database"""
+ example_dags_folder = airflow.example_dags.__path__[0]
+ dagbag = DagBag(example_dags_folder)
+
+ example_dags = dagbag.dags
+ for dag in example_dags.values():
+ SerializedDagModel.write_dag(dag)
+
+ new_dagbag = DagBag(read_dags_from_db=True)
+ self.assertEqual(len(new_dagbag.dags), 0)
+ new_dagbag.collect_dags_from_db()
+ new_dags = new_dagbag.dags
+ self.assertEqual(len(example_dags), len(new_dags))
+ for dag_id, dag in example_dags.items():
+ serialized_dag = new_dags[dag_id]
+
+ self.assertEqual(serialized_dag.dag_id, dag.dag_id)
+ self.assertEqual(set(serialized_dag.task_dict), set(dag.task_dict))
+
@patch("airflow.settings.policy", cluster_policies.cluster_policy)
def test_cluster_policy_violation(self):
"""test that file processing results in import error when task does not
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 11fdac2..db3321b 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -964,6 +964,24 @@ class TestAirflowBaseViews(TestBase):
resp = self.client.post('refresh?dag_id=example_bash_operator')
self.check_content_in_response('', resp, resp_code=302)
+ @parameterized.expand([(True,), (False,)])
+ def test_refresh_all(self, dag_serialization):
+ with mock.patch('airflow.www.views.settings.STORE_SERIALIZED_DAGS', dag_serialization):
+ if dag_serialization:
+ with mock.patch.object(
+ self.app.dag_bag, 'collect_dags_from_db'
+ ) as collect_dags_from_db:
+ resp = self.client.post("/refresh_all", follow_redirects=True)
+ self.check_content_in_response('', resp)
+ collect_dags_from_db.assert_called_once_with()
+ else:
+ with mock.patch.object(
+ self.app.dag_bag, 'collect_dags'
+ ) as collect_dags:
+ resp = self.client.post("/refresh_all", follow_redirects=True)
+ self.check_content_in_response('', resp)
+ collect_dags.assert_called_once_with(only_if_updated=False)
+
def test_delete_dag_button_normal(self):
resp = self.client.get('/', follow_redirects=True)
self.check_content_in_response('/delete?dag_id=example_bash_operator', resp)