You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/06/29 13:21:32 UTC

[airflow] 26/37: Correctly store non-default Nones in serialized tasks/dags (#8772)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit a7ab95cee2da2f9d89ec1aa6dc014122b81f6456
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 08:57:21 2020 +0100

    Correctly store non-default Nones in serialized tasks/dags (#8772)
    
    The default schedule_interval for a DAG is `@daily`, so
    `schedule_interval=None` is actually not the default, but we were not
    storing _any_ null attributes previously.
    
    This meant that upon re-inflating the DAG the schedule_interval would
    become @daily.
    
    This fixes that problem, and extends the test to look at _all_ the
    serialized attributes in our round-trip tests, rather than just the few
    that the webserver cared about.
    
    It doesn't change the serialization format, it just changes what/when
    values were stored.
    
    This solution was more complex than I hoped for, but the test case in
    test_operator_subclass_changing_base_defaults is a real one that the
    round trip tests discovered from the DatabricksSubmitRunOperator -- I
    have just captured it in this test in case that specific operator
    changes in future.
    
    (cherry picked from commit a715aa692e88160cb8e9df4effda2440e4778c17)
---
 airflow/serialization/serialized_objects.py   |  24 +++-
 tests/serialization/test_dag_serialization.py | 176 ++++++++++++++++++--------
 2 files changed, 139 insertions(+), 61 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 917d80f..3e564ec 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -122,10 +122,16 @@ class BaseSerialization:
     @classmethod
     def _is_excluded(cls, var, attrname, instance):
         """Types excluded from serialization."""
+
+        if var is None:
+            if not cls._is_constructor_param(attrname, instance):
+                # Any instance attribute, that is not a constructor argument, we exclude None as the default
+                return True
+
+            return cls._value_is_hardcoded_default(attrname, var, instance)
         return (
-            var is None or
             isinstance(var, cls._excluded_types) or
-            cls._value_is_hardcoded_default(attrname, var)
+            cls._value_is_hardcoded_default(attrname, var, instance)
         )
 
     @classmethod
@@ -259,7 +265,12 @@ class BaseSerialization:
         return datetime.timedelta(seconds=seconds)
 
     @classmethod
-    def _value_is_hardcoded_default(cls, attrname, value):
+    def _is_constructor_param(cls, attrname, instance):
+        # pylint: disable=unused-argument
+        return attrname in cls._CONSTRUCTOR_PARAMS
+
+    @classmethod
+    def _value_is_hardcoded_default(cls, attrname, value, instance):
         """
         Return true if ``value`` is the hard-coded default for the given attribute.
         This takes in to account cases where the ``concurrency`` parameter is
@@ -273,8 +284,9 @@ class BaseSerialization:
         to account for the case where the default value of the field is None but has the
         ``field = field or {}`` set.
         """
+        # pylint: disable=unused-argument
         if attrname in cls._CONSTRUCTOR_PARAMS and \
-                (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
+                (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])):
             return True
         return False
 
@@ -288,7 +300,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     _decorated_fields = {'executor_config', }
 
     _CONSTRUCTOR_PARAMS = {
-        k: v for k, v in signature(BaseOperator).parameters.items()
+        k: v.default for k, v in signature(BaseOperator).parameters.items()
         if v.default is not v.empty
     }
 
@@ -511,7 +523,7 @@ class SerializedDAG(DAG, BaseSerialization):
             'access_control': '_access_control',
         }
         return {
-            param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
+            param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
             if v.default is not v.empty
         }
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 52f6e1a..e28e2b2 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -18,7 +18,9 @@
 # under the License.
 
 """Unit tests for stringified DAGs."""
+from glob import glob
 import multiprocessing
+import os
 import unittest
 
 import six
@@ -29,13 +31,10 @@ from datetime import datetime, timedelta
 from parameterized import parameterized
 from dateutil.relativedelta import relativedelta, FR
 
-from airflow import example_dags
-from airflow.contrib import example_dags as contrib_example_dags
 from airflow.hooks.base_hook import BaseHook
 from airflow.models import DAG, Connection, DagBag, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.operators.bash_operator import BashOperator
-from airflow.operators.subdag_operator import SubDagOperator
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
 from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink
@@ -110,10 +109,14 @@ serialized_simple_dag_ground_truth = {
     },
 }
 
+ROOT_FOLDER = os.path.realpath(
+    os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
+)
 
