You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:33 UTC

[airflow] 27/37: Correctly restore upstream_task_ids when deserializing Operators (#8775)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit db3f27ef0b95b2d30970c32fa9ad257b5dff0990
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 11:41:47 2020 +0100

    Correctly restore upstream_task_ids when deserializing Operators (#8775)
    
    This test exposed a bug in one of the example dags, that wasn't caught
    by #6549. That will be a fixed in a separate issue, but it caused the
    round-trip tests to fail here
    
    Fixes #8720
    
    (cherry picked from commit 280f1f0c4cc49aba1b2f8b456326795733769d18)
---
 airflow/serialization/serialized_objects.py   | 2 +-
 tests/serialization/test_dag_serialization.py | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 3e564ec..8d261aa 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -591,7 +591,7 @@ class SerializedDAG(DAG, BaseSerialization):
             for task_id in serializable_task.downstream_task_ids:
                 # Bypass set_upstream etc here - it does more than we want
                 # noinspection PyProtectedMember
-                dag.task_dict[task_id]._upstream_task_ids.add(task_id)  # pylint: disable=protected-access
+                dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id)  # noqa: E501 # pylint: disable=protected-access
 
         return dag
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index e28e2b2..6b714a8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -358,6 +358,9 @@ class TestStringifiedDAGs(unittest.TestCase):
         assert serialized_task.task_type == task.task_type
         assert set(serialized_task.template_fields) == set(task.template_fields)
 
+        assert serialized_task.upstream_task_ids == task.upstream_task_ids
+        assert serialized_task.downstream_task_ids == task.downstream_task_ids
+
         for field in fields_to_check:
             assert getattr(serialized_task, field) == getattr(task, field), \
                 '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)