You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/07/20 07:28:31 UTC

[GitHub] [airflow] uranusjr opened a new pull request, #25176: Implement XComArg.zip(*xcom_args)

uranusjr opened a new pull request, #25176:
URL: https://github.com/apache/airflow/pull/25176

   This needs #25085 to be merged first. Currently submitted for CI so I know I’ve made the right refactoring…


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [airflow] ashb commented on a diff in pull request #25176: Implement XComArg.zip(*xcom_args)

Posted by GitBox <gi...@apache.org>.
ashb commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932026131


##########
airflow/models/xcom_arg.py:
##########
@@ -379,13 +388,96 @@ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[in
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
         value = self.arg.resolve(context, session=session)
-        assert isinstance(value, (Sequence, dict))  # Validation was done when XCom was pushed.
+        if not isinstance(value, (Sequence, dict)):
+            raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
         return _MapResult(value, self.callables)
 
 
+class _ZipResult(Sequence):
+    def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
+        self.values = values
+        self.fillvalue = fillvalue
+
+    @staticmethod
+    def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
+        try:
+            return container[index]
+        except (IndexError, KeyError):
+            return fillvalue
+
+    def __getitem__(self, index: Any) -> Any:
+        if index >= len(self):
+            raise IndexError(index)
+        return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+    def __len__(self) -> int:
+        lengths = (len(v) for v in self.values)
+        if self.fillvalue is NOTSET:
+            return min(lengths)
+        return max(lengths)
+
+
+class ZipXComArg(XComArg):
+    """An XCom reference with ``zip()`` applied.
+
+    This is constructed from multiple XComArg instances, and presents an
+    iterable that "zips" them together like the built-in ``zip()`` (and
+    ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+    """
+
+    def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+        if not args:
+            raise ValueError("At least one input is required")
+        self.args = args
+        self.fillvalue = fillvalue
+
+    def __repr__(self) -> str:
+        args_iter = iter(self.args)
+        first = repr(next(args_iter))
+        rest = ", ".join(repr(arg) for arg in args_iter)
+        if self.fillvalue is NOTSET:
+            return f"{first}.zip({rest})"
+        return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+    def _serialize(self) -> Dict[str, Any]:
+        args = [serialize_xcom_arg(arg) for arg in self.args]
+        if self.fillvalue is NOTSET:
+            return {"args": args}
+        return {"args": args, "fillvalue": self.fillvalue}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(
+            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+            fillvalue=data.get("fillvalue", NOTSET),
+        )
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        for arg in self.args:
+            yield from arg.iter_references()
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+        ready_lengths = [length for length in all_lengths if length is not None]
+        if len(ready_lengths) != len(self.args):
+            return None  # If any of the referenced XComs is not ready, we are not ready either.

Review Comment:
   I'm not sure this is the right behavoiur when fillvalue is provided, espeically given things like https://github.com/apache/airflow/issues/24338)



##########
airflow/serialization/serialized_objects.py:
##########
@@ -393,7 +392,7 @@ def _serialize(cls, var: Any) -> Any:  # Unfortunately there is no support for r
         elif isinstance(var, Param):
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
-            return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)

Review Comment:
   This leads to a slightly odd serialization with type "doubled":
   
   ```json
   { "_type": "xcom_ref", "_val": { "type": "", ... }}
   ```
   
   Is this the right thing to do?



##########
airflow/models/xcom_arg.py:
##########
@@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
         if self.key != XCOM_RETURN_KEY:
-            raise ValueError
-        return MapXComArg(self, [f])
+            raise ValueError("cannot map against non-return XCom")
+        return super().map(f)
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError("cannot map against non-return XCom")

Review Comment:
   What is the reason for this limitation btw? (I know we have had it on map for a while, but I can't think of anything that would actually break if we didn't have it)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [airflow] uranusjr merged pull request #25176: Implement XComArg.zip(*xcom_args)

Posted by GitBox <gi...@apache.org>.
uranusjr merged PR #25176:
URL: https://github.com/apache/airflow/pull/25176


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [airflow] uranusjr commented on a diff in pull request #25176: Implement XComArg.zip(*xcom_args)

Posted by GitBox <gi...@apache.org>.
uranusjr commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932049350


