You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/23 20:33:10 UTC

[airflow] branch main updated: Allow MapXComArg to resolve after serialization (#26591)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e01c0d97a Allow MapXComArg to resolve after serialization  (#26591)
3e01c0d97a is described below

commit 3e01c0d97aeefce303e1fdb5cef160f192cce4fa
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Sat Sep 24 04:33:01 2022 +0800

    Allow MapXComArg to resolve after serialization  (#26591)
    
    This is useful for cases where we want to resolve an XCom without
    running a worker, e.g. to display the value in UI.
    
    Since we don't want to actually call the mapper function in this case
    (the function is arbitrary code, and not running it is the entire point
    to serialize operators), "resolving" the XComArg in this case would
    merely produce some kind of quasi-meaningful string representation,
    instead of the actual value we'd get in the worker.
    
    Also note that this only affects a very small number of cases, since
    once a worker is run for the task instance, RenderedTaskInstanceFields
    would store the real resolved value and take over UI representation,
    avoiding this fake resolving logic to be accessed at all.
---
 airflow/models/xcom_arg.py    | 48 ++++++++++++++++++++++++++++++++++++-------
 tests/models/test_xcom_arg.py |  4 ++--
 2 files changed, 43 insertions(+), 9 deletions(-)

diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2fb60195ef..9be82976ae 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -14,10 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 from __future__ import annotations
 
+import contextlib
 import inspect
-from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, overload
+from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload
 
 from sqlalchemy import func
 from sqlalchemy.orm import Session
@@ -35,6 +37,11 @@ if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
 
+# Callable objects contained by MapXComArg. We only accept callables from
+# the user, but deserialize them into strings in a serialized XComArg for
+# safety (those callables are arbitrary user code).
+MapCallables = Sequence[Union[Callable[[Any], Any], str]]
+
 
 class XComArg(DependencyMixin):
     """Reference to an XCom value pushed from another operator.
@@ -322,15 +329,39 @@ class PlainXComArg(XComArg):
         raise XComNotFound(context["ti"].dag_id, task_id, self.key)
 
 
+def _get_callable_name(f: Callable | str) -> str:
+    """Try to "describe" a callable by getting its name."""
+    if callable(f):
+        return f.__name__
+    # Parse the source to find whatever is behind "def". For safety, we don't
+    # want to evaluate the code in any meaningful way!
+    with contextlib.suppress(Exception):
+        kw, name, _ = f.lstrip().split(None, 2)
+        if kw == "def":
+            return name
+    return "<function>"
+
+
 class _MapResult(Sequence):
-    def __init__(self, value: Sequence | dict, callables: Sequence[Callable[[Any], Any]]) -> None:
+    def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
         self.value = value
         self.callables = callables
 
     def __getitem__(self, index: Any) -> Any:
         value = self.value[index]
-        for f in self.callables:
-            value = f(value)
+
+        # In the worker, we can access all actual callables. Call them.
+        callables = [f for f in self.callables if callable(f)]
+        if len(callables) == len(self.callables):
+            for f in callables:
+                value = f(value)
+            return value
+
+        # In the scheduler, we don't have access to the actual callables, nor do
+        # we want to run it since it's arbitrary code. This builds a string to
+        # represent the call chain in the UI or logs instead.
+        for v in self.callables:
+            value = f"{_get_callable_name(v)}({value})"
         return value
 
     def __len__(self) -> int:
@@ -342,9 +373,11 @@ class MapXComArg(XComArg):
 
     This is based on an XComArg, but also applies a series of "transforms" that
     convert the pulled XCom value.
+
+    :meta private:
     """
 
-    def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
+    def __init__(self, arg: XComArg, callables: MapCallables) -> None:
         for c in callables:
             if getattr(c, "_airflow_is_task_decorator", False):
                 raise ValueError("map() argument must be a plain function, not a @task operator")
@@ -352,12 +385,13 @@ class MapXComArg(XComArg):
         self.callables = callables
 
     def __repr__(self) -> str:
-        return f"{self.arg!r}.map([{len(self.callables)} functions])"
+        map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
+        return f"{self.arg!r}{map_calls}"
 
     def _serialize(self) -> dict[str, Any]:
         return {
             "arg": serialize_xcom_arg(self.arg),
-            "callables": [inspect.getsource(c) for c in self.callables],
+            "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
         }
 
     @classmethod
diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py
index 18cbe87de1..1f9a342c02 100644
--- a/tests/models/test_xcom_arg.py
+++ b/tests/models/test_xcom_arg.py
@@ -211,14 +211,14 @@ def test_xcom_zip(dag_maker, session, fillvalue, expected_results):
 
     # Run "push_letters" and "push_numbers".
     decision = dr.task_instance_scheduling_decisions(session=session)
-    assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
+    assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"]
     for ti in decision.schedulable_tis:
         ti.run(session=session)
     session.commit()
 
     # Run "pull".
     decision = dr.task_instance_scheduling_decisions(session=session)
-    assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
+    assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results)
     for ti in decision.schedulable_tis:
         ti.run(session=session)