You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/11/19 13:24:54 UTC

[airflow] 01/03: Fix operator field update for SerializedBaseOperator (#10924)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a22dba100401f934f907eab150f12efaae428819
Author: Denis Evseev <xO...@gmail.com>
AuthorDate: Wed Sep 16 01:40:41 2020 +0300

    Fix operator field update for SerializedBaseOperator (#10924)
    
    Co-authored-by: Denis Evseev <xO...@gmail.com>
    Co-authored-by: Kaxil Naik <ka...@gmail.com>
    (cherry picked from commit f7da7d94b4ac6dc59fb50a4f4abba69776aac798)
    (cherry picked from commit cfc9732d71ae5b4b65077f2ba9cd51180a6c4548)
---
 airflow/models/taskinstance.py             |  2 +-
 airflow/sensors/external_task_sensor.py    | 14 ++++++++++++++
 tests/models/test_taskinstance.py          | 18 ++++++++++++++++++
 tests/sensors/test_external_task_sensor.py | 21 +++++++++++++++++++++
 4 files changed, 54 insertions(+), 1 deletion(-)

diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ae296ba..7c1caef 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -512,7 +512,7 @@ class TaskInstance(Base, LoggingMixin):
         self.run_as_user = task.run_as_user
         self.max_tries = task.retries
         self.executor_config = task.executor_config
-        self.operator = task.__class__.__name__
+        self.operator = task.task_type
 
     @provide_session
     def clear_xcom_data(self, session=None):
diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py
index 2dc0875..b759a71 100644
--- a/airflow/sensors/external_task_sensor.py
+++ b/airflow/sensors/external_task_sensor.py
@@ -201,6 +201,9 @@ class ExternalTaskMarker(DummyOperator):
     template_fields = ['external_dag_id', 'external_task_id', 'execution_date']
     ui_color = '#19647e'
 
+    # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
+    __serialized_fields = None
+
     @apply_defaults
     def __init__(self,
                  external_dag_id,
@@ -222,3 +225,14 @@ class ExternalTaskMarker(DummyOperator):
         if recursion_depth <= 0:
             raise ValueError("recursion_depth should be a positive integer")
         self.recursion_depth = recursion_depth
+
+    @classmethod
+    def get_serialized_fields(cls):
+        """Serialized ExternalTaskMarker contain exactly these fields + templated_fields ."""
+        if not cls.__serialized_fields:
+            cls.__serialized_fields = frozenset(
+                super(ExternalTaskMarker, cls).get_serialized_fields() | {
+                    "recursion_depth"
+                }
+            )
+        return cls.__serialized_fields
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 0c416a4..3b05bbb 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -40,6 +40,7 @@ from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.python_operator import PythonOperator
 from airflow.sensors.base_sensor_operator import BaseSensorOperator
 from airflow.ti_deps.dep_context import REQUEUEABLE_DEPS, RUNNABLE_STATES, RUNNING_DEPS
+from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.utils import timezone
@@ -1560,6 +1561,23 @@ class TaskInstanceTest(unittest.TestCase):
         with create_session() as session:
             session.query(RenderedTaskInstanceFields).delete()
 
+    def test_operator_field_with_serialization(self):
+
+        dag = DAG('test_queries', start_date=DEFAULT_DATE)
+        task = DummyOperator(task_id='op', dag=dag)
+        self.assertEqual(task.task_type, 'DummyOperator')
+
+        # Verify that ti.operator field renders correctly "without" Serialization
+        ti = TI(task=task, execution_date=datetime.datetime.now())
+        self.assertEqual(ti.operator, "DummyOperator")
+
+        serialized_op = SerializedBaseOperator.serialize_operator(task)
+        deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
+        self.assertEqual(deserialized_op.task_type, 'DummyOperator')
+        # Verify that ti.operator field renders correctly "with" Serialization
+        ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now())
+        self.assertEqual(ser_ti.operator, "DummyOperator")
+
 
 @pytest.mark.parametrize("pool_override", [None, "test_pool2"])
 def test_refresh_from_task(pool_override):
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index e2a58ec..00d5835 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -27,6 +27,7 @@ from airflow.operators.bash_operator import BashOperator
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.sensors.external_task_sensor import ExternalTaskMarker, ExternalTaskSensor
 from airflow.sensors.time_sensor import TimeSensor
+from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.state import State
 from airflow.utils.timezone import datetime
 
@@ -339,6 +340,26 @@ exit 0
             )
 
 
+class TestExternalTaskMarker(unittest.TestCase):
+    def test_serialized_fields(self):
+        self.assertTrue({"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields()))
+
+    def test_serialized_external_task_marker(self):
+        dag = DAG('test_serialized_external_task_marker', start_date=DEFAULT_DATE)
+        task = ExternalTaskMarker(
+            task_id="parent_task",
+            external_dag_id="external_task_marker_child",
+            external_task_id="child_task1",
+            dag=dag
+        )
+
+        serialized_op = SerializedBaseOperator.serialize_operator(task)
+        deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
+        self.assertEqual(deserialized_op.task_type, 'ExternalTaskMarker')
+        self.assertEqual(getattr(deserialized_op, 'external_dag_id'), 'external_task_marker_child')
+        self.assertEqual(getattr(deserialized_op, 'external_task_id'), 'child_task1')
+
+
 @pytest.fixture
 def dag_bag_ext():
     """