You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/11/05 23:24:55 UTC

[airflow] 03/06: Fix serialization of Params with set data type (#19267)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 157a864d67627aacd960848791af913720ad45eb
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Fri Nov 5 10:19:50 2021 -0700

    Fix serialization of Params with set data type (#19267)
    
    This is a solution for https://github.com/apache/airflow/issues/19096
    
    Previously, the serialization of params did not run the param value through the `_serialize` function, resulting in non-json-serializable dictionaries.  This manifested when a user, for example, tried to use params with a default value of type `set`.
    
    Here we change the logic to run the param value through the serialization process.  And I add a test for the `set` case.
    
    closes https://github.com/apache/airflow/issues/19096
    
    (cherry picked from commit 8512e0507263495ddd326e27699c45cafd31a5e1)
---
 airflow/models/param.py                       |  4 +-
 airflow/serialization/schema.json             | 22 +++++++-
 airflow/serialization/serialized_objects.py   | 50 ++++++++++++++-----
 tests/serialization/test_dag_serialization.py | 72 ++++++++++++++++++++++++---
 4 files changed, 125 insertions(+), 23 deletions(-)

diff --git a/airflow/models/param.py b/airflow/models/param.py
index 1ae01dc..53ac79a 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,7 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 from typing import Any, Dict, Optional
 
 import jsonschema
@@ -49,6 +48,7 @@ class Param:
     """
 
     __NO_VALUE_SENTINEL = NoValueSentinel()
+    CLASS_IDENTIFIER = '__class'
 
     def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs):
         self.value = default
@@ -90,7 +90,7 @@ class Param:
 
     def dump(self) -> dict:
         """Dump the Param as a dictionary"""
-        out_dict = {'__class': f'{self.__module__}.{self.__class__.__name__}'}
+        out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'}
         out_dict.update(self.__dict__)
         return out_dict
 
diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json
index b4a64b4..6d25c1e 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -78,7 +78,7 @@
     "dag": {
       "type": "object",
       "properties": {
-        "params": { "$ref": "#/definitions/dict" },
+        "params": { "$ref": "#/definitions/params_dict" },
         "_dag_id": { "type": "string" },
         "tasks": {  "$ref": "#/definitions/tasks" },
         "timezone": { "$ref": "#/definitions/timezone" },
@@ -135,6 +135,24 @@
       "type": "array",
       "additionalProperties": { "$ref": "#/definitions/operator" }
     },
+    "params_dict": {
+      "type": "object",
+      "additionalProperties": {"$ref": "#/definitions/param" }
+    },
+    "param": {
+      "$comment": "A param for a dag / operator",
+      "type": "object",
+      "required": [
+        "__class",
+        "default"
+      ],
+      "properties": {
+        "__class": { "type": "string" },
+        "default": {},
+        "description": {"anyOf": [{"type":"string"}, {"type":"null"}]},
+        "schema": { "$ref": "#/definitions/dict" }
+      }
+    },
     "operator": {
       "$comment": "A task/operator in a DAG",
       "type": "object",
@@ -166,7 +184,7 @@
         "retry_delay": { "$ref": "#/definitions/timedelta" },
         "retry_exponential_backoff": { "type": "boolean" },
         "max_retry_delay": { "$ref": "#/definitions/timedelta" },
-        "params": { "$ref": "#/definitions/dict" },
+        "params": { "$ref": "#/definitions/params_dict" },
         "priority_weight": { "type": "number" },
         "weight_rule": { "type": "string" },
         "executor_config": { "$ref": "#/definitions/dict" },
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 6ec9770..c451695 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -55,7 +55,6 @@ try:
 except ImportError:
     HAS_KUBERNETES = False
 
-
 if TYPE_CHECKING:
     from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 
@@ -325,7 +324,7 @@ class BaseSerialization:
         elif isinstance(var, TaskGroup):
             return SerializedTaskGroup.serialize_task_group(var)
         elif isinstance(var, Param):
-            return cls._encode(var.dump(), type_=DAT.PARAM)
+            return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         else:
             log.debug('Cast type %s to str in serialization.', type(var))
             return str(var)
@@ -368,9 +367,7 @@ class BaseSerialization:
         elif type_ == DAT.TUPLE:
             return tuple(cls._deserialize(v) for v in var)
         elif type_ == DAT.PARAM:
-            param_class = import_string(var['_type'])
-            del var['_type']
-            return param_class(**var)
+            return cls._deserialize_param(var)
         else:
             raise TypeError(f'Invalid type {type_!s} in deserialization.')
 
@@ -410,29 +407,58 @@ class BaseSerialization:
         return False
 
     @classmethod
+    def _serialize_param(cls, param: Param):
+        return dict(
+            __class=f"{param.__module__}.{param.__class__.__name__}",
+            default=cls._serialize(param.value),
+            description=cls._serialize(param.description),
+            schema=cls._serialize(param.schema),
+        )
+
+    @classmethod
+    def _deserialize_param(cls, param_dict: Dict):
+        """
+        In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
+        this class's ``_serialize`` method.  So before running through ``_deserialize``,
+        we first verify that it's necessary to do.
+        """
+        class_name = param_dict['__class']
+        class_ = import_string(class_name)  # type: Type[Param]
+        attrs = ('default', 'description', 'schema')
+        kwargs = {}
+        for attr in attrs:
+            if attr not in param_dict:
+                continue
+            val = param_dict[attr]
+            is_serialized = isinstance(val, dict) and '__type' in val
+            if is_serialized:
+                deserialized_val = cls._deserialize(param_dict[attr])
+                kwargs[attr] = deserialized_val
+            else:
+                kwargs[attr] = val
+        return class_(**kwargs)
+
+    @classmethod
     def _serialize_params_dict(cls, params: ParamsDict):
         """Serialize Params dict for a DAG/Task"""
         serialized_params = {}
         for k, v in params.items():
             # TODO: As of now, we would allow serialization of params which are of type Param only
             if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
-                kwargs = v.dump()
-                kwargs['default'] = kwargs.pop('value')
-                serialized_params[k] = kwargs
+                serialized_params[k] = cls._serialize_param(v)
             else:
                 raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
         return serialized_params
 
     @classmethod
     def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
-        """Deserialize a DAGs Params dict"""
+        """Deserialize a DAG's Params dict"""
         op_params = {}
         for k, v in encoded_params.items():
             if isinstance(v, dict) and "__class" in v:
-                param_class = import_string(v['__class'])
-                op_params[k] = param_class(**v)
+                op_params[k] = cls._deserialize_param(v)
             else:
-                # Old style params, upgrade it
+                # Old style params, convert it
                 op_params[k] = Param(v)
 
         return ParamsDict(op_params)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index afba96a..6fec7f9 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -21,6 +21,7 @@
 import copy
 import importlib
 import importlib.util
+import json
 import multiprocessing
 import os
 from datetime import datetime, timedelta
@@ -724,6 +725,7 @@ class TestStringifiedDAGs:
         [
             (None, {}),
             ({"param_1": "value_1"}, {"param_1": "value_1"}),
+            ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
         ],
     )
     def test_dag_params_roundtrip(self, val, expected_val):
@@ -733,7 +735,10 @@ class TestStringifiedDAGs:
         dag = DAG(dag_id='simple_dag', params=val)
         BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))
 
