You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jh...@apache.org on 2021/05/11 19:10:16 UTC

[airflow] branch master updated: Return output of last task from task_group wrapper. (#15779)

This is an automated email from the ASF dual-hosted git repository.

jhtimmins pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 303c89f  Return output of last task from task_group wrapper. (#15779)
303c89f is described below

commit 303c89fea0a7cf8a857436182abe1b042d473022
Author: James Timmins <ja...@astronomer.io>
AuthorDate: Tue May 11 12:09:58 2021 -0700

    Return output of last task from task_group wrapper. (#15779)
---
 airflow/decorators/task_group.py |  7 ++-----
 tests/utils/test_task_group.py   | 38 +++++++++++++++++++++++++++++++++++---
 2 files changed, 37 insertions(+), 8 deletions(-)

diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py
index 8c169dd..89283b2 100644
--- a/airflow/decorators/task_group.py
+++ b/airflow/decorators/task_group.py
@@ -58,12 +58,9 @@ def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs
             # Initialize TaskGroup with bound arguments
             with TaskGroup(
                 *task_group_bound_args.args, add_suffix_on_collision=True, **task_group_bound_args.kwargs
-            ) as tg_obj:
+            ):
                 # Invoke function to run Tasks inside the TaskGroup
-                f(*args, **kwargs)
-
-            # Return task_group object such that it's accessible in Globals.
-            return tg_obj
+                return f(*args, **kwargs)
 
         return cast(T, factory)
 
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index b21a9d7..2cb71e9 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -19,11 +19,13 @@
 import pendulum
 import pytest
 
-from airflow.decorators import task_group as task_group_decorator
+from airflow.decorators import dag, task_group as task_group_decorator
 from airflow.models import DAG
+from airflow.models.xcom_arg import XComArg
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
 from airflow.utils.task_group import TaskGroup
 from airflow.www.views import dag_edges, task_group_to_dict
 
@@ -676,7 +678,6 @@ def test_build_task_group_deco_context_manager():
                     },
                     {'id': 'section_1.task_1'},
                     {'id': 'section_1.task_2'},
-                    {'id': 'section_1.downstream_join_id'},
                 ],
             },
             {'id': 'task_end'},
@@ -805,7 +806,6 @@ def test_task_group_context_mix():
                             {'id': 'section_1.section_2.task_1'},
                             {'id': 'section_1.section_2.task_2'},
                             {'id': 'section_1.section_2.task_3'},
-                            {'id': 'section_1.section_2.downstream_join_id'},
                         ],
                     },
                     {'id': 'section_1.task_1'},
@@ -947,3 +947,35 @@ def test_call_taskgroup_twice():
     }
 
     assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
+
+
+def test_pass_taskgroup_output_to_task():
+    """Test that the output of a task group can be passed to a task."""
+    from airflow.decorators import task
+
+    @task
+    def one():
+        return 1
+
+    @task_group_decorator
+    def addition_task_group(num):
+        @task
+        def add_one(i):
+            return i + 1
+
+        return add_one(num)
+
+    @task
+    def increment(num):
+        return num + 1
+
+    @dag(schedule_interval=None, start_date=days_ago(1), default_args={"owner": "airflow"})
+    def wrap():
+        total_1 = one()
+        assert isinstance(total_1, XComArg)
+        total_2 = addition_task_group(total_1)
+        assert isinstance(total_2, XComArg)
+        total_3 = increment(total_2)
+        assert isinstance(total_3, XComArg)
+
+    wrap()