You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/02/17 11:49:41 UTC

[GitHub] [airflow] uranusjr opened a new pull request #21641: Implement mapped value unpacking

uranusjr opened a new pull request #21641:
URL: https://github.com/apache/airflow/pull/21641


   This is another PR that is likely much too big. Many (sort of interrelated) things done in one.
   
   #### Rewrite `@task` mapped operators expansion so it behave correctly
   
   Previously, a mapped `@task` was not expanded correctly because its mapped argument would look something like this:
   
   ```python
   mapped_arguments={
       "op_kwargs": {
           "arg": [1, 2, 3],
       },
   }
   ```
   
   this cannot be expanded using normal logic, because `PythonOperator` natively expects something like this instead:
   
   ```python
   mapped_arguments={
       "op_kwargs": [
           {"arg": 1},
           {"arg": 2},
           {"arg": 3},
       ],
   }
   ```
   
   so additional logic is implemented to modify `expand_mapped_task` to correctly expand.
   
   #### Re-implement `expand_mapped_task` logic to be able to expand literals
   
   Previously, `expand_mapped_task` only looks in TaskMap and cannot expand e.g. `.map(arg=[1, 2, 3])`. This implements a more sophisticated logic to collect information from mapped arguments and expand literal and XComArg inputs properly. Some refactoring was also done so the same logic can be reused for task-runtime value unpacking.
   
   #### Additional checks added to ensure correct literal types are passed to `.map()`
   
   Basically making sure it only receives lists and dicts (or XComArg, which is already checked separately when the upstream pushes to XCom).
   
   #### Extend `TaskInstance.render_templates` to “unpack” mapped values for task execution
   
   This is the main thing. After all template fields are rendered, additional process is done to “unpack” mapped values into individual ones based on map_index, which further mutates the operator object held by a TaskInstance. This reuses similar logic from `expand_mapped_task` to calculate the total length of map, so it can locate its map_index inside the series.
   
   Previously, `render_template_fields` was only implemented on BaseOperator, and a MappedOperator is first unmapped before templates are rendered. This approach was unforuantely wrong, since value unpacking (which needs to happen after template rendering resolves XComArg) needs to be aware of the original MappedOperator (mainly to access the user-supplied mapped kwargs). So the new logic delays unmapping until _during_ `render_templates`. If a TaskInstance’s `task` is a MappedOperator, its `render_templates` would call the MappedOperator’s `render_template_fields`, which unmaps itself, calls `render_template_fields` on the unmapped operator, unpacks values for the unmapped operator, _and then returns the unmapped operator_. The TaskInstance would reassign its `task` to the unmapped, rendered, unpacked operator. This means that `TaskInstance.render_templates` now may have a side effect of mutating its own `task` attribute. I don’t particularly like this, but couldn’t find a c
 leaner approach without breaking much of the existing interface (complicated by the fact that both `BaseOperator.render_template_fields` and `BaseOperator.render_template` are public API). I think this is close to the best possible interface considering existing constraints (and I try to document this as clearly as possible for future maintainability).
   
   A few tests were added to check this is working (see `TestMappedTaskInstanceReceiveValue`).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809099505



##########
File path: airflow/models/mappedoperator.py
##########
@@ -55,38 +58,65 @@
     TaskStateChangeCallback,
 )
 from airflow.models.pool import Pool
-from airflow.models.xcom_arg import XComArg
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.typing_compat import Literal
+from airflow.utils.context import Context
 from airflow.utils.operator_resources import Resources
-from airflow.utils.session import NEW_SESSION
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    import jinja2  # Slow import.
+
     from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
     from airflow.models.dag import DAG
     from airflow.models.taskinstance import TaskInstance
+    from airflow.models.xcom_arg import XComArg
+
+    # BaseOperator.map() can be called on an XComArg, sequence, or dict (not any
+    # mapping since we need the value to be ordered).
+    MapArgument = Union[XComArg, Sequence, dict]
+
+ValidationSource = Union[Literal["map"], Literal["partial"]]

Review comment:
       Ahhh




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809721162



##########
File path: airflow/decorators/base.py
##########
@@ -369,31 +371,43 @@ class DecoratedMappedOperator(MappedOperator):
     multiple_outputs: bool
     python_callable: Callable
 
