You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2020/12/22 19:49:24 UTC

[GitHub] [airflow] casassg commented on a change in pull request #12228: taskgroup decorator (#11870)

casassg commented on a change in pull request #12228:
URL: https://github.com/apache/airflow/pull/12228#discussion_r547475486



##########
File path: airflow/utils/task_group.py
##########
@@ -344,3 +357,52 @@ def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]:
                 return dag.task_group
 
         return cls._context_managed_task_group
+
+
+T = TypeVar("T", bound=Callable)  # pylint: disable=invalid-name
+
+
+def taskgroup(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]:
+    """
+    Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup.
+    Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup.
+
+    :param python_callable: Function to decorate
+    :param tg_args: Arguments for TaskGroup object
+    :type tg_args: list
+    :param tg_kwargs: Kwargs for TaskGroup object.
+    :type tg_kwargs: dict
+    """
+
+    def wrapper(f: T):
+        # Setting group_id as function name if not given
+        if len(tg_args) == 0 and 'group_id' not in tg_kwargs.keys():
+            tg_kwargs['group_id'] = f.__name__
+
+        # Get dag initializer signature and bind it to validate that task_group_args,
+        # and task_group_kwargs are correct
+        task_group_sig = signature(TaskGroup.__init__)
+        task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs)
+
+        @functools.wraps(f)
+        def factory(*args, **kwargs):

Review comment:
       May be nice to provide a specific kwarg for overwriting the group_id

##########
File path: airflow/utils/task_group.py
##########
@@ -344,3 +357,52 @@ def get_current_task_group(cls, dag: Optional["DAG"]) -> Optional[TaskGroup]:
                 return dag.task_group
 
         return cls._context_managed_task_group
+
+
+T = TypeVar("T", bound=Callable)  # pylint: disable=invalid-name
+
+
+def taskgroup(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs) -> Callable[[T], T]:
+    """
+    Python TaskGroup decorator. Wraps a function into an Airflow TaskGroup.
+    Accepts kwargs for operator TaskGroup. Can be used to parametrize TaskGroup.
+
+    :param python_callable: Function to decorate
+    :param tg_args: Arguments for TaskGroup object
+    :type tg_args: list
+    :param tg_kwargs: Kwargs for TaskGroup object.
+    :type tg_kwargs: dict
+    """
+
+    def wrapper(f: T):
+        # Setting group_id as function name if not given
+        if len(tg_args) == 0 and 'group_id' not in tg_kwargs.keys():
+            tg_kwargs['group_id'] = f.__name__

Review comment:
       Can we add logic to autogenerate group_id if it already exists on the DAG? 
   
   A taskgroup decorated function can be invoked multiple times, this should generate multiple taskgroups. 
   
   We can follow a similar approach than @task decorator where we append `__1` to the end. Probs we can reuse the same regex.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org