You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/01/21 01:16:45 UTC

[airflow] branch master updated: BaseBranchOperator will push to xcom by default. (#13704) (#13763)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e25795  BaseBranchOperator will push to xcom by default. (#13704) (#13763)
3e25795 is described below

commit 3e257950990a6edd817c372036352f96d4f8a76b
Author: Ashmeet Lamba <as...@gmail.com>
AuthorDate: Thu Jan 21 06:46:32 2021 +0530

    BaseBranchOperator will push to xcom by default. (#13704) (#13763)
    
    This change will BaseBranchOperator to do xcom push of the branch it choose to follow.
    It will also add support to use the do_xcom_push parameter.
    
    The added change returns the result received by running choose_branch().
    
    Closes: #13704
---
 airflow/operators/branch.py             |  4 +++-
 tests/operators/test_branch_operator.py | 21 +++++++++++++++++++++
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index a465341..d1502b4 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -49,4 +49,6 @@ class BaseBranchOperator(BaseOperator, SkipMixin):
         raise NotImplementedError
 
     def execute(self, context: Dict):
-        self.skip_all_except(context['ti'], self.choose_branch(context))
+        branches_to_execute = self.choose_branch(context)
+        self.skip_all_except(context['ti'], branches_to_execute)
+        return branches_to_execute
diff --git a/tests/operators/test_branch_operator.py b/tests/operators/test_branch_operator.py
index d372534..f54dafe 100644
--- a/tests/operators/test_branch_operator.py
+++ b/tests/operators/test_branch_operator.py
@@ -170,3 +170,24 @@ class TestBranchOperator(unittest.TestCase):
                 assert ti.state == State.NONE
             else:
                 raise Exception
+
+    def test_xcom_push(self):
+        self.branch_op = ChooseBranchOne(task_id='make_choice', dag=self.dag)
+
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
+        dr = self.dag.create_dagrun(
+            run_type=DagRunType.MANUAL,
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING,
+        )
+
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                assert ti.xcom_pull(task_ids='make_choice') == 'branch_1'