You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by "hussein-awala (via GitHub)" <gi...@apache.org> on 2023/03/11 01:59:04 UTC

[GitHub] [airflow] hussein-awala commented on a diff in pull request #29913: Fix mapped tasks partial arguments when DAG default args are provided

hussein-awala commented on code in PR #29913:
URL: https://github.com/apache/airflow/pull/29913#discussion_r1133009097


##########
airflow/models/baseoperator.py:
##########
@@ -240,72 +240,123 @@ def partial(
         task_id = task_group.child_id(task_id)
 
     # Merge DAG and task group level defaults into user-supplied values.
-    partial_kwargs, partial_params = get_merged_defaults(
+    default_partial_kwargs, partial_params = get_merged_defaults(
         dag=dag,
         task_group=task_group,
         task_params=params,
         task_default_args=kwargs.pop("default_args", None),
     )
-    partial_kwargs.update(kwargs)
-
-    # Always fully populate partial kwargs to exclude them from map().
-    partial_kwargs.setdefault("dag", dag)
-    partial_kwargs.setdefault("task_group", task_group)
-    partial_kwargs.setdefault("task_id", task_id)
-    partial_kwargs.setdefault("start_date", start_date)
-    partial_kwargs.setdefault("end_date", end_date)
-    partial_kwargs.setdefault("owner", owner)
-    partial_kwargs.setdefault("email", email)
-    partial_kwargs.setdefault("trigger_rule", trigger_rule)
-    partial_kwargs.setdefault("depends_on_past", depends_on_past)
-    partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past)
-    partial_kwargs.setdefault("wait_for_past_depends_before_skipping", wait_for_past_depends_before_skipping)
-    partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
-    partial_kwargs.setdefault("retries", retries)
-    partial_kwargs.setdefault("queue", queue)
-    partial_kwargs.setdefault("pool", pool)
-    partial_kwargs.setdefault("pool_slots", pool_slots)
-    partial_kwargs.setdefault("execution_timeout", execution_timeout)
-    partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
-    partial_kwargs.setdefault("retry_delay", retry_delay)
-    partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff)
-    partial_kwargs.setdefault("priority_weight", priority_weight)
-    partial_kwargs.setdefault("weight_rule", weight_rule)
-    partial_kwargs.setdefault("sla", sla)
-    partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
-    partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
-    partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
-    partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
-    partial_kwargs.setdefault("on_success_callback", on_success_callback)
-    partial_kwargs.setdefault("run_as_user", run_as_user)
-    partial_kwargs.setdefault("executor_config", executor_config)
-    partial_kwargs.setdefault("inlets", inlets or [])
-    partial_kwargs.setdefault("outlets", outlets or [])
-    partial_kwargs.setdefault("resources", resources)
-    partial_kwargs.setdefault("doc", doc)
-    partial_kwargs.setdefault("doc_json", doc_json)
-    partial_kwargs.setdefault("doc_md", doc_md)
-    partial_kwargs.setdefault("doc_rst", doc_rst)
-    partial_kwargs.setdefault("doc_yaml", doc_yaml)
+
+    # Create partial_kwargs from args and kwargs
+    partial_kwargs = {
+        **kwargs,
+        "dag": dag,
+        "task_group": task_group,
+        "task_id": task_id,
+        "start_date": start_date,
+        "end_date": end_date,
+        "owner": owner,
+        "email": email,
+        "trigger_rule": trigger_rule,
+        "depends_on_past": depends_on_past,
+        "ignore_first_depends_on_past": ignore_first_depends_on_past,
+        "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping,
+        "wait_for_downstream": wait_for_downstream,
+        "retries": retries,
+        "queue": queue,
+        "pool": pool,
+        "pool_slots": pool_slots,
+        "execution_timeout": execution_timeout,
+        "max_retry_delay": max_retry_delay,
+        "retry_delay": retry_delay,
+        "retry_exponential_backoff": retry_exponential_backoff,
+        "priority_weight": priority_weight,
+        "weight_rule": weight_rule,
+        "sla": sla,
+        "max_active_tis_per_dag": max_active_tis_per_dag,
+        "on_execute_callback": on_execute_callback,
+        "on_failure_callback": on_failure_callback,
+        "on_retry_callback": on_retry_callback,
+        "on_success_callback": on_success_callback,
+        "run_as_user": run_as_user,
+        "executor_config": executor_config,
+        "inlets": inlets,
+        "outlets": outlets,
+        "resources": resources,
+        "doc": doc,
+        "doc_json": doc_json,
+        "doc_md": doc_md,
+        "doc_rst": doc_rst,
+        "doc_yaml": doc_yaml,
+    }
+
+    # Override None kwargs by dag default values
+    for k, v in default_partial_kwargs.items():
+        if partial_kwargs.get(k) is None:
+            partial_kwargs[k] = v
+
+    # Override None kwargs which don't have a dag default value by Airflow default value
+    partial_kwargs["owner"] = partial_kwargs["owner"] or DEFAULT_OWNER
+    partial_kwargs["trigger_rule"] = partial_kwargs["trigger_rule"] or DEFAULT_TRIGGER_RULE
+    partial_kwargs["depends_on_past"] = (
+        partial_kwargs["depends_on_past"] if partial_kwargs["depends_on_past"] is not None else False
+    )
+    partial_kwargs["ignore_first_depends_on_past"] = (
+        partial_kwargs["ignore_first_depends_on_past"]
+        if partial_kwargs["ignore_first_depends_on_past"] is not None
+        else DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
+    )
+    partial_kwargs["wait_for_past_depends_before_skipping"] = (
+        partial_kwargs["wait_for_past_depends_before_skipping"]
+        if partial_kwargs["wait_for_past_depends_before_skipping"] is not None
+        else DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
+    )
+    partial_kwargs["wait_for_downstream"] = (
+        partial_kwargs["wait_for_downstream"] if partial_kwargs["wait_for_downstream"] is not None else False
+    )
+    partial_kwargs["retries"] = (
+        partial_kwargs["retries"] if partial_kwargs["retries"] is not None else DEFAULT_RETRIES
+    )
+    partial_kwargs["queue"] = partial_kwargs["queue"] or DEFAULT_QUEUE
+    partial_kwargs["pool_slots"] = (
+        partial_kwargs["pool_slots"] if partial_kwargs["pool_slots"] is not None else DEFAULT_POOL_SLOTS
+    )
+    partial_kwargs["execution_timeout"] = (
+        partial_kwargs["execution_timeout"] or DEFAULT_TASK_EXECUTION_TIMEOUT
+    )
+    partial_kwargs["retry_delay"] = (
+        partial_kwargs["retry_delay"] if partial_kwargs["retry_delay"] is not None else DEFAULT_RETRY_DELAY
+    )
+    partial_kwargs["retry_exponential_backoff"] = (
+        partial_kwargs["retry_exponential_backoff"] if partial_kwargs["retry_exponential_backoff"] else False
+    )
+    partial_kwargs["priority_weight"] = (
+        partial_kwargs["priority_weight"] if partial_kwargs["priority_weight"] else DEFAULT_PRIORITY_WEIGHT
+    )
+    partial_kwargs["weight_rule"] = partial_kwargs["weight_rule"] or DEFAULT_WEIGHT_RULE
+    partial_kwargs["inlets"] = partial_kwargs["inlets"] or []
+    partial_kwargs["outlets"] = partial_kwargs["outlets"] or []
 
     # Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
     if "task_concurrency" in kwargs:  # Reject deprecated option.
         raise TypeError("unexpected argument: task_concurrency")
     if partial_kwargs["wait_for_downstream"]:
         partial_kwargs["depends_on_past"] = True
-    partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"])
-    partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"])
+    partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"])  # type: ignore
+    partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"])  # type: ignore

Review Comment:
   Since I didn't defined the type of `partial_kwargs`, python auto detected it to Dict[str, <combination of all the values types>] which is not compatible with the used functions, so I fixed it by defining the type as Dict[str, Any]



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@airflow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org