-def make_example_dags(module):
+
+def make_example_dags(module_path):
     """Loads DAGs from a module for test."""
-    dagbag = DagBag(module.__path__[0])
+    dagbag = DagBag(module_path)
     return dagbag.dags
 
 
@@ -170,22 +173,34 @@ def make_user_defined_macro_filter_dag():
     return {dag.dag_id: dag}
 
 
-def collect_dags():
+def collect_dags(dag_folder=None):
     """Collects DAGs to test."""
     dags = {}
     dags.update(make_simple_dag())
     dags.update(make_user_defined_macro_filter_dag())
-    dags.update(make_example_dags(example_dags))
-    dags.update(make_example_dags(contrib_example_dags))
+
+    if dag_folder:
+        if isinstance(dag_folder, (list, tuple)):
+            patterns = dag_folder
+        else:
+            patterns = [dag_folder]
+    else:
+        patterns = [
+            "airflow/example_dags",
+            "airflow/contrib/example_dags",
+        ]
+    for pattern in patterns:
+        for directory in glob(ROOT_FOLDER + "/" + pattern):
+            dags.update(make_example_dags(directory))
 
     # Filter subdags as they are stored in same row in Serialized Dag table
     dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
     return dags
 
 
-def serialize_subprocess(queue):
+def serialize_subprocess(queue, dag_folder):
     """Validate pickle in a subprocess."""
-    dags = collect_dags()
+    dags = collect_dags(dag_folder)
     for dag in dags.values():
         queue.put(SerializedDAG.to_json(dag))
     queue.put(None)
@@ -242,14 +257,17 @@ class TestStringifiedDAGs(unittest.TestCase):
             )
             return dag_dict
 
-        self.assertEqual(sorted_serialized_dag(ground_truth_dag),
-                         sorted_serialized_dag(json_dag))
+        assert sorted_serialized_dag(ground_truth_dag) == sorted_serialized_dag(json_dag)
 
-    def test_deserialization(self):
+    def test_deserialization_across_process(self):
         """A serialized DAG can be deserialized in another process."""
+
+        # Since we need to parse the dags twice here (once in the subprocess,
+        # and once here to get a DAG to compare to) we don't want to load all
+        # dags.
         queue = multiprocessing.Queue()
         proc = multiprocessing.Process(
-            target=serialize_subprocess, args=(queue,))
+            target=serialize_subprocess, args=(queue, "airflow/example_dags"))
         proc.daemon = True
         proc.start()
 
@@ -262,69 +280,100 @@ class TestStringifiedDAGs(unittest.TestCase):
             self.assertTrue(isinstance(dag, DAG))
             stringified_dags[dag.dag_id] = dag
 
-        dags = collect_dags()
-        self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))
+        dags = collect_dags("airflow/example_dags")
+        assert set(stringified_dags.keys()) == set(dags.keys())
 
         # Verify deserialized DAGs.
         for dag_id in stringified_dags:
             self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
 
-        example_skip_dag = stringified_dags['example_skip_dag']
-        skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
-        self.validate_deserialized_task(
-            skip_operator_1_task, 'DummySkipOperator', '#e8b7e4', '#000')
+    def test_roundtrip_provider_example_dags(self):
+        dags = collect_dags([
+            "airflow/providers/*/example_dags",
+            "airflow/providers/*/*/example_dags",
+        ])
 
-        # Verify that the DAG object has 'full_filepath' attribute
-        # and is equal to fileloc
-        self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
-        self.assertEqual(example_skip_dag.full_filepath, example_skip_dag.fileloc)
-
-        example_subdag_operator = stringified_dags['example_subdag_operator']
-        section_1_task = example_subdag_operator.task_dict['section-1']
-        self.validate_deserialized_task(
-            section_1_task,
-            SubDagOperator.__name__,
-            SubDagOperator.ui_color,
-            SubDagOperator.ui_fgcolor
-        )
+        # Verify deserialized DAGs.
+        for dag in dags.values():
+            serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+            self.validate_deserialized_dag(serialized_dag, dag)
 
     def validate_deserialized_dag(self, serialized_dag, dag):
         """
         Verify that all example DAGs work with DAG Serialization by
         checking fields between Serialized Dags & non-Serialized Dags
         """