-    # We can't save these in partial_kwargs because op_args and op_kwargs need
-    # to be present in mapped_kwargs, and MappedOperator prevents duplication.
-    partial_op_kwargs: Dict[str, Any]
+    # We can't save these in mapped_kwargs because op_kwargs need to be present
+    # in partial_kwargs, and MappedOperator prevents duplication.
+    mapped_op_kwargs: Dict[str, "MapArgument"]
 
     @classmethod
     @cache
     def get_serialized_fields(cls):
-        # The magic argument-less super() does not work well with @cache
-        # (actually lru_cache in general), so we use the explicit form instead.
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., cls) to work around pyupgrade bug.
         sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
-        return sup.get_serialized_fields() | {"partial_op_kwargs"}
+        return sup.get_serialized_fields() | {"mapped_op_kwargs"}
 
-    def _create_unmapped_operator(
-        self,
-        *,
-        mapped_kwargs: Dict[str, Any],
-        partial_kwargs: Dict[str, Any],
-        real: bool,
-    ) -> "BaseOperator":
+    def __attrs_post_init__(self):
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., self) to work around pyupgrade bug.
+        super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
+        XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
+
+    def _get_expansion_kwargs(self) -> Dict[str, "MapArgument"]:
+        """The kwargs to calculate expansion length against.
+
+        Different from classic operators, a decorated (taskflow) operator's
+        ``map()`` contributes to the ``op_kwargs`` operator argument (not the
+        operator arguments themselves), and should therefore expand against it.
+        """
+        return self.mapped_op_kwargs
+
+    def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator":
         assert not isinstance(self.operator_class, str)
-        mapped_kwargs = mapped_kwargs.copy()
-        del mapped_kwargs["op_kwargs"]
+        partial_kwargs = self.partial_kwargs.copy()
+        if real:
+            mapped_op_kwargs: Dict[str, Any] = self.mapped_op_kwargs
+        else:
+            mapped_op_kwargs = {k: unittest.mock.MagicMock(name=k) for k in self.mapped_op_kwargs}

Review comment:
       I extracted this into a function with a three-paragraph docstring.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808984085



##########
File path: airflow/decorators/base.py
##########
@@ -269,22 +276,21 @@ def __call__(self, *args, **kwargs) -> XComArg:
             op.doc_md = self.function.__doc__
         return XComArg(op)
 
-    def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()):
-        unknown_args = kwargs.copy()
-        for name in itertools.chain(self.function_arg_names, valid_names):
-            unknown_args.pop(name, None)
-
-            if not unknown_args:
-                # If we have no args left ot check, we are valid
-                return
+    def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
+        kwargs_left = kwargs.copy()
+        for arg_name in self.function_arg_names:
+            value = kwargs_left.pop(arg_name, NOTSET)
+            if func != "map" or value is NOTSET or isinstance(value, get_mappable_types()):
+                continue
+            raise ValueError(f"{func} got unexpected value{type(value)!r} for keyword argument {arg_name!r}")

Review comment:
       🤦🏻 I see. Ignore me.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r811194466



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       Yeah, execution order is not important just that map_index=0 is always the same values in the case of mapping over multiple TaskMaps.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808988190



