You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/07/28 10:05:21 UTC

[GitHub] [airflow] ashb commented on a diff in pull request #25176: Implement XComArg.zip(*xcom_args)

ashb commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932026131


##########
airflow/models/xcom_arg.py:
##########
@@ -379,13 +388,96 @@ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[in
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
         value = self.arg.resolve(context, session=session)
-        assert isinstance(value, (Sequence, dict))  # Validation was done when XCom was pushed.
+        if not isinstance(value, (Sequence, dict)):
+            raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
         return _MapResult(value, self.callables)
 
 
+class _ZipResult(Sequence):
+    def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
+        self.values = values
+        self.fillvalue = fillvalue
+
+    @staticmethod
+    def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
+        try:
+            return container[index]
+        except (IndexError, KeyError):
+            return fillvalue
+
+    def __getitem__(self, index: Any) -> Any:
+        if index >= len(self):
+            raise IndexError(index)
+        return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+    def __len__(self) -> int:
+        lengths = (len(v) for v in self.values)
+        if self.fillvalue is NOTSET:
+            return min(lengths)
+        return max(lengths)
+
+
+class ZipXComArg(XComArg):
+    """An XCom reference with ``zip()`` applied.
+
+    This is constructed from multiple XComArg instances, and presents an
+    iterable that "zips" them together like the built-in ``zip()`` (and
+    ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+    """
+
+    def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+        if not args:
+            raise ValueError("At least one input is required")
+        self.args = args
+        self.fillvalue = fillvalue
+
+    def __repr__(self) -> str:
+        args_iter = iter(self.args)
+        first = repr(next(args_iter))
+        rest = ", ".join(repr(arg) for arg in args_iter)
+        if self.fillvalue is NOTSET:
+            return f"{first}.zip({rest})"
+        return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+    def _serialize(self) -> Dict[str, Any]:
+        args = [serialize_xcom_arg(arg) for arg in self.args]
+        if self.fillvalue is NOTSET:
+            return {"args": args}
+        return {"args": args, "fillvalue": self.fillvalue}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(
+            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+            fillvalue=data.get("fillvalue", NOTSET),
+        )
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        for arg in self.args:
+            yield from arg.iter_references()
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+        ready_lengths = [length for length in all_lengths if length is not None]
+        if len(ready_lengths) != len(self.args):
+            return None  # If any of the referenced XComs is not ready, we are not ready either.

Review Comment:
   I'm not sure this is the right behavoiur when fillvalue is provided, espeically given things like https://github.com/apache/airflow/issues/24338)



##########
airflow/serialization/serialized_objects.py:
##########
@@ -393,7 +392,7 @@ def _serialize(cls, var: Any) -> Any:  # Unfortunately there is no support for r
         elif isinstance(var, Param):
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
-            return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)

Review Comment:
   This leads to a slightly odd serialization with type "doubled":
   
   ```json
   { "_type": "xcom_ref", "_val": { "type": "", ... }}
   ```
   
   Is this the right thing to do?



##########
airflow/models/xcom_arg.py:
##########
@@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
         if self.key != XCOM_RETURN_KEY:
-            raise ValueError
-        return MapXComArg(self, [f])
+            raise ValueError("cannot map against non-return XCom")
+        return super().map(f)
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError("cannot map against non-return XCom")

Review Comment:
   What is the reason for this limitation btw? (I know we have had it on map for a while, but I can't think of anything that would actually break if we didn't have it)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org