You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by jh...@apache.org on 2021/05/11 19:10:16 UTC
[airflow] branch master updated: Return output of last task from
task_group wrapper. (#15779)
This is an automated email from the ASF dual-hosted git repository.
jhtimmins pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new 303c89f Return output of last task from task_group wrapper. (#15779)
303c89f is described below
commit 303c89fea0a7cf8a857436182abe1b042d473022
Author: James Timmins <ja...@astronomer.io>
AuthorDate: Tue May 11 12:09:58 2021 -0700
Return output of last task from task_group wrapper. (#15779)
---
airflow/decorators/task_group.py | 7 ++-----
tests/utils/test_task_group.py | 38 +++++++++++++++++++++++++++++++++++---
2 files changed, 37 insertions(+), 8 deletions(-)
diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py
index 8c169dd..89283b2 100644
--- a/airflow/decorators/task_group.py
+++ b/airflow/decorators/task_group.py
@@ -58,12 +58,9 @@ def task_group(python_callable: Optional[Callable] = None, *tg_args, **tg_kwargs
# Initialize TaskGroup with bound arguments
with TaskGroup(
*task_group_bound_args.args, add_suffix_on_collision=True, **task_group_bound_args.kwargs
- ) as tg_obj:
+ ):
# Invoke function to run Tasks inside the TaskGroup
- f(*args, **kwargs)
-
- # Return task_group object such that it's accessible in Globals.
- return tg_obj
+ return f(*args, **kwargs)
return cast(T, factory)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index b21a9d7..2cb71e9 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -19,11 +19,13 @@
import pendulum
import pytest
-from airflow.decorators import task_group as task_group_decorator
+from airflow.decorators import dag, task_group as task_group_decorator
from airflow.models import DAG
+from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
+from airflow.utils.dates import days_ago
from airflow.utils.task_group import TaskGroup
from airflow.www.views import dag_edges, task_group_to_dict
@@ -676,7 +678,6 @@ def test_build_task_group_deco_context_manager():
},
{'id': 'section_1.task_1'},
{'id': 'section_1.task_2'},
- {'id': 'section_1.downstream_join_id'},
],
},
{'id': 'task_end'},
@@ -805,7 +806,6 @@ def test_task_group_context_mix():
{'id': 'section_1.section_2.task_1'},
{'id': 'section_1.section_2.task_2'},
{'id': 'section_1.section_2.task_3'},
- {'id': 'section_1.section_2.downstream_join_id'},
],
},
{'id': 'section_1.task_1'},
@@ -947,3 +947,35 @@ def test_call_taskgroup_twice():
}
assert extract_node_id(task_group_to_dict(dag.task_group)) == node_ids
+
+
+def test_pass_taskgroup_output_to_task():
+ """Test that the output of a task group can be passed to a task."""
+ from airflow.decorators import task
+
+ @task
+ def one():
+ return 1
+
+ @task_group_decorator
+ def addition_task_group(num):
+ @task
+ def add_one(i):
+ return i + 1
+
+ return add_one(num)
+
+ @task
+ def increment(num):
+ return num + 1
+
+ @dag(schedule_interval=None, start_date=days_ago(1), default_args={"owner": "airflow"})
+ def wrap():
+ total_1 = one()
+ assert isinstance(total_1, XComArg)
+ total_2 = addition_task_group(total_1)
+ assert isinstance(total_2, XComArg)
+ total_3 = increment(total_2)
+ assert isinstance(total_3, XComArg)
+
+ wrap()