You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/26 12:53:42 UTC
[airflow] branch v1-10-test updated (686716b -> 45bf7f4)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a change to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.
from 686716b [AIRFLOW-6778] Add a configurable DAGs volume mount path for Kubernetes (#8147)
new 340e43d Correctly store non-default Nones in serialized tasks/dags (#8772)
new 45bf7f4 Correctly restore upstream_task_ids when deserializing Operators (#8775)
The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails. The revisions
listed as "add" were already present in the repository and have only
been added to this reference.
Summary of changes:
airflow/serialization/serialized_objects.py | 26 +++-
tests/serialization/test_dag_serialization.py | 179 ++++++++++++++++++--------
2 files changed, 143 insertions(+), 62 deletions(-)
[airflow] 01/02: Correctly store non-default Nones in serialized
tasks/dags (#8772)
Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 340e43d86b0f9592553c8e6eae9b1f14e6f72179
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 08:57:21 2020 +0100
Correctly store non-default Nones in serialized tasks/dags (#8772)
The default schedule_interval for a DAG is `@daily`, so
`schedule_interval=None` is actually not the default, but we were not
storing _any_ null attributes previously.
This meant that upon re-inflating the DAG the schedule_interval would
become @daily.
This fixes that problem, and extends the test to look at _all_ the
serialized attributes in our round-trip tests, rather than just the few
that the webserver cared about.
It doesn't change the serialization format, it just changes what/when
values were stored.
This solution was more complex than I hoped for, but the test case in
test_operator_subclass_changing_base_defaults is a real one that the
round trip tests discovered from the DatabricksSubmitRunOperator -- I
have just captured it in this test in case that specific operator
changes in future.
(cherry picked from commit a715aa692e88160cb8e9df4effda2440e4778c17)
---
airflow/serialization/serialized_objects.py | 24 +++-
tests/serialization/test_dag_serialization.py | 176 ++++++++++++++++++--------
2 files changed, 139 insertions(+), 61 deletions(-)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 917d80f..3e564ec 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -122,10 +122,16 @@ class BaseSerialization:
@classmethod
def _is_excluded(cls, var, attrname, instance):
"""Types excluded from serialization."""
+
+ if var is None:
+ if not cls._is_constructor_param(attrname, instance):
+ # Any instance attribute, that is not a constructor argument, we exclude None as the default
+ return True
+
+ return cls._value_is_hardcoded_default(attrname, var, instance)
return (
- var is None or
isinstance(var, cls._excluded_types) or
- cls._value_is_hardcoded_default(attrname, var)
+ cls._value_is_hardcoded_default(attrname, var, instance)
)
@classmethod
@@ -259,7 +265,12 @@ class BaseSerialization:
return datetime.timedelta(seconds=seconds)
@classmethod
- def _value_is_hardcoded_default(cls, attrname, value):
+ def _is_constructor_param(cls, attrname, instance):
+ # pylint: disable=unused-argument
+ return attrname in cls._CONSTRUCTOR_PARAMS
+
+ @classmethod
+ def _value_is_hardcoded_default(cls, attrname, value, instance):
"""
Return true if ``value`` is the hard-coded default for the given attribute.
This takes in to account cases where the ``concurrency`` parameter is
@@ -273,8 +284,9 @@ class BaseSerialization:
to account for the case where the default value of the field is None but has the
``field = field or {}`` set.
"""
+ # pylint: disable=unused-argument
if attrname in cls._CONSTRUCTOR_PARAMS and \
- (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
+ (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])):
return True
return False
@@ -288,7 +300,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
_decorated_fields = {'executor_config', }
_CONSTRUCTOR_PARAMS = {
- k: v for k, v in signature(BaseOperator).parameters.items()
+ k: v.default for k, v in signature(BaseOperator).parameters.items()
if v.default is not v.empty
}
@@ -511,7 +523,7 @@ class SerializedDAG(DAG, BaseSerialization):
'access_control': '_access_control',
}
return {
- param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
+ param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
if v.default is not v.empty
}
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 52f6e1a..e28e2b2 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -18,7 +18,9 @@
# under the License.
"""Unit tests for stringified DAGs."""
+from glob import glob
import multiprocessing
+import os
import unittest
import six
@@ -29,13 +31,10 @@ from datetime import datetime, timedelta
from parameterized import parameterized
from dateutil.relativedelta import relativedelta, FR
-from airflow import example_dags
-from airflow.contrib import example_dags as contrib_example_dags
from airflow.hooks.base_hook import BaseHook
from airflow.models import DAG, Connection, DagBag, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.operators.bash_operator import BashOperator
-from airflow.operators.subdag_operator import SubDagOperator
from airflow.serialization.json_schema import load_dag_schema_dict
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink
@@ -110,10 +109,14 @@ serialized_simple_dag_ground_truth = {
},
}
+ROOT_FOLDER = os.path.realpath(
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
+)
-def make_example_dags(module):
+
+def make_example_dags(module_path):
"""Loads DAGs from a module for test."""
- dagbag = DagBag(module.__path__[0])
+ dagbag = DagBag(module_path)
return dagbag.dags
@@ -170,22 +173,34 @@ def make_user_defined_macro_filter_dag():
return {dag.dag_id: dag}
-def collect_dags():
+def collect_dags(dag_folder=None):
"""Collects DAGs to test."""
dags = {}
dags.update(make_simple_dag())
dags.update(make_user_defined_macro_filter_dag())
- dags.update(make_example_dags(example_dags))
- dags.update(make_example_dags(contrib_example_dags))
+
+ if dag_folder:
+ if isinstance(dag_folder, (list, tuple)):
+ patterns = dag_folder
+ else:
+ patterns = [dag_folder]
+ else:
+ patterns = [
+ "airflow/example_dags",
+ "airflow/contrib/example_dags",
+ ]
+ for pattern in patterns:
+ for directory in glob(ROOT_FOLDER + "/" + pattern):
+ dags.update(make_example_dags(directory))
# Filter subdags as they are stored in same row in Serialized Dag table
dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
return dags
-def serialize_subprocess(queue):
+def serialize_subprocess(queue, dag_folder):
"""Validate pickle in a subprocess."""
- dags = collect_dags()
+ dags = collect_dags(dag_folder)
for dag in dags.values():
queue.put(SerializedDAG.to_json(dag))
queue.put(None)
@@ -242,14 +257,17 @@ class TestStringifiedDAGs(unittest.TestCase):
)
return dag_dict
- self.assertEqual(sorted_serialized_dag(ground_truth_dag),
- sorted_serialized_dag(json_dag))
+ assert sorted_serialized_dag(ground_truth_dag) == sorted_serialized_dag(json_dag)
- def test_deserialization(self):
+ def test_deserialization_across_process(self):
"""A serialized DAG can be deserialized in another process."""
+
+ # Since we need to parse the dags twice here (once in the subprocess,
+ # and once here to get a DAG to compare to) we don't want to load all
+ # dags.
queue = multiprocessing.Queue()
proc = multiprocessing.Process(
- target=serialize_subprocess, args=(queue,))
+ target=serialize_subprocess, args=(queue, "airflow/example_dags"))
proc.daemon = True
proc.start()
@@ -262,69 +280,100 @@ class TestStringifiedDAGs(unittest.TestCase):
self.assertTrue(isinstance(dag, DAG))
stringified_dags[dag.dag_id] = dag
- dags = collect_dags()
- self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))
+ dags = collect_dags("airflow/example_dags")
+ assert set(stringified_dags.keys()) == set(dags.keys())
# Verify deserialized DAGs.
for dag_id in stringified_dags:
self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
- example_skip_dag = stringified_dags['example_skip_dag']
- skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
- self.validate_deserialized_task(
- skip_operator_1_task, 'DummySkipOperator', '#e8b7e4', '#000')
+ def test_roundtrip_provider_example_dags(self):
+ dags = collect_dags([
+ "airflow/providers/*/example_dags",
+ "airflow/providers/*/*/example_dags",
+ ])
- # Verify that the DAG object has 'full_filepath' attribute
- # and is equal to fileloc
- self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
- self.assertEqual(example_skip_dag.full_filepath, example_skip_dag.fileloc)
-
- example_subdag_operator = stringified_dags['example_subdag_operator']
- section_1_task = example_subdag_operator.task_dict['section-1']
- self.validate_deserialized_task(
- section_1_task,
- SubDagOperator.__name__,
- SubDagOperator.ui_color,
- SubDagOperator.ui_fgcolor
- )
+ # Verify deserialized DAGs.
+ for dag in dags.values():
+ serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+ self.validate_deserialized_dag(serialized_dag, dag)
def validate_deserialized_dag(self, serialized_dag, dag):
"""
Verify that all example DAGs work with DAG Serialization by
checking fields between Serialized Dags & non-Serialized Dags
"""
- fields_to_check = [
- "params", "fileloc", "max_active_runs", "concurrency",
- "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
- "catchup", "description", "start_date", "end_date", "parent_dag",
- "template_searchpath", "_access_control", "dagrun_timeout"
- ]
+ fields_to_check = dag.get_serialized_fields() - {
+ # Doesn't implement __eq__ properly. Check manually
+ 'timezone',
- # fields_to_check = dag.get_serialized_fields()
+ # Need to check fields in it, to exclude functions
+ 'default_args',
+ }
for field in fields_to_check:
- self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))
+ assert getattr(serialized_dag, field) == getattr(dag, field), \
+ '{}.{} does not match'.format(dag.dag_id, field)
- self.assertEqual(
- sorted(serialized_dag.task_ids),
- sorted([str(task) for task in dag.task_ids]))
+ if dag.default_args:
+ for k, v in dag.default_args.items():
+ if callable(v):
+ # Check we stored _someting_.
+ assert k in serialized_dag.default_args
+ else:
+ assert v == serialized_dag.default_args[k], \
+ '{}.default_args[{}] does not match'.format(dag.dag_id, k)
+
+ assert serialized_dag.timezone.name == dag.timezone.name
+
+ for task_id in dag.task_ids:
+ self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id))
+
+ # Verify that the DAG object has 'full_filepath' attribute
+ # and is equal to fileloc
+ assert serialized_dag.full_filepath == dag.fileloc
- def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
+ def validate_deserialized_task(self, serialized_task, task,):
"""Verify non-airflow operators are casted to BaseOperator."""
- self.assertTrue(isinstance(task, SerializedBaseOperator))
- # Verify the original operator class is recorded for UI.
- self.assertTrue(task.task_type == task_type)
- self.assertTrue(task.ui_color == ui_color)
- self.assertTrue(task.ui_fgcolor == ui_fgcolor)
+ assert isinstance(serialized_task, SerializedBaseOperator)
+ assert not isinstance(task, SerializedBaseOperator)
+ assert isinstance(task, BaseOperator)
+
+ fields_to_check = task.get_serialized_fields() - {
+ # Checked separately
+ '_task_type', 'subdag',
+
+ # Type is exluded, so don't check it
+ '_log',
+
+ # List vs tuple. Check separately
+ 'template_fields',
+
+ # We store the string, real dag has the actual code
+ 'on_failure_callback', 'on_success_callback', 'on_retry_callback',
+
+ # Checked separately
+ 'resources',
+ }
+
+ assert serialized_task.task_type == task.task_type
+ assert set(serialized_task.template_fields) == set(task.template_fields)
+
+ for field in fields_to_check:
+ assert getattr(serialized_task, field) == getattr(task, field), \
+ '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)
+
+ if serialized_task.resources is None:
+ assert task.resources is None or task.resources == []
+ else:
+ assert serialized_task.resources == task.resources
# Check that for Deserialised task, task.subdag is None for all other Operators
# except for the SubDagOperator where task.subdag is an instance of DAG object
if task.task_type == "SubDagOperator":
- self.assertIsNotNone(task.subdag)
- self.assertTrue(isinstance(task.subdag, DAG))
+ assert serialized_task.subdag is not None
+ assert isinstance(serialized_task.subdag, DAG)
else:
- self.assertIsNone(task.subdag)
- self.assertEqual({}, task.params)
- self.assertEqual({}, task.executor_config)
+ assert serialized_task.subdag is None
@parameterized.expand([
(datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
@@ -650,6 +699,23 @@ class TestStringifiedDAGs(unittest.TestCase):
dag_params = set(dag_schema.keys()) - ignored_keys
self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
+ def test_operator_subclass_changing_base_defaults(self):
+ assert BaseOperator(task_id='dummy').do_xcom_push is True, \
+ "Precondition check! If this fails the test won't make sense"
+
+ class MyOperator(BaseOperator):
+ def __init__(self, do_xcom_push=False, **kwargs):
+ super(MyOperator, self).__init__(**kwargs)
+ self.do_xcom_push = do_xcom_push
+
+ op = MyOperator(task_id='dummy')
+ assert op.do_xcom_push is False
+
+ blob = SerializedBaseOperator.serialize_operator(op)
+ serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+
+ assert serialized_op.do_xcom_push is False
+
def test_no_new_fields_added_to_base_operator(self):
"""
This test verifies that there are no new fields added to BaseOperator. And reminds that
[airflow] 02/02: Correctly restore upstream_task_ids when
deserializing Operators (#8775)
Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 45bf7f4a557e38201fdffcd7f535b08d9835d9aa
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 11:41:47 2020 +0100
Correctly restore upstream_task_ids when deserializing Operators (#8775)
This test exposed a bug in one of the example dags, that wasn't caught
by #6549. That will be a fixed in a separate issue, but it caused the
round-trip tests to fail here
Fixes #8720
(cherry picked from commit 280f1f0c4cc49aba1b2f8b456326795733769d18)
---
airflow/serialization/serialized_objects.py | 2 +-
tests/serialization/test_dag_serialization.py | 3 +++
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 3e564ec..8d261aa 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -591,7 +591,7 @@ class SerializedDAG(DAG, BaseSerialization):
for task_id in serializable_task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
# noinspection PyProtectedMember
- dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access
+ dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) # noqa: E501 # pylint: disable=protected-access
return dag
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index e28e2b2..6b714a8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -358,6 +358,9 @@ class TestStringifiedDAGs(unittest.TestCase):
assert serialized_task.task_type == task.task_type
assert set(serialized_task.template_fields) == set(task.template_fields)
+ assert serialized_task.upstream_task_ids == task.upstream_task_ids
+ assert serialized_task.downstream_task_ids == task.downstream_task_ids
+
for field in fields_to_check:
assert getattr(serialized_task, field) == getattr(task, field), \
'{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)