You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/14 15:47:16 UTC

[airflow] branch master updated: BugFix: Dag-level Callback Requests were not run (#13651)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new c128aa7  BugFix: Dag-level Callback Requests were not run (#13651)
c128aa7 is described below

commit c128aa744e9b282d6046da1531f785359f471085
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Thu Jan 14 15:46:58 2021 +0000

    BugFix: Dag-level Callback Requests were not run (#13651)
    
    In https://github.com/apache/airflow/pull/13163 - I attempted to only run
    Callback requests when they are defined on DAG. But I just found out
    that while we were storing the task-level callbacks as string in Serialized
    JSON, we were not storing DAG level callbacks and hence it default to None
    when the Serialized DAG was deserialized which meant that the DAG callbacks
    were not run.
    
    This PR fixes it, we don't need to store DAG level callbacks as string, as
    we don't display them in the Webserver and the actual contents are not used anywhere
    in the Scheduler itself. Scheduler just checks if the callbacks are defined and sends
    it to DagFileProcessorProcess to run with the actual DAG file. So instead of storing
    the actual callback as string which would have resulted in larger JSON blob, I have
    added properties to determine whether a callback is defined or not.
    
    (`dag.has_on_success_callback` and `dag.has_on_failure_callback`)
    
    Note: SLA callbacks don't have issue, as we currently check that SLAs are defined on
    any tasks are not, if yes, we send it to DagFileProcessorProcess which then executes
    the SLA callback defined on DAG.
---
 airflow/models/dag.py                         |  9 +++++
 airflow/models/dagrun.py                      |  6 +--
 airflow/serialization/schema.json             |  2 +
 airflow/serialization/serialized_objects.py   | 12 ++++++
 airflow/utils/callback_requests.py            |  4 +-
 tests/models/test_dagrun.py                   | 13 +++++++
 tests/serialization/test_dag_serialization.py | 56 ++++++++++++++++++++++++++-
 7 files changed, 97 insertions(+), 5 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 3a7b5ec..1d0105d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -348,6 +348,12 @@ class DAG(LoggingMixin):
         self.partial = False
         self.on_success_callback = on_success_callback
         self.on_failure_callback = on_failure_callback
+
+        # To keep it in parity with Serialized DAGs
+        # and identify if DAG has on_*_callback without actually storing them in Serialized JSON
+        self.has_on_success_callback = self.on_success_callback is not None
+        self.has_on_failure_callback = self.on_failure_callback is not None
+
         self.doc_md = doc_md
 
         self._access_control = DAG._upgrade_outdated_dag_access_control(access_control)
@@ -2028,6 +2034,9 @@ class DAG(LoggingMixin):
                 'on_failure_callback',
                 'template_undefined',
                 'jinja_environment_kwargs',
+                # has_on_*_callback are only stored if the value is True, as the default is False
+                'has_on_success_callback',
+                'has_on_failure_callback',
             }
         return cls.__serialized_fields
 
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 50fc514..c71a52f 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -429,7 +429,7 @@ class DagRun(Base, LoggingMixin):
             self.set_state(State.FAILED)
             if execute_callbacks:
                 dag.handle_callback(self, success=False, reason='task_failure', session=session)
-            elif dag.on_failure_callback:
+            elif dag.has_on_failure_callback:
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
@@ -444,7 +444,7 @@ class DagRun(Base, LoggingMixin):
             self.set_state(State.SUCCESS)
             if execute_callbacks:
                 dag.handle_callback(self, success=True, reason='success', session=session)
-            elif dag.on_success_callback:
+            elif dag.has_on_success_callback:
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
@@ -459,7 +459,7 @@ class DagRun(Base, LoggingMixin):
             self.set_state(State.FAILED)
             if execute_callbacks:
                 dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
-            elif dag.on_failure_callback:
+            elif dag.has_on_failure_callback:
                 callback = callback_requests.DagCallbackRequest(
                     full_filepath=dag.fileloc,
                     dag_id=self.dag_id,
diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json
index 5ad3178..c831334 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -96,6 +96,8 @@
         "_default_view": { "type" : "string"},
         "_access_control": {"$ref": "#/definitions/dict" },
         "is_paused_upon_creation":  { "type": "boolean" },
+        "has_on_success_callback":  { "type": "boolean" },
+        "has_on_failure_callback":  { "type": "boolean" },
         "tags": { "type": "array" },
         "_task_group": {"anyOf": [
           { "type": "null" },
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 403030b..d0cb1a6 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -634,6 +634,12 @@ class SerializedDAG(DAG, BaseSerialization):
 
             serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()]
             serialize_dag['_task_group'] = SerializedTaskGroup.serialize_task_group(dag.task_group)
+
+            # has_on_*_callback are only stored if the value is True, as the default is False
+            if dag.has_on_success_callback:
+                serialize_dag['has_on_success_callback'] = True
+            if dag.has_on_failure_callback:
+                serialize_dag['has_on_failure_callback'] = True
             return serialize_dag
         except SerializationError:
             raise
@@ -677,6 +683,12 @@ class SerializedDAG(DAG, BaseSerialization):
                 dag.task_group.add(task)
         # pylint: enable=protected-access
 
+        # Set has_on_*_callbacks to True if they exist in Serialized blob as False is the default
+        if "has_on_success_callback" in encoded_dag:
+            dag.has_on_success_callback = True
+        if "has_on_failure_callback" in encoded_dag:
+            dag.has_on_failure_callback = True
+
         keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys()
         for k in keys_to_set_none:
             setattr(dag, k, None)
diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py
index 5561955..89ffe52 100644
--- a/airflow/utils/callback_requests.py
+++ b/airflow/utils/callback_requests.py
@@ -34,7 +34,9 @@ class CallbackRequest:
         self.msg = msg
 
     def __eq__(self, other):
-        return self.__dict__ == other.__dict__
+        if isinstance(other, CallbackRequest):
+            return self.__dict__ == other.__dict__
+        return False
 
     def __repr__(self):
         return str(self.__dict__)
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index b5c7da0..f27c50a 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -30,6 +30,7 @@ from airflow.models.dagrun import DagRun
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import ShortCircuitOperator
+from airflow.serialization.serialized_objects import SerializedDAG
 from airflow.stats import Stats
 from airflow.utils import timezone
 from airflow.utils.callback_requests import DagCallbackRequest
@@ -304,6 +305,9 @@ class TestDagRun(unittest.TestCase):
             'test_state_succeeded2': State.SUCCESS,
         }
 
+        # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
         dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
         _, callback = dag_run.update_state()
         self.assertEqual(State.SUCCESS, dag_run.state)
@@ -328,6 +332,9 @@ class TestDagRun(unittest.TestCase):
         }
         dag_task1.set_downstream(dag_task2)
 
+        # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
         dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
         _, callback = dag_run.update_state()
         self.assertEqual(State.FAILED, dag_run.state)
@@ -354,6 +361,9 @@ class TestDagRun(unittest.TestCase):
             'test_state_succeeded2': State.SUCCESS,
         }
 