##########
File path: airflow/models/abstractoperator.py
##########
@@ -249,3 +253,124 @@ def get_extra_links(self, dttm: datetime.datetime, link_name: str) -> Optional[D
         elif link_name in self.global_operator_extra_link_dict:
             return self.global_operator_extra_link_dict[link_name].get_link(self, dttm)
         return None
+
+    @provide_session
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: "jinja2.Environment",
+        seen_oids: Set,
+        *,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        for attr_name in template_fields:
+            try:
+                content = getattr(parent, attr_name)
+            except AttributeError:
+                raise AttributeError(
+                    f"{attr_name!r} is configured as a template field "
+                    f"but {parent.task_type} does not have this attribute."
+                )
+            if not content:
+                continue
+            rendered_content = self._render_template_field(
+                attr_name,
+                content,
+                context,

Review comment:
       nit: `context` and `content` are very similar, and we'd be somewhat error prone to confuse them. Can we rename `content` to something like `template`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr merged pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr merged pull request #21641:
URL: https://github.com/apache/airflow/pull/21641


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809020736



##########
File path: airflow/models/abstractoperator.py
##########
@@ -249,3 +253,124 @@ def get_extra_links(self, dttm: datetime.datetime, link_name: str) -> Optional[D
         elif link_name in self.global_operator_extra_link_dict:
             return self.global_operator_extra_link_dict[link_name].get_link(self, dttm)
         return None
+
+    @provide_session
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: "jinja2.Environment",
+        seen_oids: Set,
+        *,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        for attr_name in template_fields:
+            try:
+                content = getattr(parent, attr_name)
+            except AttributeError:
+                raise AttributeError(
+                    f"{attr_name!r} is configured as a template field "
+                    f"but {parent.task_type} does not have this attribute."
+                )
+            if not content:
+                continue
+            rendered_content = self._render_template_field(
+                attr_name,
+                content,
+                context,

Review comment:
       We unfortunately can’t rename the argument name in `render_template()` since it’s public API. I renamed `content` to `value` everywhere else. I avoided `template` since it is used to refer to the compiled Jinja template object.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809114724



##########
File path: airflow/decorators/base.py
##########
@@ -369,31 +371,43 @@ class DecoratedMappedOperator(MappedOperator):
     multiple_outputs: bool
     python_callable: Callable
 
-    # We can't save these in partial_kwargs because op_args and op_kwargs need
-    # to be present in mapped_kwargs, and MappedOperator prevents duplication.
-    partial_op_kwargs: Dict[str, Any]
+    # We can't save these in mapped_kwargs because op_kwargs need to be present
+    # in partial_kwargs, and MappedOperator prevents duplication.
+    mapped_op_kwargs: Dict[str, "MapArgument"]
 
     @classmethod
     @cache
     def get_serialized_fields(cls):
-        # The magic argument-less super() does not work well with @cache
-        # (actually lru_cache in general), so we use the explicit form instead.
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., cls) to work around pyupgrade bug.
         sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
-        return sup.get_serialized_fields() | {"partial_op_kwargs"}
+        return sup.get_serialized_fields() | {"mapped_op_kwargs"}
 
-    def _create_unmapped_operator(
-        self,
-        *,
-        mapped_kwargs: Dict[str, Any],
-        partial_kwargs: Dict[str, Any],
-        real: bool,
-    ) -> "BaseOperator":
+    def __attrs_post_init__(self):
+        # The magic super() doesn't work here, so we use the explicit form.
+        # Not using super(..., self) to work around pyupgrade bug.
+        super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
+        XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
+
+    def _get_expansion_kwargs(self) -> Dict[str, "MapArgument"]:
+        """The kwargs to calculate expansion length against.
+
+        Different from classic operators, a decorated (taskflow) operator's
+        ``map()`` contributes to the ``op_kwargs`` operator argument (not the
+        operator arguments themselves), and should therefore expand against it.
+        """
+        return self.mapped_op_kwargs
+
+    def _create_unmapped_operator(self, *, mapped_kwargs: Dict[str, Any], real: bool) -> "BaseOperator":
         assert not isinstance(self.operator_class, str)
-        mapped_kwargs = mapped_kwargs.copy()
-        del mapped_kwargs["op_kwargs"]
+        partial_kwargs = self.partial_kwargs.copy()
+        if real:
+            mapped_op_kwargs: Dict[str, Any] = self.mapped_op_kwargs
+        else:
+            mapped_op_kwargs = {k: unittest.mock.MagicMock(name=k) for k in self.mapped_op_kwargs}

Review comment:
       This probably needs a huge comment saying what we're doing 😁 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r811195531



##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
         """
         if isinstance(self.operator_class, str):
             return  # No need to validate deserialized operator.
-        operator = self._create_unmapped_operator(
-            mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in self.mapped_kwargs},
-            partial_kwargs=self.partial_kwargs,
-            real=False,
-        )
-        if operator.task_group:
-            operator.task_group._remove(operator)
-        dag = operator.get_dag()
+        mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+        op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)

Review comment:
       Yeah, I was thinking an escape hatch for when an Operator ctor does something that causes it to fail when given a Mock.
   
   A proper class-level `validate(**kwargs)` sounds like a better idea.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809950052



##########
File path: airflow/models/mappedoperator.py
##########
@@ -103,6 +134,24 @@ def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail
     raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
 
 
+def create_mocked_kwargs(kwargs: Dict[str, "MapArgument"]) -> Dict[str, unittest.mock.MagicMock]:
+    """Create a mapping of mocks for given map arguments.
+
+    When a mapped operator is created, we want to perform basic validation on
+    the map arguments, especially the count of arguments. However, most of this
+    kind of logic lives directly on an operator class's ``__init__``, and
+    there's no good way to validate the arguments except to actually try to
+    create an operator instance.
+
+    Since the map arguments are yet to be populated when the mapped operator is
+    being parsed, we need to "invent" some mocked values for this validation
+    purpose. The :class:`~unittest.mock.MagicMock` class is a good fit for this
+    since it not only provide good run-time properties, but also enjoy special
+    treatments in Mypy.
+    """
+    return {k: unittest.mock.MagicMock(name=k) for k in kwargs}

Review comment:
       Future possible enhancement (absolutely not in this pr):
   
   Set `spec` for the mock based on the type hint of the param we are are creating the mock for.

##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
         """
         if isinstance(self.operator_class, str):
             return  # No need to validate deserialized operator.
-        operator = self._create_unmapped_operator(
-            mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in self.mapped_kwargs},
-            partial_kwargs=self.partial_kwargs,
-            real=False,
-        )
-        if operator.task_group:
-            operator.task_group._remove(operator)
-        dag = operator.get_dag()
+        mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+        op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)

Review comment:
       I do wonder if we'll need an escape hatch for this.
   
   I guess we can add that when someone finds a specific case.

##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       I don't think it matters, but I would have expected `emit_numbers to be the outer loop since it's the "first" argument
   
   ```python
           assert outputs == {
               (1, ("a", "x")),
               (1, ("c", "z")),
               (1, ("b", "y")),
               (2, ("a", "x")),
               (2, ("b", "y")),
               (2, ("c", "z")),
           }
   ```
   
   
   i.e. it's the equiv of:
   ```python
   
   for number in emit_numbers():
       for letter in emit_letters():
           show(number, letter)
   ```

##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
         """
         if isinstance(self.operator_class, str):
             return  # No need to validate deserialized operator.
-        operator = self._create_unmapped_operator(
-            mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in self.mapped_kwargs},
-            partial_kwargs=self.partial_kwargs,
-            real=False,
-        )
-        if operator.task_group:
-            operator.task_group._remove(operator)
-        dag = operator.get_dag()
+        mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+        op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)
+        if op.task_group:
+            op.task_group._remove(op)
+        dag = op.get_dag()
         if dag:
-            dag._remove_task(operator.task_id)
+            dag._remove_task(op.task_id)

Review comment:
       ```python
           dag = op.get_dag()
           if dag:
               dag._remove_task(op.task_id)
   ```
   
   I think that is all we need -- dag._remove_tasks does the taskgroup check/remove already.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808982659



##########
File path: airflow/models/mappedoperator.py
##########
@@ -55,38 +58,65 @@
     TaskStateChangeCallback,
 )
 from airflow.models.pool import Pool
-from airflow.models.xcom_arg import XComArg
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.typing_compat import Literal
+from airflow.utils.context import Context
 from airflow.utils.operator_resources import Resources
-from airflow.utils.session import NEW_SESSION
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    import jinja2  # Slow import.
+
     from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
     from airflow.models.dag import DAG
     from airflow.models.taskinstance import TaskInstance
+    from airflow.models.xcom_arg import XComArg
+
+    # BaseOperator.map() can be called on an XComArg, sequence, or dict (not any
+    # mapping since we need the value to be ordered).
+    MapArgument = Union[XComArg, Sequence, dict]
+
+ValidationSource = Union[Literal["map"], Literal["partial"]]

Review comment:
       I think you can do this
   
   ```suggestion
   ValidationSource = Literal["map", "partial"]
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809974265



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       This is a set so the ordering doesn’t matter :p The actual ordering is correct if the task instances are ordered by `map_index`. (But that ordering is not guaranteed on the list of _task instances_ returned by `expand_mapped_task`; I can add additional logic to test that.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r810680291



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       We guarantee that map_index would be generated in a predictable order, but not the mapped task instances would execute in that same order (which is the potentially unstable part here).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809008757



##########
File path: airflow/models/abstractoperator.py
##########
@@ -17,14 +17,18 @@
 # under the License.
 
 import datetime
-from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, List, Optional, Set, Type, Union
+from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Type, Union
+
+from sqlalchemy.orm import Session

Review comment:
       Yes, but if `Session` is ever passed in as an argument this class is already in memory anyway, making this import basically free.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r810005922



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       I think we need the order to be stable/guaranteed. What it is doesn't matter so much, just that it is stable and consistent. (We said we'd have this in the AIP)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r810680512



##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
         """
         if isinstance(self.operator_class, str):
             return  # No need to validate deserialized operator.
-        operator = self._create_unmapped_operator(
-            mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in self.mapped_kwargs},
-            partial_kwargs=self.partial_kwargs,
-            real=False,
-        )
-        if operator.task_group:
-            operator.task_group._remove(operator)
-        dag = operator.get_dag()
+        mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+        op = self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)

