You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ep...@apache.org on 2022/08/15 18:45:04 UTC

[airflow] 28/45: Fix Serialization error in TaskCallbackRequest (#25471)

This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-3-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 385f04ba345e872dc31de62113e2f46e01fd1d4a
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Tue Aug 2 22:50:40 2022 +0100

    Fix Serialization error in TaskCallbackRequest (#25471)
    
    How we serialize `SimpleTaskInstance `in `TaskCallbackRequest` class leads to JSON serialization error when there's start_date or end_date in the task instance. Since there's always a start_date on tis, this would always fail.
    This PR aims to fix this through a new method on the SimpleTaskInstance that looks for start_date/end_date and converts them to isoformat for serialization.
    
    (cherry picked from commit d7e14ba0d612d8315238f9d0cba4ef8c44b6867c)
---
 airflow/callbacks/callback_requests.py    |  2 +-
 airflow/models/taskinstance.py            | 10 ++++++++++
 tests/callbacks/test_callback_requests.py | 21 +++++++++++++++++----
 3 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py
index 8112589cd0..b04a201c08 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -75,7 +75,7 @@ class TaskCallbackRequest(CallbackRequest):
 
     def to_json(self) -> str:
         dict_obj = self.__dict__.copy()
-        dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].__dict__
+        dict_obj["simple_task_instance"] = self.simple_task_instance.as_dict()
         return json.dumps(dict_obj)
 
     @classmethod
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index debd0aa6b0..33fe7a3f53 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2650,6 +2650,16 @@ class SimpleTaskInstance:
             return self.__dict__ == other.__dict__
         return NotImplemented
 
+    def as_dict(self):
+        new_dict = dict(self.__dict__)
+        for key in new_dict:
+            if key in ['start_date', 'end_date']:
+                val = new_dict[key]
+                if not val or isinstance(val, str):
+                    continue
+                new_dict.update({key: val.isoformat()})
+        return new_dict
+
     @classmethod
     def from_ti(cls, ti: TaskInstance):
         return cls(
diff --git a/tests/callbacks/test_callback_requests.py b/tests/callbacks/test_callback_requests.py
index 286d64eaa1..3764f19c4c 100644
--- a/tests/callbacks/test_callback_requests.py
+++ b/tests/callbacks/test_callback_requests.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
 from datetime import datetime
 
 from parameterized import parameterized
@@ -29,6 +28,7 @@ from airflow.callbacks.callback_requests import (
 from airflow.models.dag import DAG
 from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
 from airflow.operators.bash import BashOperator
+from airflow.utils import timezone
 from airflow.utils.state import State
 
 TI = TaskInstance(
@@ -38,7 +38,7 @@ TI = TaskInstance(
 )
 
 
-class TestCallbackRequest(unittest.TestCase):
+class TestCallbackRequest:
     @parameterized.expand(
         [
             (CallbackRequest(full_filepath="filepath", msg="task_failure"), CallbackRequest),
@@ -64,7 +64,20 @@ class TestCallbackRequest(unittest.TestCase):
     )
     def test_from_json(self, input, request_class):
         json_str = input.to_json()
-
         result = request_class.from_json(json_str=json_str)
+        assert result == input
 
-        self.assertEqual(result, input)
+    def test_taskcallback_to_json_with_start_date_and_end_date(self, session, create_task_instance):
+        ti = create_task_instance()
+        ti.start_date = timezone.utcnow()
+        ti.end_date = timezone.utcnow()
+        session.merge(ti)
+        session.flush()
+        input = TaskCallbackRequest(
+            full_filepath="filepath",
+            simple_task_instance=SimpleTaskInstance.from_ti(ti),
+            is_failure_callback=True,
+        )
+        json_str = input.to_json()
+        result = TaskCallbackRequest.from_json(json_str)
+        assert input == result