You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/08/19 10:00:59 UTC

[airflow] branch master updated: Add back 'refresh_all' method in airflow/www/views.py (#10328)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 3bc3701  Add back 'refresh_all' method in airflow/www/views.py (#10328)
3bc3701 is described below

commit 3bc37013f6efc5cf758821f06257a02cae6e2d52
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Wed Aug 19 10:59:36 2020 +0100

    Add back 'refresh_all' method in airflow/www/views.py (#10328)
    
    closes https://github.com/apache/airflow/issues/9749
---
 airflow/models/dagbag.py    | 21 +++++++++++++++++++++
 airflow/www/views.py        | 15 +++++++++++++++
 tests/models/test_dagbag.py | 20 ++++++++++++++++++++
 tests/www/test_views.py     | 18 ++++++++++++++++++
 4 files changed, 74 insertions(+)

diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 5f78215..a9def70 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -450,6 +450,27 @@ class DagBag(BaseDagBag, LoggingMixin):
                          format(filename),
                          file_stat.duration)
 
+    def collect_dags_from_db(self):
+        """Collects DAGs from database."""
+        from airflow.models.serialized_dag import SerializedDagModel
+        start_dttm = timezone.utcnow()
+        self.log.info("Filling up the DagBag from database")
+
+        # The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
+        # from the table by the scheduler job.
+        self.dags = SerializedDagModel.read_all_dags()
+
+        # Adds subdags.
+        # DAG post-processing steps such as self.bag_dag and croniter are not needed as
+        # they are done by scheduler before serialization.
+        subdags = {}
+        for dag in self.dags.values():
+            for subdag in dag.subdags:
+                subdags[subdag.dag_id] = subdag
+        self.dags.update(subdags)
+
+        Stats.timing('collect_db_dags', timezone.utcnow() - start_dttm)
+
     def dagbag_report(self):
         """Prints a report around DagBag loading stats"""
         stats = self.dagbag_stats
diff --git a/airflow/www/views.py b/airflow/www/views.py
index f649946..4ea826c 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -1936,6 +1936,21 @@ class Airflow(AirflowBaseView):  # noqa: D101
         flash("DAG [{}] is now fresh as a daisy".format(dag_id))
         return redirect(request.referrer)
 
+    @expose('/refresh_all', methods=['POST'])
+    @has_access
+    @action_logging
+    def refresh_all(self):
+        if settings.STORE_SERIALIZED_DAGS:
+            current_app.dag_bag.collect_dags_from_db()
+        else:
+            current_app.dag_bag.collect_dags(only_if_updated=False)
+
+        # sync permissions for all dags
+        for dag_id, dag in current_app.dag_bag.dags.items():
+            current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
+        flash("All DAGs are now up to date")
+        return redirect(url_for('Airflow.index'))
+
     @expose('/gantt')
     @has_dag_access(can_dag_read=True)
     @has_access
diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py
index f895638..9173ad1 100644
--- a/tests/models/test_dagbag.py
+++ b/tests/models/test_dagbag.py
@@ -701,6 +701,26 @@ class TestDagBag(unittest.TestCase):
         self.assertCountEqual(updated_ser_dag_1.tags, ["example", "new_tag"])
         self.assertGreater(updated_ser_dag_1_update_time, ser_dag_1_update_time)
 
+    def test_collect_dags_from_db(self):
+        """DAGs are collected from Database"""
+        example_dags_folder = airflow.example_dags.__path__[0]
+        dagbag = DagBag(example_dags_folder)
+
+        example_dags = dagbag.dags
+        for dag in example_dags.values():
+            SerializedDagModel.write_dag(dag)
+
+        new_dagbag = DagBag(read_dags_from_db=True)
+        self.assertEqual(len(new_dagbag.dags), 0)
+        new_dagbag.collect_dags_from_db()
+        new_dags = new_dagbag.dags
+        self.assertEqual(len(example_dags), len(new_dags))
+        for dag_id, dag in example_dags.items():
+            serialized_dag = new_dags[dag_id]
+
+            self.assertEqual(serialized_dag.dag_id, dag.dag_id)
+            self.assertEqual(set(serialized_dag.task_dict), set(dag.task_dict))
+
     @patch("airflow.settings.policy", cluster_policies.cluster_policy)
     def test_cluster_policy_violation(self):
         """test that file processing results in import error when task does not
diff --git a/tests/www/test_views.py b/tests/www/test_views.py
index 11fdac2..db3321b 100644
--- a/tests/www/test_views.py
+++ b/tests/www/test_views.py
@@ -964,6 +964,24 @@ class TestAirflowBaseViews(TestBase):
         resp = self.client.post('refresh?dag_id=example_bash_operator')
         self.check_content_in_response('', resp, resp_code=302)
 
+    @parameterized.expand([(True,), (False,)])
+    def test_refresh_all(self, dag_serialization):
+        with mock.patch('airflow.www.views.settings.STORE_SERIALIZED_DAGS', dag_serialization):
+            if dag_serialization:
+                with mock.patch.object(
+                    self.app.dag_bag, 'collect_dags_from_db'
+                ) as collect_dags_from_db:
+                    resp = self.client.post("/refresh_all", follow_redirects=True)
+                    self.check_content_in_response('', resp)
+                    collect_dags_from_db.assert_called_once_with()
+            else:
+                with mock.patch.object(
+                    self.app.dag_bag, 'collect_dags'
+                ) as collect_dags:
+                    resp = self.client.post("/refresh_all", follow_redirects=True)
+                    self.check_content_in_response('', resp)
+                    collect_dags.assert_called_once_with(only_if_updated=False)
+
     def test_delete_dag_button_normal(self):
         resp = self.client.get('/', follow_redirects=True)
         self.check_content_in_response('/delete?dag_id=example_bash_operator', resp)