-        serialized_dag = SerializedDAG.to_dict(dag)
+        serialized_dag_json = SerializedDAG.to_json(dag)
+
+        serialized_dag = json.loads(serialized_dag_json)
+
         assert "params" in serialized_dag["dag"]
 
         deserialized_dag = SerializedDAG.from_dict(serialized_dag)
@@ -764,14 +769,37 @@ class TestStringifiedDAGs:
             params={'path': S3Param('s3://my_bucket/my_path')},
         )
 
-        with pytest.raises(SerializationError):
-            SerializedDAG.to_dict(dag)
+    @pytest.mark.parametrize(
+        'param',
+        [
+            Param('my value', description='hello', schema={'type': 'string'}),
+            Param('my value', description='hello'),
+            Param(None, description=None),
+        ],
+    )
+    def test_full_param_roundtrip(self, param):
+        """
+        Test to make sure that only native Param objects are being passed as dag or task params
+        """
+
+        dag = DAG(dag_id='simple_dag', params={'my_param': param})
+        serialized_json = SerializedDAG.to_json(dag)
+        serialized = json.loads(serialized_json)
+        SerializedDAG.validate_schema(serialized)
+        dag = SerializedDAG.from_dict(serialized)
+
+        assert dag.params["my_param"] == param.value
+        observed_param = dict.get(dag.params, 'my_param')
+        assert isinstance(observed_param, Param)
+        assert observed_param.description == param.description
+        assert observed_param.schema == param.schema
 
     @pytest.mark.parametrize(
         "val, expected_val",
         [
             (None, {}),
             ({"param_1": "value_1"}, {"param_1": "value_1"}),
+            ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
         ],
     )
     def test_task_params_roundtrip(self, val, expected_val):
@@ -1433,29 +1461,32 @@ class TestStringifiedDAGs:
         assert serialized_obj == expected_output
 
     def test_params_upgrade(self):
+        """when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param"""
         serialized = {
             "__version": 1,
             "dag": {
                 "_dag_id": "simple_dag",
-                "fileloc": __file__,
+                "fileloc": '/path/to/file.py',
                 "tasks": [],
                 "timezone": "UTC",
                 "params": {"none": None, "str": "str", "dict": {"a": "b"}},
             },
         }
-        SerializedDAG.validate_schema(serialized)
         dag = SerializedDAG.from_dict(serialized)
 
         assert dag.params["none"] is None
         assert isinstance(dict.__getitem__(dag.params, "none"), Param)
         assert dag.params["str"] == "str"
 
-    def test_params_serialize_default(self):
+    def test_params_serialize_default_2_2_0(self):
+        """In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
+        the standard serializer function.  In 2.2.2 we serialize param ``default``.  We keep this
+        test only to ensure that params stored in 2.2.0 can still be parsed correctly."""
         serialized = {
             "__version": 1,
             "dag": {
                 "_dag_id": "simple_dag",
-                "fileloc": __file__,
+                "fileloc": '/path/to/file.py',
                 "tasks": [],
                 "timezone": "UTC",
                 "params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}},
@@ -1467,6 +1498,33 @@ class TestStringifiedDAGs:
         assert isinstance(dict.__getitem__(dag.params, "str"), Param)
         assert dag.params["str"] == "str"
 
+    def test_params_serialize_default(self):
+        serialized = {
+            "__version": 1,
+            "dag": {
+                "_dag_id": "simple_dag",
+                "fileloc": '/path/to/file.py',
+                "tasks": [],
+                "timezone": "UTC",
+                "params": {
+                    "my_param": {
+                        "default": "a string value",
+                        "description": "hello",
+                        "schema": {"__var": {"type": "string"}, "__type": "dict"},
+                        "__class": "airflow.models.param.Param",
+                    }
+                },
+            },
+        }
+        SerializedDAG.validate_schema(serialized)
+        dag = SerializedDAG.from_dict(serialized)
+
+        assert dag.params["my_param"] == "a string value"
+        param = dict.get(dag.params, 'my_param')
+        assert isinstance(param, Param)
+        assert param.description == 'hello'
+        assert param.schema == {'type': 'string'}
+
 
 def test_kubernetes_optional():
     """Serialisation / deserialisation continues to work without kubernetes installed"""