You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/11/19 13:24:54 UTC
[airflow] 01/03: Fix operator field update for
SerializedBaseOperator (#10924)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a22dba100401f934f907eab150f12efaae428819
Author: Denis Evseev <xO...@gmail.com>
AuthorDate: Wed Sep 16 01:40:41 2020 +0300
Fix operator field update for SerializedBaseOperator (#10924)
Co-authored-by: Denis Evseev <xO...@gmail.com>
Co-authored-by: Kaxil Naik <ka...@gmail.com>
(cherry picked from commit f7da7d94b4ac6dc59fb50a4f4abba69776aac798)
(cherry picked from commit cfc9732d71ae5b4b65077f2ba9cd51180a6c4548)
---
airflow/models/taskinstance.py | 2 +-
airflow/sensors/external_task_sensor.py | 14 ++++++++++++++
tests/models/test_taskinstance.py | 18 ++++++++++++++++++
tests/sensors/test_external_task_sensor.py | 21 +++++++++++++++++++++
4 files changed, 54 insertions(+), 1 deletion(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index ae296ba..7c1caef 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -512,7 +512,7 @@ class TaskInstance(Base, LoggingMixin):
self.run_as_user = task.run_as_user
self.max_tries = task.retries
self.executor_config = task.executor_config
- self.operator = task.__class__.__name__
+ self.operator = task.task_type
@provide_session
def clear_xcom_data(self, session=None):
diff --git a/airflow/sensors/external_task_sensor.py b/airflow/sensors/external_task_sensor.py
index 2dc0875..b759a71 100644
--- a/airflow/sensors/external_task_sensor.py
+++ b/airflow/sensors/external_task_sensor.py
@@ -201,6 +201,9 @@ class ExternalTaskMarker(DummyOperator):
template_fields = ['external_dag_id', 'external_task_id', 'execution_date']
ui_color = '#19647e'
+ # The _serialized_fields are lazily loaded when get_serialized_fields() method is called
+ __serialized_fields = None
+
@apply_defaults
def __init__(self,
external_dag_id,
@@ -222,3 +225,14 @@ class ExternalTaskMarker(DummyOperator):
if recursion_depth <= 0:
raise ValueError("recursion_depth should be a positive integer")
self.recursion_depth = recursion_depth
+
+ @classmethod
+ def get_serialized_fields(cls):
+ """Serialized ExternalTaskMarker contain exactly these fields + templated_fields ."""
+ if not cls.__serialized_fields:
+ cls.__serialized_fields = frozenset(
+ super(ExternalTaskMarker, cls).get_serialized_fields() | {
+ "recursion_depth"
+ }
+ )
+ return cls.__serialized_fields
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 0c416a4..3b05bbb 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -40,6 +40,7 @@ from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import PythonOperator
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.ti_deps.dep_context import REQUEUEABLE_DEPS, RUNNABLE_STATES, RUNNING_DEPS
+from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
@@ -1560,6 +1561,23 @@ class TaskInstanceTest(unittest.TestCase):
with create_session() as session:
session.query(RenderedTaskInstanceFields).delete()
+ def test_operator_field_with_serialization(self):
+
+ dag = DAG('test_queries', start_date=DEFAULT_DATE)
+ task = DummyOperator(task_id='op', dag=dag)
+ self.assertEqual(task.task_type, 'DummyOperator')
+
+ # Verify that ti.operator field renders correctly "without" Serialization
+ ti = TI(task=task, execution_date=datetime.datetime.now())
+ self.assertEqual(ti.operator, "DummyOperator")
+
+ serialized_op = SerializedBaseOperator.serialize_operator(task)
+ deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
+ self.assertEqual(deserialized_op.task_type, 'DummyOperator')
+ # Verify that ti.operator field renders correctly "with" Serialization
+ ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now())
+ self.assertEqual(ser_ti.operator, "DummyOperator")
+
@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index e2a58ec..00d5835 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -27,6 +27,7 @@ from airflow.operators.bash_operator import BashOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.sensors.external_task_sensor import ExternalTaskMarker, ExternalTaskSensor
from airflow.sensors.time_sensor import TimeSensor
+from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.state import State
from airflow.utils.timezone import datetime
@@ -339,6 +340,26 @@ exit 0
)
+class TestExternalTaskMarker(unittest.TestCase):
+ def test_serialized_fields(self):
+ self.assertTrue({"recursion_depth"}.issubset(ExternalTaskMarker.get_serialized_fields()))
+
+ def test_serialized_external_task_marker(self):
+ dag = DAG('test_serialized_external_task_marker', start_date=DEFAULT_DATE)
+ task = ExternalTaskMarker(
+ task_id="parent_task",
+ external_dag_id="external_task_marker_child",
+ external_task_id="child_task1",
+ dag=dag
+ )
+
+ serialized_op = SerializedBaseOperator.serialize_operator(task)
+ deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
+ self.assertEqual(deserialized_op.task_type, 'ExternalTaskMarker')
+ self.assertEqual(getattr(deserialized_op, 'external_dag_id'), 'external_task_marker_child')
+ self.assertEqual(getattr(deserialized_op, 'external_task_id'), 'child_task1')
+
+
@pytest.fixture
def dag_bag_ext():
"""