+        # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
         dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
 
         _, callback = dag_run.update_state(execute_callbacks=False)
@@ -388,6 +398,9 @@ class TestDagRun(unittest.TestCase):
             'test_state_failed2': State.FAILED,
         }
 
+        # Scheduler uses Serialized DAG -- so use that instead of the Actual DAG
+        dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
         dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states)
 
         _, callback = dag_run.update_state(execute_callbacks=False)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 37f50e7..eba7f13 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -806,7 +806,7 @@ class TestStringifiedDAGs(unittest.TestCase):
         dag_schema: dict = load_dag_schema_dict()["definitions"]["dag"]["properties"]
 
         # The parameters we add manually in Serialization needs to be ignored
-        ignored_keys: set = {"is_subdag", "tasks"}
+        ignored_keys: set = {"is_subdag", "tasks", "has_on_success_callback", "has_on_failure_callback"}
         dag_params: set = set(dag_schema.keys()) - ignored_keys
         self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
 
@@ -970,6 +970,60 @@ class TestStringifiedDAGs(unittest.TestCase):
 
         assert op.deps == serialized_op.deps
 
+    @parameterized.expand(
+        [
+            ({"on_success_callback": lambda x: print("hi")}, True),
+            ({}, False),
+        ]
+    )
+    def test_dag_on_success_callback_roundtrip(self, passed_success_callback, expected_value):
+        """
+        Test that when on_success_callback is passed to the DAG, has_on_success_callback is stored
+        in Serialized JSON blob. And when it is de-serialized dag.has_on_success_callback is set to True.
+
+        When the callback is not set, has_on_success_callback should not be stored in Serialized blob
+        and so default to False on de-serialization
+        """
+        dag = DAG(dag_id='test_dag_on_success_callback_roundtrip', **passed_success_callback)
+        BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))
+
+        serialized_dag = SerializedDAG.to_dict(dag)
+        if expected_value:
+            assert "has_on_success_callback" in serialized_dag["dag"]
+        else:
+            assert "has_on_success_callback" not in serialized_dag["dag"]
+
+        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+
+        assert deserialized_dag.has_on_success_callback is expected_value
+
+    @parameterized.expand(
+        [
+            ({"on_failure_callback": lambda x: print("hi")}, True),
+            ({}, False),
+        ]
+    )
+    def test_dag_on_failure_callback_roundtrip(self, passed_failure_callback, expected_value):
+        """
+        Test that when on_failure_callback is passed to the DAG, has_on_failure_callback is stored
+        in Serialized JSON blob. And when it is de-serialized dag.has_on_failure_callback is set to True.
+
+        When the callback is not set, has_on_failure_callback should not be stored in Serialized blob
+        and so default to False on de-serialization
+        """
+        dag = DAG(dag_id='test_dag_on_failure_callback_roundtrip', **passed_failure_callback)
+        BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))
+
+        serialized_dag = SerializedDAG.to_dict(dag)
+        if expected_value:
+            assert "has_on_failure_callback" in serialized_dag["dag"]
+        else:
+            assert "has_on_failure_callback" not in serialized_dag["dag"]
+
+        deserialized_dag = SerializedDAG.from_dict(serialized_dag)
+
+        assert deserialized_dag.has_on_failure_callback is expected_value
+
 
 def test_kubernetes_optional():
     """Serialisation / deserialisation continues to work without kubernetes installed"""