Review comment:
       Escape hatch for if an operator cannot correctly run `__init__` with mocks?
   
   I think eventually we may want to create a separate API for validation without actually creating the operator, perhaps a classmethod called `validate` that can be called here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808997216



##########
File path: airflow/models/mappedoperator.py
##########
@@ -475,37 +517,119 @@ def expand_mapped_task(
             state = unmapped_ti.state
             self.log.debug("Updated in place to become %s", unmapped_ti)
             ret.append(unmapped_ti)
-            indexes_to_map = range(1, task_map_info_length)
+            indexes_to_map = range(1, total_length)
         else:
             # Only create "missing" ones.
             current_max_mapping = (
                 session.query(func.max(TaskInstance.map_index))
                 .filter(
-                    TaskInstance.dag_id == upstream_ti.dag_id,
+                    TaskInstance.dag_id == self.dag_id,
                     TaskInstance.task_id == self.task_id,
-                    TaskInstance.run_id == upstream_ti.run_id,
+                    TaskInstance.run_id == run_id,
                 )
                 .scalar()
             )
-            indexes_to_map = range(current_max_mapping + 1, task_map_info_length)
+            indexes_to_map = range(current_max_mapping + 1, total_length)
 
         for index in indexes_to_map:
             # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
             # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
-            ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state)  # type: ignore
+            ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)  # type: ignore
             self.log.debug("Expanding TIs upserted %s", ti)
             task_instance_mutation_hook(ti)
             ret.append(session.merge(ti))
 
         # Set to "REMOVED" any (old) TaskInstances with map indices greater
         # than the current map value
         session.query(TaskInstance).filter(
-            TaskInstance.dag_id == upstream_ti.dag_id,
+            TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.run_id == upstream_ti.run_id,
-            TaskInstance.map_index >= task_map_info_length,
+            TaskInstance.run_id == run_id,
+            TaskInstance.map_index >= total_length,
         ).update({TaskInstance.state: TaskInstanceState.REMOVED})
 
         session.flush()
 
         return ret
