You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/06/26 12:53:42 UTC

[airflow] branch v1-10-test updated (686716b -> 45bf7f4)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a change to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git.


    from 686716b  [AIRFLOW-6778] Add a configurable DAGs volume mount path for Kubernetes (#8147)
     new 340e43d  Correctly store non-default Nones in serialized tasks/dags (#8772)
     new 45bf7f4  Correctly restore upstream_task_ids when deserializing Operators (#8775)

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 airflow/serialization/serialized_objects.py   |  26 +++-
 tests/serialization/test_dag_serialization.py | 179 ++++++++++++++++++--------
 2 files changed, 143 insertions(+), 62 deletions(-)


[airflow] 01/02: Correctly store non-default Nones in serialized tasks/dags (#8772)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 340e43d86b0f9592553c8e6eae9b1f14e6f72179
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 08:57:21 2020 +0100

    Correctly store non-default Nones in serialized tasks/dags (#8772)
    
    The default schedule_interval for a DAG is `@daily`, so
    `schedule_interval=None` is actually not the default, but we were not
    storing _any_ null attributes previously.
    
    This meant that upon re-inflating the DAG the schedule_interval would
    become @daily.
    
    This fixes that problem, and extends the test to look at _all_ the
    serialized attributes in our round-trip tests, rather than just the few
    that the webserver cared about.
    
    It doesn't change the serialization format, it just changes what/when
    values were stored.
    
    This solution was more complex than I hoped for, but the test case in
    test_operator_subclass_changing_base_defaults is a real one that the
    round trip tests discovered from the DatabricksSubmitRunOperator -- I
    have just captured it in this test in case that specific operator
    changes in future.
    
    (cherry picked from commit a715aa692e88160cb8e9df4effda2440e4778c17)
---
 airflow/serialization/serialized_objects.py   |  24 +++-
 tests/serialization/test_dag_serialization.py | 176 ++++++++++++++++++--------
 2 files changed, 139 insertions(+), 61 deletions(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 917d80f..3e564ec 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -122,10 +122,16 @@ class BaseSerialization:
     @classmethod
     def _is_excluded(cls, var, attrname, instance):
         """Types excluded from serialization."""
+
+        if var is None:
+            if not cls._is_constructor_param(attrname, instance):
+                # Any instance attribute, that is not a constructor argument, we exclude None as the default
+                return True
+
+            return cls._value_is_hardcoded_default(attrname, var, instance)
         return (
-            var is None or
             isinstance(var, cls._excluded_types) or
-            cls._value_is_hardcoded_default(attrname, var)
+            cls._value_is_hardcoded_default(attrname, var, instance)
         )
 
     @classmethod
@@ -259,7 +265,12 @@ class BaseSerialization:
         return datetime.timedelta(seconds=seconds)
 
     @classmethod
-    def _value_is_hardcoded_default(cls, attrname, value):
+    def _is_constructor_param(cls, attrname, instance):
+        # pylint: disable=unused-argument
+        return attrname in cls._CONSTRUCTOR_PARAMS
+
+    @classmethod
+    def _value_is_hardcoded_default(cls, attrname, value, instance):
         """
         Return true if ``value`` is the hard-coded default for the given attribute.
         This takes in to account cases where the ``concurrency`` parameter is
@@ -273,8 +284,9 @@ class BaseSerialization:
         to account for the case where the default value of the field is None but has the
         ``field = field or {}`` set.
         """
+        # pylint: disable=unused-argument
         if attrname in cls._CONSTRUCTOR_PARAMS and \
-                (cls._CONSTRUCTOR_PARAMS[attrname].default is value or (value in [{}, []])):
+                (cls._CONSTRUCTOR_PARAMS[attrname] is value or (value in [{}, []])):
             return True
         return False
 
@@ -288,7 +300,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     _decorated_fields = {'executor_config', }
 
     _CONSTRUCTOR_PARAMS = {
-        k: v for k, v in signature(BaseOperator).parameters.items()
+        k: v.default for k, v in signature(BaseOperator).parameters.items()
         if v.default is not v.empty
     }
 
@@ -511,7 +523,7 @@ class SerializedDAG(DAG, BaseSerialization):
             'access_control': '_access_control',
         }
         return {
-            param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items()
+            param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
             if v.default is not v.empty
         }
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 52f6e1a..e28e2b2 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -18,7 +18,9 @@
 # under the License.
 
 """Unit tests for stringified DAGs."""
+from glob import glob
 import multiprocessing
+import os
 import unittest
 
 import six
@@ -29,13 +31,10 @@ from datetime import datetime, timedelta
 from parameterized import parameterized
 from dateutil.relativedelta import relativedelta, FR
 
-from airflow import example_dags
-from airflow.contrib import example_dags as contrib_example_dags
 from airflow.hooks.base_hook import BaseHook
 from airflow.models import DAG, Connection, DagBag, TaskInstance
 from airflow.models.baseoperator import BaseOperator
 from airflow.operators.bash_operator import BashOperator
-from airflow.operators.subdag_operator import SubDagOperator
 from airflow.serialization.json_schema import load_dag_schema_dict
 from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
 from airflow.utils.tests import CustomOperator, CustomOpLink, GoogleLink
@@ -110,10 +109,14 @@ serialized_simple_dag_ground_truth = {
     },
 }
 
+ROOT_FOLDER = os.path.realpath(
+    os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir)
+)
 
-def make_example_dags(module):
+
+def make_example_dags(module_path):
     """Loads DAGs from a module for test."""
-    dagbag = DagBag(module.__path__[0])
+    dagbag = DagBag(module_path)
     return dagbag.dags
 
 
@@ -170,22 +173,34 @@ def make_user_defined_macro_filter_dag():
     return {dag.dag_id: dag}
 
 
-def collect_dags():
+def collect_dags(dag_folder=None):
     """Collects DAGs to test."""
     dags = {}
     dags.update(make_simple_dag())
     dags.update(make_user_defined_macro_filter_dag())
-    dags.update(make_example_dags(example_dags))
-    dags.update(make_example_dags(contrib_example_dags))
+
+    if dag_folder:
+        if isinstance(dag_folder, (list, tuple)):
+            patterns = dag_folder
+        else:
+            patterns = [dag_folder]
+    else:
+        patterns = [
+            "airflow/example_dags",
+            "airflow/contrib/example_dags",
+        ]
+    for pattern in patterns:
+        for directory in glob(ROOT_FOLDER + "/" + pattern):
+            dags.update(make_example_dags(directory))
 
     # Filter subdags as they are stored in same row in Serialized Dag table
     dags = {dag_id: dag for dag_id, dag in dags.items() if not dag.is_subdag}
     return dags
 
 
-def serialize_subprocess(queue):
+def serialize_subprocess(queue, dag_folder):
     """Validate pickle in a subprocess."""
-    dags = collect_dags()
+    dags = collect_dags(dag_folder)
     for dag in dags.values():
         queue.put(SerializedDAG.to_json(dag))
     queue.put(None)
@@ -242,14 +257,17 @@ class TestStringifiedDAGs(unittest.TestCase):
             )
             return dag_dict
 
-        self.assertEqual(sorted_serialized_dag(ground_truth_dag),
-                         sorted_serialized_dag(json_dag))
+        assert sorted_serialized_dag(ground_truth_dag) == sorted_serialized_dag(json_dag)
 
-    def test_deserialization(self):
+    def test_deserialization_across_process(self):
         """A serialized DAG can be deserialized in another process."""
+
+        # Since we need to parse the dags twice here (once in the subprocess,
+        # and once here to get a DAG to compare to) we don't want to load all
+        # dags.
         queue = multiprocessing.Queue()
         proc = multiprocessing.Process(
-            target=serialize_subprocess, args=(queue,))
+            target=serialize_subprocess, args=(queue, "airflow/example_dags"))
         proc.daemon = True
         proc.start()
 
@@ -262,69 +280,100 @@ class TestStringifiedDAGs(unittest.TestCase):
             self.assertTrue(isinstance(dag, DAG))
             stringified_dags[dag.dag_id] = dag
 
-        dags = collect_dags()
-        self.assertTrue(set(stringified_dags.keys()) == set(dags.keys()))
+        dags = collect_dags("airflow/example_dags")
+        assert set(stringified_dags.keys()) == set(dags.keys())
 
         # Verify deserialized DAGs.
         for dag_id in stringified_dags:
             self.validate_deserialized_dag(stringified_dags[dag_id], dags[dag_id])
 
-        example_skip_dag = stringified_dags['example_skip_dag']
-        skip_operator_1_task = example_skip_dag.task_dict['skip_operator_1']
-        self.validate_deserialized_task(
-            skip_operator_1_task, 'DummySkipOperator', '#e8b7e4', '#000')
+    def test_roundtrip_provider_example_dags(self):
+        dags = collect_dags([
+            "airflow/providers/*/example_dags",
+            "airflow/providers/*/*/example_dags",
+        ])
 
-        # Verify that the DAG object has 'full_filepath' attribute
-        # and is equal to fileloc
-        self.assertTrue(hasattr(example_skip_dag, 'full_filepath'))
-        self.assertEqual(example_skip_dag.full_filepath, example_skip_dag.fileloc)
-
-        example_subdag_operator = stringified_dags['example_subdag_operator']
-        section_1_task = example_subdag_operator.task_dict['section-1']
-        self.validate_deserialized_task(
-            section_1_task,
-            SubDagOperator.__name__,
-            SubDagOperator.ui_color,
-            SubDagOperator.ui_fgcolor
-        )
+        # Verify deserialized DAGs.
+        for dag in dags.values():
+            serialized_dag = SerializedDAG.from_json(SerializedDAG.to_json(dag))
+            self.validate_deserialized_dag(serialized_dag, dag)
 
     def validate_deserialized_dag(self, serialized_dag, dag):
         """
         Verify that all example DAGs work with DAG Serialization by
         checking fields between Serialized Dags & non-Serialized Dags
         """
-        fields_to_check = [
-            "params", "fileloc", "max_active_runs", "concurrency",
-            "is_paused_upon_creation", "doc_md", "safe_dag_id", "is_subdag",
-            "catchup", "description", "start_date", "end_date", "parent_dag",
-            "template_searchpath", "_access_control", "dagrun_timeout"
-        ]
+        fields_to_check = dag.get_serialized_fields() - {
+            # Doesn't implement __eq__ properly. Check manually
+            'timezone',
 
-        # fields_to_check = dag.get_serialized_fields()
+            # Need to check fields in it, to exclude functions
+            'default_args',
+        }
         for field in fields_to_check:
-            self.assertEqual(getattr(serialized_dag, field), getattr(dag, field))
+            assert getattr(serialized_dag, field) == getattr(dag, field), \
+                '{}.{} does not match'.format(dag.dag_id, field)
 
-        self.assertEqual(
-            sorted(serialized_dag.task_ids),
-            sorted([str(task) for task in dag.task_ids]))
+        if dag.default_args:
+            for k, v in dag.default_args.items():
+                if callable(v):
+                    # Check we stored _someting_.
+                    assert k in serialized_dag.default_args
+                else:
+                    assert v == serialized_dag.default_args[k], \
+                        '{}.default_args[{}] does not match'.format(dag.dag_id, k)
+
+        assert serialized_dag.timezone.name == dag.timezone.name
+
+        for task_id in dag.task_ids:
+            self.validate_deserialized_task(serialized_dag.get_task(task_id), dag.get_task(task_id))
+
+        # Verify that the DAG object has 'full_filepath' attribute
+        # and is equal to fileloc
+        assert serialized_dag.full_filepath == dag.fileloc
 
-    def validate_deserialized_task(self, task, task_type, ui_color, ui_fgcolor):
+    def validate_deserialized_task(self, serialized_task, task,):
         """Verify non-airflow operators are casted to BaseOperator."""
-        self.assertTrue(isinstance(task, SerializedBaseOperator))
-        # Verify the original operator class is recorded for UI.
-        self.assertTrue(task.task_type == task_type)
-        self.assertTrue(task.ui_color == ui_color)
-        self.assertTrue(task.ui_fgcolor == ui_fgcolor)
+        assert isinstance(serialized_task, SerializedBaseOperator)
+        assert not isinstance(task, SerializedBaseOperator)
+        assert isinstance(task, BaseOperator)
+
+        fields_to_check = task.get_serialized_fields() - {
+            # Checked separately
+            '_task_type', 'subdag',
+
+            # Type is exluded, so don't check it
+            '_log',
+
+            # List vs tuple. Check separately
+            'template_fields',
+
+            # We store the string, real dag has the actual code
+            'on_failure_callback', 'on_success_callback', 'on_retry_callback',
+
+            # Checked separately
+            'resources',
+        }
+
+        assert serialized_task.task_type == task.task_type
+        assert set(serialized_task.template_fields) == set(task.template_fields)
+
+        for field in fields_to_check:
+            assert getattr(serialized_task, field) == getattr(task, field), \
+                '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)
+
+        if serialized_task.resources is None:
+            assert task.resources is None or task.resources == []
+        else:
+            assert serialized_task.resources == task.resources
 
         # Check that for Deserialised task, task.subdag is None for all other Operators
         # except for the SubDagOperator where task.subdag is an instance of DAG object
         if task.task_type == "SubDagOperator":
-            self.assertIsNotNone(task.subdag)
-            self.assertTrue(isinstance(task.subdag, DAG))
+            assert serialized_task.subdag is not None
+            assert isinstance(serialized_task.subdag, DAG)
         else:
-            self.assertIsNone(task.subdag)
-        self.assertEqual({}, task.params)
-        self.assertEqual({}, task.executor_config)
+            assert serialized_task.subdag is None
 
     @parameterized.expand([
         (datetime(2019, 8, 1), None, datetime(2019, 8, 1)),
@@ -650,6 +699,23 @@ class TestStringifiedDAGs(unittest.TestCase):
         dag_params = set(dag_schema.keys()) - ignored_keys
         self.assertEqual(set(DAG.get_serialized_fields()), dag_params)
 
+    def test_operator_subclass_changing_base_defaults(self):
+        assert BaseOperator(task_id='dummy').do_xcom_push is True, \
+            "Precondition check! If this fails the test won't make sense"
+
+        class MyOperator(BaseOperator):
+            def __init__(self, do_xcom_push=False, **kwargs):
+                super(MyOperator, self).__init__(**kwargs)
+                self.do_xcom_push = do_xcom_push
+
+        op = MyOperator(task_id='dummy')
+        assert op.do_xcom_push is False
+
+        blob = SerializedBaseOperator.serialize_operator(op)
+        serialized_op = SerializedBaseOperator.deserialize_operator(blob)
+
+        assert serialized_op.do_xcom_push is False
+
     def test_no_new_fields_added_to_base_operator(self):
         """
         This test verifies that there are no new fields added to BaseOperator. And reminds that


[airflow] 02/02: Correctly restore upstream_task_ids when deserializing Operators (#8775)

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v1-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 45bf7f4a557e38201fdffcd7f535b08d9835d9aa
Author: Ash Berlin-Taylor <as...@firemirror.com>
AuthorDate: Sun May 10 11:41:47 2020 +0100

    Correctly restore upstream_task_ids when deserializing Operators (#8775)
    
    This test exposed a bug in one of the example dags, that wasn't caught
    by #6549. That will be a fixed in a separate issue, but it caused the
    round-trip tests to fail here
    
    Fixes #8720
    
    (cherry picked from commit 280f1f0c4cc49aba1b2f8b456326795733769d18)
---
 airflow/serialization/serialized_objects.py   | 2 +-
 tests/serialization/test_dag_serialization.py | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 3e564ec..8d261aa 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -591,7 +591,7 @@ class SerializedDAG(DAG, BaseSerialization):
             for task_id in serializable_task.downstream_task_ids:
                 # Bypass set_upstream etc here - it does more than we want
                 # noinspection PyProtectedMember
-                dag.task_dict[task_id]._upstream_task_ids.add(task_id)  # pylint: disable=protected-access
+                dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id)  # noqa: E501 # pylint: disable=protected-access
 
         return dag
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index e28e2b2..6b714a8 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -358,6 +358,9 @@ class TestStringifiedDAGs(unittest.TestCase):
         assert serialized_task.task_type == task.task_type
         assert set(serialized_task.template_fields) == set(task.template_fields)
 
+        assert serialized_task.upstream_task_ids == task.upstream_task_ids
+        assert serialized_task.downstream_task_ids == task.downstream_task_ids
+
         for field in fields_to_check:
             assert getattr(serialized_task, field) == getattr(task, field), \
                 '{}.{}.{} does not match'.format(task.dag.dag_id, task.task_id, field)