You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2022/02/17 12:43:14 UTC

[GitHub] [airflow] uranusjr commented on a change in pull request #21641: Implement mapped value unpacking

uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809008037



##########
File path: airflow/models/mappedoperator.py
##########
@@ -475,37 +517,119 @@ def expand_mapped_task(
             state = unmapped_ti.state
             self.log.debug("Updated in place to become %s", unmapped_ti)
             ret.append(unmapped_ti)
-            indexes_to_map = range(1, task_map_info_length)
+            indexes_to_map = range(1, total_length)
         else:
             # Only create "missing" ones.
             current_max_mapping = (
                 session.query(func.max(TaskInstance.map_index))
                 .filter(
-                    TaskInstance.dag_id == upstream_ti.dag_id,
+                    TaskInstance.dag_id == self.dag_id,
                     TaskInstance.task_id == self.task_id,
-                    TaskInstance.run_id == upstream_ti.run_id,
+                    TaskInstance.run_id == run_id,
                 )
                 .scalar()
             )
-            indexes_to_map = range(current_max_mapping + 1, task_map_info_length)
+            indexes_to_map = range(current_max_mapping + 1, total_length)
 
         for index in indexes_to_map:
             # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
             # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
-            ti = TaskInstance(self, run_id=upstream_ti.run_id, map_index=index, state=state)  # type: ignore
+            ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)  # type: ignore
             self.log.debug("Expanding TIs upserted %s", ti)
             task_instance_mutation_hook(ti)
             ret.append(session.merge(ti))
 
         # Set to "REMOVED" any (old) TaskInstances with map indices greater
         # than the current map value
         session.query(TaskInstance).filter(
-            TaskInstance.dag_id == upstream_ti.dag_id,
+            TaskInstance.dag_id == self.dag_id,
             TaskInstance.task_id == self.task_id,
-            TaskInstance.run_id == upstream_ti.run_id,
-            TaskInstance.map_index >= task_map_info_length,
+            TaskInstance.run_id == run_id,
+            TaskInstance.map_index >= total_length,
         ).update({TaskInstance.state: TaskInstanceState.REMOVED})
 
         session.flush()
 
         return ret
+
+    def prepare_for_execution(self) -> "MappedOperator":
+        # Since a mapped operator cannot be used for execution, and an unmapped
+        # BaseOperator needs to be created later (see render_template_fields),
+        # we don't need to create a copy of the MappedOperator here.
+        return self
+
+    def render_template_fields(
+        self,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+    ) -> "BaseOperator":
+        """Template all attributes listed in template_fields.
+
+        Different from the BaseOperator implementation, this renders the
+        template fields on the *unmapped* BaseOperator.
+
+        :param context: Dict with values to apply on content
+        :param jinja_env: Jinja environment
+        :return: The unmapped, populated BaseOperator
+        """
+        if not jinja_env:
+            jinja_env = self.get_template_env()
+        unmapped_task = self.unmap()
+        self._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=unmapped_task.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _render_template_field(
+        self,
+        key: str,
+        content: Any,
+        context: Context,
+        jinja_env: Optional["jinja2.Environment"] = None,
+        seen_oids: Optional[Set] = None,
+        *,
+        session: Session,
+    ) -> Any:
+        """Override the ordinary template rendering to add more logic.
+
+        Specifically, if we're rendering a mapped argument, we need to "unmap"
+        the value as well to assign it to the unmapped operator.
+        """
+        content = super()._render_template_field(key, content, context, jinja_env, seen_oids, session=session)
+        return self._expand_mapped_field(key, content, context, session=session)
+
+    def _expand_mapped_field(self, key: str, content: Any, context: Context, *, session: Session) -> Any:
+        map_index = context["ti"].map_index
+        if map_index < 0:
+            return content
+        expansion_kwargs = self._get_expansion_kwargs()
+        all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+        def _find_index_for_this_field(index: int) -> int:
+            # Need to use self.mapped_kwargs for the original argument order.
+            for mapped_key in reversed(list(expansion_kwargs)):
+                mapped_length = all_lengths[mapped_key]
+                if mapped_key == key:
+                    return index % mapped_length
+                index //= mapped_length
+            return -1
+
+        found_index = _find_index_for_this_field(map_index)
+        if found_index < 0:
+            return content
+        if isinstance(content, collections.abc.Sequence):
+            return content[found_index]
+        if not isinstance(content, dict):
+            raise TypeError(f"can't map over value of type {type(content)}")
+        for i, (k, v) in enumerate(content.items()):
+            if i == found_index:
+                return k, v
+        raise IndexError(f"index {map_index} is over mapped length")
+
+
+class _FieldNotMapped(Exception):
+    """Raised by _expand_mapped_field if a field is not mapped."""

Review comment:
       Forgot to delete it after refactoring.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org