+
+    def prepare_for_execution(self) -> "MappedOperator":
+        # Since a mapped operator cannot be used for execution, and an unmapped
+        # BaseOperator needs to be created later (see render_template_fields),
+        # we don't need to create a copy of the MappedOperator here.
+        return self
+
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+    ) -> "BaseOperator":
+        """Template all attributes listed in template_fields.
+
+        Different from the BaseOperator implementation, this renders the
+        template fields on the *unmapped* BaseOperator.
+
+        :param context: Dict with values to apply on content
+        :param jinja_env: Jinja environment
+        :return: The unmapped, populated BaseOperator
+        """
+        if not jinja_env:
+            jinja_env = self.get_template_env()
+        unmapped_task = self.unmap()
+        self._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=unmapped_task.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _render_template_field(
+        self,
+        key: str,
+        content: Any,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+        seen_oids: Optional[Set] = None,
+        *,
+        session: Session,
+    ) -> Any:
+        """Override the ordinary template rendering to add more logic.
+
+        Specifically, if we're rendering a mapped argument, we need to "unmap"
+        the value as well to assign it to the unmapped operator.
+        """
+        content = super()._render_template_field(key, content, context, jinja_env, seen_oids, session=session)
+        return self._expand_mapped_field(key, content, context, session=session)
+
+    def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any:
+        map_index = context["ti"].map_index
+        if map_index < 0:
+            return content
+        expansion_kwargs = self._get_expansion_kwargs()
+        all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+        def _find_index_for_this_field(index: int) -> int:
+            # Need to use self.mapped_kwargs for the original argument order.
+            for mapped_key in reversed(list(expansion_kwargs)):
+                mapped_length = all_lengths[mapped_key]
+                if mapped_key == key:
+                    return index % mapped_length
+                index //= mapped_length
+            return -1
+
+        found_index = _find_index_for_this_field(map_index)
+        if found_index < 0:
+            return content
+        if isinstance(content, collections.abc.Sequence):
+            return content[found_index]
+        if not isinstance(content, dict):
+            raise TypeError(f"can't map over value of type {type(content)}")
+        for i, (k, v) in enumerate(content.items()):
+            if i == found_index:
+                return k, v
+        raise IndexError(f"index {map_index} is over mapped length")
+
+
+class _FieldNotMapped(Exception):
+    """Raised by _expand_mapped_field if a field is not mapped."""