##########
airflow/models/xcom_arg.py:
##########
@@ -379,13 +388,96 @@ def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[in
     @provide_session
     def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
         value = self.arg.resolve(context, session=session)
-        assert isinstance(value, (Sequence, dict))  # Validation was done when XCom was pushed.
+        if not isinstance(value, (Sequence, dict)):
+            raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
         return _MapResult(value, self.callables)
 
 
+class _ZipResult(Sequence):
+    def __init__(self, values: Sequence[Union[Sequence, dict]], *, fillvalue: Any = NOTSET) -> None:
+        self.values = values
+        self.fillvalue = fillvalue
+
+    @staticmethod
+    def _get_or_fill(container: Union[Sequence, dict], index: Any, fillvalue: Any) -> Any:
+        try:
+            return container[index]
+        except (IndexError, KeyError):
+            return fillvalue
+
+    def __getitem__(self, index: Any) -> Any:
+        if index >= len(self):
+            raise IndexError(index)
+        return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+    def __len__(self) -> int:
+        lengths = (len(v) for v in self.values)
+        if self.fillvalue is NOTSET:
+            return min(lengths)
+        return max(lengths)
+
+
+class ZipXComArg(XComArg):
+    """An XCom reference with ``zip()`` applied.
+
+    This is constructed from multiple XComArg instances, and presents an
+    iterable that "zips" them together like the built-in ``zip()`` (and
+    ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+    """
+
+    def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+        if not args:
+            raise ValueError("At least one input is required")
+        self.args = args
+        self.fillvalue = fillvalue
+
+    def __repr__(self) -> str:
+        args_iter = iter(self.args)
+        first = repr(next(args_iter))
+        rest = ", ".join(repr(arg) for arg in args_iter)
+        if self.fillvalue is NOTSET:
+            return f"{first}.zip({rest})"
+        return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+    def _serialize(self) -> Dict[str, Any]:
+        args = [serialize_xcom_arg(arg) for arg in self.args]
+        if self.fillvalue is NOTSET:
+            return {"args": args}
+        return {"args": args, "fillvalue": self.fillvalue}
+
+    @classmethod
+    def _deserialize(cls, data: Dict[str, Any], dag: "DAG") -> XComArg:
+        return cls(
+            [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+            fillvalue=data.get("fillvalue", NOTSET),
+        )
+
+    def iter_references(self) -> Iterator[Tuple["Operator", str]]:
+        for arg in self.args:
+            yield from arg.iter_references()
+
+    def get_task_map_length(self, run_id: str, *, session: "Session") -> Optional[int]:
+        all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+        ready_lengths = [length for length in all_lengths if length is not None]
+        if len(ready_lengths) != len(self.args):
+            return None  # If any of the referenced XComs is not ready, we are not ready either.

Review Comment:
   Yes #24338 will require changing this, but I want to make that change in the PR fixing the issue.
   
   Returning None here means “I don’t know how many tasks the downstream needs to be expanded into”, so `fillvalue` is not relevant here. Until all referenced XComs are available (or we know it won’t be, to address #24338), we can’t know how long the zipped result is (because it can still be either shorter or longer), and thus cannot decide how many tasks are needed. So I think this is the correct logic (although yes we need to change the `if` clause when fixing #24338).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [airflow] uranusjr commented on a diff in pull request #25176: Implement XComArg.zip(*xcom_args)

Posted by GitBox <gi...@apache.org>.
uranusjr commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932034109


##########
airflow/serialization/serialized_objects.py:
##########
@@ -393,7 +392,7 @@ def _serialize(cls, var: Any) -> Any:  # Unfortunately there is no support for r
         elif isinstance(var, Param):
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
-            return cls._encode(cls._serialize_xcomarg(var), type_=DAT.XCOM_REF)
+            return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)

Review Comment:
   Kind of intended, since I don’t want to pollute the serialisation logic with different XComArg types. Maybe using another key would improve the result? Say
   
   ```
   {"_type": "xcom_ref", "_val": {"task_id": "...", "key": "..."}}
   
   {
       "_type": "xcom_ref",
       "_val": {
            "kind": "map",
            "arg": {"task_id": "...", "key": "..."},
            "callables": [...]
       }
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [airflow] uranusjr commented on a diff in pull request #25176: Implement XComArg.zip(*xcom_args)

Posted by GitBox <gi...@apache.org>.
uranusjr commented on code in PR #25176:
URL: https://github.com/apache/airflow/pull/25176#discussion_r932035824


##########
airflow/models/xcom_arg.py:
##########
@@ -285,8 +288,13 @@ def iter_references(self) -> Iterator[Tuple["Operator", str]]:
 
     def map(self, f: Callable[[Any], Any]) -> "MapXComArg":
         if self.key != XCOM_RETURN_KEY:
-            raise ValueError
-        return MapXComArg(self, [f])
+            raise ValueError("cannot map against non-return XCom")
+        return super().map(f)
+
+    def zip(self, *others: "XComArg", fillvalue: Any = NOTSET) -> "ZipXComArg":
+        if self.key != XCOM_RETURN_KEY:
+            raise ValueError("cannot map against non-return XCom")

Review Comment:
   Nothing theoratical, I’m just not ready to deal with these being used outside of task-mapping context yet. These can be removed later when other use cases are experiemented on.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org