-        fields_to_check = [
-            "params", "fileloc", "max_active_runs", "concurrency",
-            "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
-            "catchup", "description", "start_date", "end_date", "parent_dag",
-            "template_searchpath", "_access_control", "dagrun_timeout"
-        ]
+        fields_to_check = dag.get_serialized_fields() - {
+            # Doesn't implement __eq__ properly. Check manually
+            'timezone',
 
-        # fields_to_check = dag.get_serialized_fields()
+            # Need to check fields in it, to exclude functions
+            'default_args',
+        }
         for field in fields_to_check:
-            self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))
+            assert getattr(serialized_dag, field) == getattr(dag, field), \
+                '{}.{} does not match'.format(dag.dag_id, field)
 
-        self.assertEqual(
-            sorted(serialized_dag.task_ids),
-            sorted([str(task) for task in dag.task_ids]))
+        if dag.default_args:
+            for k, v in dag.default_args.items():
+                if callable(v):
+                    # Check we stored _someting_.
+                    assert k in serialized_dag.default_args
+                else:
+                    assert v == serialized_dag.default_args[k], \
+                        '{}.default_args[{}] does not match'.format(dag.dag_id, k)
+
+        assert serialized_dag.timezone.name == dag.timezone.name
+
+        for task_id in dag.task_ids:
+            self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id))
+
+        # Verify that the DAG object has 'full_filepath' attribute
+        # and is equal to fileloc
+        assert serialized_dag.full_filepath == dag.fileloc
 
-    def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
+    def validate_deserialized_task(self, serialized_task, task,):
         """Verify non-airflow operators are casted to BaseOperator."""
-        self.assertTrue(isinstance(task, SerializedBaseOperator))
-        # Verify the original operator class is recorded for UI.
-        self.assertTrue(task.task_type == task_type)
-        self.assertTrue(task.ui_color == ui_color)
-        self.assertTrue(task.ui_fgcolor == ui_fgcolor)
+        assert isinstance(serialized_task, SerializedBaseOperator)
+        assert not isinstance(task, SerializedBaseOperator)
+        assert isinstance(task, BaseOperator)
+
+        fields_to_check = task.get_serialized_fields() - {
+            # Checked separately
+            '_task_type', 'subdag',
+
+            # Type is exluded, so don't check it
+            '_log',
+
+            # List vs tuple. Check separately
+            'template_fields',
+
+            # We store the string, real dag has the actual code
+            'on_failure_callback', 'on_success_callback', 'on_retry_callback',
+
+            # Checked separately
+            'resources',
+        }
+
+        assert serialized_task.task_type == task.task_type
+        assert set(serialized_task.template_fields) == set(task.template_fields)
+
+        for field in fields_to_check:
+            assert getattr(serialized_task, field) == getattr(task, field), \
+                '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)
+
+        if serialized_task.resources is None:
+            assert task.resources is None or task.resources == []
+        else:
+            assert serialized_task.resources == task.resources
 
         # Check that for Deserialised task, task.subdag is None for all other Operators
         # except for the SubDagOperator where task.subdag is an instance of DAG object
         if task.task_type == "SubDagOperator":
-            self.assertIsNotNone(task.subdag)
-            self.assertTrue(isinstance(task.subdag, DAG))
+            assert serialized_task.subdag is not None
+            assert isinstance(serialized_task.subdag, DAG)
         else:
-            self.assertIsNone(task.subdag)
-        self.assertEqual({}, task.params)
-        self.assertEqual({}, task.executor_config)
+            assert serialized_task.subdag is None
 
     @parameterized.expand([
         (datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
@@ -650,6 +699,23 @@ class TestStringifiedDAGs(unittest.TestCase):
         dag_params = set(dag_schema.keys()) - ignored_keys
         self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
 
+    def test_operator_subclass_changing_base_defaults(self):
+        assert BaseOperator(task_id='dummy').do_xcom_push is True, \
+            "Precondition check! If this fails the test won't make sense"
+
+        class MyOperator(BaseOperator):
+            def __init__(self, do_xcom_push=False, **kwargs):
+                super(MyOperator, self).__init__(**kwargs)
+                self.do_xcom_push = do_xcom_push
+
+        op = MyOperator(task_id='dummy')
+        assert op.do_xcom_push is False
+
+        blob = SerializedBaseOperator.serialize_operator(op)
+        serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+
+        assert serialized_op.do_xcom_push is False
+
     def test_no_new_fields_added_to_base_operator(self):
         """
         This test verifies that there are no new fields added to BaseOperator. And reminds that