Review comment:
       Not used currently?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809974265



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       This is a set so the ordering doesn’t matter :p The actual ordering is correct if the task instances are ordered by `map_index`. (But that’s not guaranteed by `expand_mapped_task`; I can add additional logic to test that.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809974265



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       This is a set so the ordering doesn’t matter :p You’re right on the actual ordering. (I didn’t want to rely on that because the *list of task instances* are not guaranteed to be ordered.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809010392



##########
File path: airflow/models/abstractoperator.py
##########
@@ -249,3 +253,124 @@ def get_extra_links(self, dttm: datetime.datetime, link_name: str) -> Optional[D
         elif link_name in self.global_operator_extra_link_dict:
             return self.global_operator_extra_link_dict[link_name].get_link(self, dttm)
         return None
+
+    @provide_session
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: "jinja2.Environment",
+        seen_oids: Set,
+        *,
+        session: Session = NEW_SESSION,
+    ) -> None:
+        for attr_name in template_fields:
+            try:
+                content = getattr(parent, attr_name)
+            except AttributeError:
+                raise AttributeError(
+                    f"{attr_name!r} is configured as a template field "
+                    f"but {parent.task_type} does not have this attribute."
+                )
+            if not content:
+                continue
+            rendered_content = self._render_template_field(
+                attr_name,
+                content,
+                context,

Review comment:
       Good idea, will do. (This was cut-pasted from BaseOperator.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808983807



##########
File path: airflow/decorators/base.py
##########
@@ -269,22 +276,21 @@ def __call__(self, *args, **kwargs) -> XComArg:
             op.doc_md = self.function.__doc__
         return XComArg(op)
 
-    def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()):
-        unknown_args = kwargs.copy()
-        for name in itertools.chain(self.function_arg_names, valid_names):
-            unknown_args.pop(name, None)
-
-            if not unknown_args:
-                # If we have no args left ot check, we are valid
-                return
+    def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
+        kwargs_left = kwargs.copy()
+        for arg_name in self.function_arg_names:
+            value = kwargs_left.pop(arg_name, NOTSET)
+            if func != "map" or value is NOTSET or isinstance(value, get_mappable_types()):
+                continue
+            raise ValueError(f"{func} got unexpected value{type(value)!r} for keyword argument {arg_name!r}")

Review comment:
       Shouldn't this be a TypeError? (That's why Python throws in this situation)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r808986669



##########
File path: airflow/models/abstractoperator.py
##########
@@ -17,14 +17,18 @@
 # under the License.
 
 import datetime
-from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, List, Optional, Set, Type, Union
+from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Type, Union
+
+from sqlalchemy.orm import Session

Review comment:
       Type checking only import?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809014065



##########
File path: airflow/models/mappedoperator.py
##########
@@ -55,38 +58,65 @@
     TaskStateChangeCallback,
 )
 from airflow.models.pool import Pool
-from airflow.models.xcom_arg import XComArg
 from airflow.serialization.enums import DagAttributeTypes
 from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
 from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
+from airflow.typing_compat import Literal
+from airflow.utils.context import Context
 from airflow.utils.operator_resources import Resources
-from airflow.utils.session import NEW_SESSION
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
+    import jinja2  # Slow import.
+
     from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
     from airflow.models.dag import DAG
     from airflow.models.taskinstance import TaskInstance
+    from airflow.models.xcom_arg import XComArg
+
+    # BaseOperator.map() can be called on an XComArg, sequence, or dict (not any
+    # mapping since we need the value to be ordered).
+    MapArgument = Union[XComArg, Sequence, dict]
+
+ValidationSource = Union[Literal["map"], Literal["partial"]]

Review comment:
       Unfortunately this only works on 3.9+.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#issuecomment-1042961386


   This is actually working


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809008037



##########
File path: airflow/models/mappedoperator.py
##########
@@ -475,37 +517,119 @@ def expand_mapped_task(
             state = unmapped_ti.state
             self.log.debug("Updated in place to become %s", unmapped_ti)
             ret.append(unmapped_ti)
-            indexes_to_map = range(1, task_map_info_length)
+            indexes_to_map = range(1, total_length)
         else:
             # Only create "missing" ones.
             current_max_mapping = (
                 session.query(func.max(TaskInstance.map_index))
                 .filter(
-                    TaskInstance.dag_id == upstream_ti.dag_id,
+                    TaskInstance.dag_id == self.dag_id,
                     TaskInstance.task_id == self.task_id,
-                    TaskInstance.run_id == upstream_ti.run_id,
+                    TaskInstance.run_id == run_id,
                 )
                 .scalar()
             )
-            indexes_to_map = range(current_max_mapping + 1, task_map_info_length)
+            indexes_to_map = range(current_max_mapping + 1, total_length)
 
         for index in indexes_to_map:
             # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
             # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
-            ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state)  # type: ignore
+            ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)  # type: ignore
             self.log.debug("Expanding TIs upserted %s", ti)
             task_instance_mutation_hook(ti)
             ret.append(session.merge(ti))
 
         # Set to "REMOVED" any (old) TaskInstances with map indices greater
         # than the current map value
         session.query(TaskInstance).filter(
-            TaskInstance.dag_id == upstream_ti.dag_id,
+            TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.run_id == upstream_ti.run_id,
-            TaskInstance.map_index >= task_map_info_length,
+            TaskInstance.run_id == run_id,
+            TaskInstance.map_index >= total_length,
         ).update({TaskInstance.state: TaskInstanceState.REMOVED})
 
         session.flush()
 
         return ret
