You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/11/05 23:24:55 UTC
[airflow] 03/06: Fix serialization of Params with set data type
(#19267)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 157a864d67627aacd960848791af913720ad45eb
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Fri Nov 5 10:19:50 2021 -0700
Fix serialization of Params with set data type (#19267)
This is a solution for https://github.com/apache/airflow/issues/19096
Previously, the serialization of params did not run the param value through the `_serialize` function, resulting in non-json-serializable dictionaries. This manifested when a user, for example, tried to use params with a default value of type `set`.
Here we change the logic to run the param value through the serialization process. And I add a test for the `set` case.
closes https://github.com/apache/airflow/issues/19096
(cherry picked from commit 8512e0507263495ddd326e27699c45cafd31a5e1)
---
airflow/models/param.py | 4 +-
airflow/serialization/schema.json | 22 +++++++-
airflow/serialization/serialized_objects.py | 50 ++++++++++++++-----
tests/serialization/test_dag_serialization.py | 72 ++++++++++++++++++++++++---
4 files changed, 125 insertions(+), 23 deletions(-)
diff --git a/airflow/models/param.py b/airflow/models/param.py
index 1ae01dc..53ac79a 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
from typing import Any, Dict, Optional
import jsonschema
@@ -49,6 +48,7 @@ class Param:
"""
__NO_VALUE_SENTINEL = NoValueSentinel()
+ CLASS_IDENTIFIER = '__class'
def __init__(self, default: Any = __NO_VALUE_SENTINEL, description: str = None, **kwargs):
self.value = default
@@ -90,7 +90,7 @@ class Param:
def dump(self) -> dict:
"""Dump the Param as a dictionary"""
- out_dict = {'__class': f'{self.__module__}.{self.__class__.__name__}'}
+ out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'}
out_dict.update(self.__dict__)
return out_dict
diff --git a/airflow/serialization/schema.json b/airflow/serialization/schema.json
index b4a64b4..6d25c1e 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -78,7 +78,7 @@
"dag": {
"type": "object",
"properties": {
- "params": { "$ref": "#/definitions/dict" },
+ "params": { "$ref": "#/definitions/params_dict" },
"_dag_id": { "type": "string" },
"tasks": { "$ref": "#/definitions/tasks" },
"timezone": { "$ref": "#/definitions/timezone" },
@@ -135,6 +135,24 @@
"type": "array",
"additionalProperties": { "$ref": "#/definitions/operator" }
},
+ "params_dict": {
+ "type": "object",
+ "additionalProperties": {"$ref": "#/definitions/param" }
+ },
+ "param": {
+ "$comment": "A param for a dag / operator",
+ "type": "object",
+ "required": [
+ "__class",
+ "default"
+ ],
+ "properties": {
+ "__class": { "type": "string" },
+ "default": {},
+ "description": {"anyOf": [{"type":"string"}, {"type":"null"}]},
+ "schema": { "$ref": "#/definitions/dict" }
+ }
+ },
"operator": {
"$comment": "A task/operator in a DAG",
"type": "object",
@@ -166,7 +184,7 @@
"retry_delay": { "$ref": "#/definitions/timedelta" },
"retry_exponential_backoff": { "type": "boolean" },
"max_retry_delay": { "$ref": "#/definitions/timedelta" },
- "params": { "$ref": "#/definitions/dict" },
+ "params": { "$ref": "#/definitions/params_dict" },
"priority_weight": { "type": "number" },
"weight_rule": { "type": "string" },
"executor_config": { "$ref": "#/definitions/dict" },
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 6ec9770..c451695 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -55,7 +55,6 @@ try:
except ImportError:
HAS_KUBERNETES = False
-
if TYPE_CHECKING:
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
@@ -325,7 +324,7 @@ class BaseSerialization:
elif isinstance(var, TaskGroup):
return SerializedTaskGroup.serialize_task_group(var)
elif isinstance(var, Param):
- return cls._encode(var.dump(), type_=DAT.PARAM)
+ return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
else:
log.debug('Cast type %s to str in serialization.', type(var))
return str(var)
@@ -368,9 +367,7 @@ class BaseSerialization:
elif type_ == DAT.TUPLE:
return tuple(cls._deserialize(v) for v in var)
elif type_ == DAT.PARAM:
- param_class = import_string(var['_type'])
- del var['_type']
- return param_class(**var)
+ return cls._deserialize_param(var)
else:
raise TypeError(f'Invalid type {type_!s} in deserialization.')
@@ -410,29 +407,58 @@ class BaseSerialization:
return False
@classmethod
+ def _serialize_param(cls, param: Param):
+ return dict(
+ __class=f"{param.__module__}.{param.__class__.__name__}",
+ default=cls._serialize(param.value),
+ description=cls._serialize(param.description),
+ schema=cls._serialize(param.schema),
+ )
+
+ @classmethod
+ def _deserialize_param(cls, param_dict: Dict):
+ """
+ In 2.2.0, Param attrs were assumed to be json-serializable and were not run through
+ this class's ``_serialize`` method. So before running through ``_deserialize``,
+ we first verify that it's necessary to do.
+ """
+ class_name = param_dict['__class']
+ class_ = import_string(class_name) # type: Type[Param]
+ attrs = ('default', 'description', 'schema')
+ kwargs = {}
+ for attr in attrs:
+ if attr not in param_dict:
+ continue
+ val = param_dict[attr]
+ is_serialized = isinstance(val, dict) and '__type' in val
+ if is_serialized:
+ deserialized_val = cls._deserialize(param_dict[attr])
+ kwargs[attr] = deserialized_val
+ else:
+ kwargs[attr] = val
+ return class_(**kwargs)
+
+ @classmethod
def _serialize_params_dict(cls, params: ParamsDict):
"""Serialize Params dict for a DAG/Task"""
serialized_params = {}
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only
if f'{v.__module__}.{v.__class__.__name__}' == 'airflow.models.param.Param':
- kwargs = v.dump()
- kwargs['default'] = kwargs.pop('value')
- serialized_params[k] = kwargs
+ serialized_params[k] = cls._serialize_param(v)
else:
raise ValueError('Params to a DAG or a Task can be only of type airflow.models.param.Param')
return serialized_params
@classmethod
def _deserialize_params_dict(cls, encoded_params: Dict) -> ParamsDict:
- """Deserialize a DAGs Params dict"""
+ """Deserialize a DAG's Params dict"""
op_params = {}
for k, v in encoded_params.items():
if isinstance(v, dict) and "__class" in v:
- param_class = import_string(v['__class'])
- op_params[k] = param_class(**v)
+ op_params[k] = cls._deserialize_param(v)
else:
- # Old style params, upgrade it
+ # Old style params, convert it
op_params[k] = Param(v)
return ParamsDict(op_params)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index afba96a..6fec7f9 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -21,6 +21,7 @@
import copy
import importlib
import importlib.util
+import json
import multiprocessing
import os
from datetime import datetime, timedelta
@@ -724,6 +725,7 @@ class TestStringifiedDAGs:
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
+ ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_dag_params_roundtrip(self, val, expected_val):
@@ -733,7 +735,10 @@ class TestStringifiedDAGs:
dag = DAG(dag_id='simple_dag', params=val)
BaseOperator(task_id='simple_task', dag=dag, start_date=datetime(2019, 8, 1))
- serialized_dag = SerializedDAG.to_dict(dag)
+ serialized_dag_json = SerializedDAG.to_json(dag)
+
+ serialized_dag = json.loads(serialized_dag_json)
+
assert "params" in serialized_dag["dag"]
deserialized_dag = SerializedDAG.from_dict(serialized_dag)
@@ -764,14 +769,37 @@ class TestStringifiedDAGs:
params={'path': S3Param('s3://my_bucket/my_path')},
)
- with pytest.raises(SerializationError):
- SerializedDAG.to_dict(dag)
+ @pytest.mark.parametrize(
+ 'param',
+ [
+ Param('my value', description='hello', schema={'type': 'string'}),
+ Param('my value', description='hello'),
+ Param(None, description=None),
+ ],
+ )
+ def test_full_param_roundtrip(self, param):
+ """
+ Test to make sure that only native Param objects are being passed as dag or task params
+ """
+
+ dag = DAG(dag_id='simple_dag', params={'my_param': param})
+ serialized_json = SerializedDAG.to_json(dag)
+ serialized = json.loads(serialized_json)
+ SerializedDAG.validate_schema(serialized)
+ dag = SerializedDAG.from_dict(serialized)
+
+ assert dag.params["my_param"] == param.value
+ observed_param = dict.get(dag.params, 'my_param')
+ assert isinstance(observed_param, Param)
+ assert observed_param.description == param.description
+ assert observed_param.schema == param.schema
@pytest.mark.parametrize(
"val, expected_val",
[
(None, {}),
({"param_1": "value_1"}, {"param_1": "value_1"}),
+ ({"param_1": {1, 2, 3}}, {"param_1": {1, 2, 3}}),
],
)
def test_task_params_roundtrip(self, val, expected_val):
@@ -1433,29 +1461,32 @@ class TestStringifiedDAGs:
assert serialized_obj == expected_output
def test_params_upgrade(self):
+ """when pre-2.2.0 param (i.e. primitive) is deserialized we convert to Param"""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
- "fileloc": __file__,
+ "fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"none": None, "str": "str", "dict": {"a": "b"}},
},
}
- SerializedDAG.validate_schema(serialized)
dag = SerializedDAG.from_dict(serialized)
assert dag.params["none"] is None
assert isinstance(dict.__getitem__(dag.params, "none"), Param)
assert dag.params["str"] == "str"
- def test_params_serialize_default(self):
+ def test_params_serialize_default_2_2_0(self):
+ """In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
+ the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
+ test only to ensure that params stored in 2.2.0 can still be parsed correctly."""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
- "fileloc": __file__,
+ "fileloc": '/path/to/file.py',
"tasks": [],
"timezone": "UTC",
"params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}},
@@ -1467,6 +1498,33 @@ class TestStringifiedDAGs:
assert isinstance(dict.__getitem__(dag.params, "str"), Param)
assert dag.params["str"] == "str"
+ def test_params_serialize_default(self):
+ serialized = {
+ "__version": 1,
+ "dag": {
+ "_dag_id": "simple_dag",
+ "fileloc": '/path/to/file.py',
+ "tasks": [],
+ "timezone": "UTC",
+ "params": {
+ "my_param": {
+ "default": "a string value",
+ "description": "hello",
+ "schema": {"__var": {"type": "string"}, "__type": "dict"},
+ "__class": "airflow.models.param.Param",
+ }
+ },
+ },
+ }
+ SerializedDAG.validate_schema(serialized)
+ dag = SerializedDAG.from_dict(serialized)
+
+ assert dag.params["my_param"] == "a string value"
+ param = dict.get(dag.params, 'my_param')
+ assert isinstance(param, Param)
+ assert param.description == 'hello'
+ assert param.schema == {'type': 'string'}
+
def test_kubernetes_optional():
"""Serialisation / deserialisation continues to work without kubernetes installed"""