You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:32 UTC
[airflow] 26/37: Correctly store non-default Nones in serialized
tasks/dags (#8772)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a7ab95cee2da2f9d89ec1aa6dc014122b81f6456
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 08:57:21 2020 +0100
Correctly store non-default Nones in serialized tasks/dags (#8772)
The default schedule_interval for a DAG is `@daily`, so
`schedule_interval=None` is actually not the default, but we were not
storing _any_ null attributes previously.
This meant that upon re-inflating the DAG the schedule_interval would
become @daily.
This fixes that problem, and extends the test to look at _all_ the
serialized attributes in our round-trip tests, rather than just the few
that the webserver cared about.
It doesn't change the serialization format, it just changes what/when
values were stored.
This solution was more complex than I hoped for, but the test case in
test_operator_subclass_changing_base_defaults is a real one that the
round trip tests discovered from the DatabricksSubmitRunOperator -- I
have just captured it in this test in case that specific operator
changes in future.
(cherry picked from commit a715aa692e88160cb8e9df4effda2440e4778c17)
---
airflow/serialization/serialized_objects.py | 24 +++-
tests/serialization/test_dag_serialization.py | 176 ++++++++++++++++++--------
2 files changed, 139 insertions(+), 61 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 917d80f..3e564ec 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -122,10 +122,16 @@ class BaseSerialization:
@classmethod
def _is_excluded(cls, var, attrname, instance):
"""Types excluded from serialization."""
+
+ if var is None:
+ if not cls._is_constructor_param(attrname, instance):
+ # Any instance attribute, that is not a constructor argument, we exclude None as the default
+ return True
+
+ return cls._value_is_hardcoded_default(attrname, var, instance)
return (
- var is None or
isinstance(var, cls._excluded_types) or
- cls._value_is_hardcoded_default(attrname, var)
+ cls._value_is_hardcoded_default(attrname, var, instance)
)
@classmethod
@@ -259,7 +265,12 @@ class BaseSerialization:
return datetime.timedelta(seconds=seconds)
@classmethod
- def _value_is_hardcoded_default(cls, attrname, value):
+ def _is_constructor_param(cls, attrname, instance):
+ # pylint: disable=unused-argument
+ return attrname in cls._CONSTRUCTOR_PARAMS
+
+ @classmethod
+ def _value_is_hardcoded_default(cls, attrname, value, instance):
"""
Return true if ``value`` is the hard-coded default for the given attribute.
This takes in to account cases where the ``concurrency`` parameter is
@@ -273,8 +284,9 @@ class BaseSerialization:
to account for the case where the default value of the field is None but has the
``field = field or {}`` set.
"""
+ # pylint: disable=unused-argument
if attrname in cls._CONSTRUCTOR_PARAMS and \
- (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
+ (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])):
return True
return False
@@ -288,7 +300,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
_decorated_fields = {'executor_config', }
_CONSTRUCTOR_PARAMS = {
- k: v for k, v in signature(BaseOperator).parameters.items()
+ k: v.default for k, v in signature(BaseOperator).parameters.items()
if v.default is not v.empty
}
@@ -511,7 +523,7 @@ class SerializedDAG(DAG, BaseSerialization):
'access_control': '_access_control',
}
return {
- param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
+ param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
if v.default is not v.empty
}
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 52f6e1a..e28e2b2 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -18,7 +18,9 @@
# under the License.
"""Unit tests for stringified DAGs."""
+from glob import glob
import multiprocessing
+import os
import unittest
import six
@@ -29,13 +31,10 @@ from datetime import datetime, timedelta
from parameterized import parameterized
from dateutil.relativedelta import relativedelta, FR
-from airflow import example_dags
-from airflow.contrib import example_dags as contrib_example_dags
from airflow.hooks.base_hook import BaseHook
from airflow.models import DAG, Connection, DagBag, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.operators.bash_operator import BashOperator
-from airflow.operators.subdag_operator import SubDagOperator
from airflow.serialization.json_schema import load_dag_schema_dict
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink
@@ -110,10 +109,14 @@ serialized_simple_dag_ground_truth = {
},
}
+ROOT_FOLDER = os.path.realpath(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
+)
-def make_example_dags(module):
+
+def make_example_dags(module_path):
"""Loads DAGs from a module for test."""
- dagbag = DagBag(module.__path__[0])
+ dagbag = DagBag(module_path)
return dagbag.dags
@@ -170,22 +173,34 @@ def make_user_defined_macro_filter_dag():
return {dag.dag_id: dag}
-def collect_dags():
+def collect_dags(dag_folder=None):
"""Collects DAGs to test."""
dags = {}
dags.update(make_simple_dag())
dags.update(make_user_defined_macro_filter_dag())
- dags.update(make_example_dags(example_dags))
- dags.update(make_example_dags(contrib_example_dags))
+
+ if dag_folder:
+ if isinstance(dag_folder, (list, tuple)):
+ patterns = dag_folder
+ else:
+ patterns = [dag_folder]
+ else:
+ patterns = [
+ "airflow/example_dags",
+ "airflow/contrib/example_dags",
+ ]
+ for pattern in patterns:
+ for directory in glob(ROOT_FOLDER + "/" + pattern):
+ dags.update(make_example_dags(directory))
# Filter subdags as they are stored in same row in Serialized Dag table
dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
return dags
-def serialize_subprocess(queue):
+def serialize_subprocess(queue, dag_folder):
"""Validate pickle in a subprocess."""
- dags = collect_dags()
+ dags = collect_dags(dag_folder)
for dag in dags.values():
queue.put(SerializedDAG.to_json(dag))
queue.put(None)
@@ -242,14 +257,17 @@ class TestStringifiedDAGs(unittest.TestCase):
)
return dag_dict
- self.assertEqual(sorted_serialized_dag(ground_truth_dag),
- sorted_serialized_dag(json_dag))
+ assert sorted_serialized_dag(ground_truth_dag) == sorted_serialized_dag(json_dag)
- def test_deserialization(self):
+ def test_deserialization_across_process(self):
"""A serialized DAG can be deserialized in another process."""
+
+ # Since we need to parse the dags twice here (once in the subprocess,
+ # and once here to get a DAG to compare to) we don't want to load all
+ # dags.
queue = multiprocessing.Queue()
proc = multiprocessing.Process(
- target=serialize_subprocess, args=(queue,))
+ target=serialize_subprocess, args=(queue, "airflow/example_dags"))
proc.daemon = True
proc.start()
@@ -262,69 +280,100 @@ class TestStringifiedDAGs(unittest.TestCase):
self.assertTrue(isinstance(dag, DAG))
stringified_dags[dag.dag_id] = dag
- dags = collect_dags()
- self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))
+ dags = collect_dags("airflow/example_dags")
+ assert set(stringified_dags.keys()) == set(dags.keys())
# Verify deserialized DAGs.
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
- example_skip_dag = stringified_dags['example_skip_dag']
- skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
- self.validate_deserialized_task(
- skip_operator_1_task, 'DummySkipOperator', '#e8b7e4', '#000')
+ def test_roundtrip_provider_example_dags(self):
+ dags = collect_dags([
+ "airflow/providers/*/example_dags",
+ "airflow/providers/*/*/example_dags",
+ ])
- # Verify that the DAG object has 'full_filepath' attribute
- # and is equal to fileloc
- self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
- self.assertEqual(example_skip_dag.full_filepath, example_skip_dag.fileloc)
-
- example_subdag_operator = stringified_dags['example_subdag_operator']
- section_1_task = example_subdag_operator.task_dict['section-1']
- self.validate_deserialized_task(
- section_1_task,
- SubDagOperator.__name__,
- SubDagOperator.ui_color,
- SubDagOperator.ui_fgcolor
- )
+ # Verify deserialized DAGs.
+ for dag in dags.values():
+ serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+ self.validate_deserialized_dag(serialized_dag, dag)
def validate_deserialized_dag(self, serialized_dag, dag):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
"""
- fields_to_check = [
- "params", "fileloc", "max_active_runs", "concurrency",
- "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
- "catchup", "description", "start_date", "end_date", "parent_dag",
- "template_searchpath", "_access_control", "dagrun_timeout"
- ]
+ fields_to_check = dag.get_serialized_fields() - {
+ # Doesn't implement __eq__ properly. Check manually
+ 'timezone',
- # fields_to_check = dag.get_serialized_fields()
+ # Need to check fields in it, to exclude functions
+ 'default_args',
+ }
for field in fields_to_check:
- self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))
+ assert getattr(serialized_dag, field) == getattr(dag, field), \
+ '{}.{} does not match'.format(dag.dag_id, field)
- self.assertEqual(
- sorted(serialized_dag.task_ids),
- sorted([str(task) for task in dag.task_ids]))
+ if dag.default_args:
+ for k, v in dag.default_args.items():
+ if callable(v):
+ # Check we stored _someting_.
+ assert k in serialized_dag.default_args
+ else:
+ assert v == serialized_dag.default_args[k], \
+ '{}.default_args[{}] does not match'.format(dag.dag_id, k)
+
+ assert serialized_dag.timezone.name == dag.timezone.name
+
+ for task_id in dag.task_ids:
+ self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id))
+
+ # Verify that the DAG object has 'full_filepath' attribute
+ # and is equal to fileloc
+ assert serialized_dag.full_filepath == dag.fileloc
- def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
+ def validate_deserialized_task(self, serialized_task, task,):
"""Verify non-airflow operators are casted to BaseOperator."""
- self.assertTrue(isinstance(task, SerializedBaseOperator))
- # Verify the original operator class is recorded for UI.
- self.assertTrue(task.task_type == task_type)
- self.assertTrue(task.ui_color == ui_color)
- self.assertTrue(task.ui_fgcolor == ui_fgcolor)
+ assert isinstance(serialized_task, SerializedBaseOperator)
+ assert not isinstance(task, SerializedBaseOperator)
+ assert isinstance(task, BaseOperator)
+
+ fields_to_check = task.get_serialized_fields() - {
+ # Checked separately
+ '_task_type', 'subdag',
+
+ # Type is exluded, so don't check it
+ '_log',
+
+ # List vs tuple. Check separately
+ 'template_fields',
+
+ # We store the string, real dag has the actual code
+ 'on_failure_callback', 'on_success_callback', 'on_retry_callback',
+
+ # Checked separately
+ 'resources',
+ }
+
+ assert serialized_task.task_type == task.task_type
+ assert set(serialized_task.template_fields) == set(task.template_fields)
+
+ for field in fields_to_check:
+ assert getattr(serialized_task, field) == getattr(task, field), \
+ '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)
+
+ if serialized_task.resources is None:
+ assert task.resources is None or task.resources == []
+ else:
+ assert serialized_task.resources == task.resources
# Check that for Deserialised task, task.subdag is None for all other Operators
# except for the SubDagOperator where task.subdag is an instance of DAG object
if task.task_type == "SubDagOperator":
- self.assertIsNotNone(task.subdag)
- self.assertTrue(isinstance(task.subdag, DAG))
+ assert serialized_task.subdag is not None
+ assert isinstance(serialized_task.subdag, DAG)
else:
- self.assertIsNone(task.subdag)
- self.assertEqual({}, task.params)
- self.assertEqual({}, task.executor_config)
+ assert serialized_task.subdag is None
@parameterized.expand([
(datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
@@ -650,6 +699,23 @@ class TestStringifiedDAGs(unittest.TestCase):
dag_params = set(dag_schema.keys()) - ignored_keys
self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
+ def test_operator_subclass_changing_base_defaults(self):
+ assert BaseOperator(task_id='dummy').do_xcom_push is True, \
+ "Precondition check! If this fails the test won't make sense"
+
+ class MyOperator(BaseOperator):
+ def __init__(self, do_xcom_push=False, **kwargs):
+ super(MyOperator, self).__init__(**kwargs)
+ self.do_xcom_push = do_xcom_push
+
+ op = MyOperator(task_id='dummy')
+ assert op.do_xcom_push is False
+
+ blob = SerializedBaseOperator.serialize_operator(op)
+ serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+
+ assert serialized_op.do_xcom_push is False
+
def test_no_new_fields_added_to_base_operator(self):
"""
This test verifies that there are no new fields added to BaseOperator. And reminds that