+
+    def prepare_for_execution(self) -> "MappedOperator":
+        # Since a mapped operator cannot be used for execution, and an unmapped
+        # BaseOperator needs to be created later (see render_template_fields),
+        # we don't need to create a copy of the MappedOperator here.
+        return self
+
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+    ) -> "BaseOperator":
+        """Template all attributes listed in template_fields.
+
+        Different from the BaseOperator implementation, this renders the
+        template fields on the *unmapped* BaseOperator.
+
+        :param context: Dict with values to apply on content
+        :param jinja_env: Jinja environment
+        :return: The unmapped, populated BaseOperator
+        """
+        if not jinja_env:
+            jinja_env = self.get_template_env()
+        unmapped_task = self.unmap()
+        self._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=unmapped_task.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _render_template_field(
+        self,
+        key: str,
+        content: Any,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+        seen_oids: Optional[Set] = None,
+        *,
+        session: Session,
+    ) -> Any:
+        """Override the ordinary template rendering to add more logic.
+
+        Specifically, if we're rendering a mapped argument, we need to "unmap"
+        the value as well to assign it to the unmapped operator.
+        """
+        content = super()._render_template_field(key, content, context, jinja_env, seen_oids, session=session)
+        return self._expand_mapped_field(key, content, context, session=session)
+
+    def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any:
+        map_index = context["ti"].map_index
+        if map_index < 0:
+            return content
+        expansion_kwargs = self._get_expansion_kwargs()
+        all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+        def _find_index_for_this_field(index: int) -> int:
+            # Need to use self.mapped_kwargs for the original argument order.
+            for mapped_key in reversed(list(expansion_kwargs)):
+                mapped_length = all_lengths[mapped_key]
+                if mapped_key == key:
+                    return index % mapped_length
+                index //= mapped_length
+            return -1
+
+        found_index = _find_index_for_this_field(map_index)
+        if found_index < 0:
+            return content
+        if isinstance(content, collections.abc.Sequence):
+            return content[found_index]
+        if not isinstance(content, dict):
+            raise TypeError(f"can't map over value of type {type(content)}")
+        for i, (k, v) in enumerate(content.items()):
+            if i == found_index:
+                return k, v
+        raise IndexError(f"index {map_index} is over mapped length")
+
+
+class _FieldNotMapped(Exception):
+    """Raised by _expand_mapped_field if a field is not mapped."""

Review comment:
       Forgot to delete it after refactoring.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] ashb commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809955067



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       I don't think it matters, but I would have expected `emit_numbers to be the outer loop since it's the "first" argument
   
   ```python
           assert outputs == {
               (1, ("a", "x")),
               (1, ("c", "z")),
               (1, ("b", "y")),
               (2, ("a", "x")),
               (2, ("b", "y")),
               (2, ("c", "z")),
           }
   ```
   
   
   i.e. `show.map(number=emit_numbers(), letter=emit_letters())` is the equiv of:
   ```python
   
   for number in emit_numbers():
       for letter in emit_letters():
           show(number, letter)
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] github-actions[bot] commented on pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#issuecomment-1044448888


   The PR most likely needs to run full matrix of tests because it modifies parts of the core of Airflow. However, committers might decide to merge it quickly and take the risk. If they don't merge it quickly - please rebase it to the latest main at your convenience, or amend the last commit of the PR, and push it with --force-with-lease.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

Posted by GitBox <gi...@apache.org>.
uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r810681544



##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
         assert task_map.map_index == -1
         assert task_map.length == expected_length
         assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+    @pytest.mark.parametrize(
+        "literal, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="literal", session=session) as dag:
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=literal)
+
+        dag_run = dag_maker.create_dagrun()
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(literal)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    @pytest.mark.parametrize(
+        "upstream_return, expected_outputs",
+        [
+            pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+            pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+        ],
+    )
+    def test_map_xcom(self, upstream_return, expected_outputs, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="xcom", session=session) as dag:
+
+            @dag.task
+            def emit():
+                return upstream_return
+
+            @dag.task
+            def show(value):
+                outputs.add(value)
+
+            show.map(value=emit())
+
+        dag_run = dag_maker.create_dagrun()
+        emit_ti = dag_run.get_task_instance("emit", session=session)
+        emit_ti.refresh_from_task(dag.get_task("emit"))
+        emit_ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == len(upstream_return)
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == expected_outputs
+
+    def test_map_product(self, dag_maker, session):
+        outputs = set()
+
+        with dag_maker(dag_id="product", session=session) as dag:
+
+            @dag.task
+            def emit_numbers():
+                return [1, 2]
+
+            @dag.task
+            def emit_letters():
+                return {"a": "x", "b": "y", "c": "z"}
+
+            @dag.task
+            def show(number, letter):
+                outputs.add((number, letter))
+
+            show.map(number=emit_numbers(), letter=emit_letters())
+
+        dag_run = dag_maker.create_dagrun()
+        for task_id in ["emit_numbers", "emit_letters"]:
+            ti = dag_run.get_task_instance(task_id, session=session)
+            ti.refresh_from_task(dag.get_task(task_id))
+            ti.run()
+
+        show_task = dag.get_task("show")
+        tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+        assert len(tis) == 6
+
+        for ti in tis:
+            ti.refresh_from_task(show_task)
+            ti.run()
+        assert outputs == {
+            (1, ("a", "x")),
+            (2, ("a", "x")),
+            (1, ("b", "y")),
+            (2, ("b", "y")),
+            (1, ("c", "z")),
+            (2, ("c", "z")),
+        }

Review comment:
       With that said, I added a `sorted()` call and tweaked the tests to ensure stricter ordering.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org