You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/01/12 00:08:32 UTC

[airflow] branch v2-5-test updated: Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)

This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-5-test by this push:
     new b967fc97d1 Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)
b967fc97d1 is described below

commit b967fc97d19f18ceeff38945fedb627382061a2c
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Wed Jan 4 06:39:53 2023 -0500

    Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)
    
    (cherry picked from commit 7f18fa96e434c64288d801904caf1fcde18e2cbf)
---
 airflow/sensors/external_task.py           |  6 ++-
 tests/sensors/test_external_task_sensor.py | 72 ++++++++++++++++++++++++++----
 2 files changed, 68 insertions(+), 10 deletions(-)

diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index e9573a0671..967bb5a276 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -162,8 +162,7 @@ class ExternalTaskSensor(BaseSensorOperator):
                     f"when `external_task_id` or `external_task_ids` or `external_task_group_id` "
                     f"is not `None`: {State.task_states}"
                 )
-            if external_task_ids and len(external_task_ids) > len(set(external_task_ids)):
-                raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
+
         elif not total_states <= set(State.dag_states):
             raise ValueError(
                 f"Valid values for `allowed_states` and `failed_states` "
@@ -196,6 +195,9 @@ class ExternalTaskSensor(BaseSensorOperator):
 
     @provide_session
     def poke(self, context, session=None):
+        if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)):
+            raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
+
         dttm_filter = self._get_dttm_filter(context)
         serialized_dttm_filter = ",".join(dt.isoformat() for dt in dttm_filter)
 
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 80f538e868..b594210b13 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -33,8 +33,10 @@ from airflow.exceptions import AirflowException, AirflowSensorTimeout
 from airflow.models import DagBag, DagRun, TaskInstance
 from airflow.models.dag import DAG
 from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.xcom_arg import XComArg
 from airflow.operators.bash import BashOperator
 from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
 from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink
 from airflow.sensors.time_sensor import TimeSensor
 from airflow.serialization.serialized_objects import SerializedBaseOperator
@@ -45,6 +47,7 @@ from airflow.utils.timezone import datetime
 from airflow.utils.types import DagRunType
 from tests.models import TEST_DAGS_FOLDER
 from tests.test_utils.db import clear_db_runs
+from tests.test_utils.mock_operators import MockOperator
 
 DEFAULT_DATE = datetime(2015, 1, 1)
 TEST_DAG_ID = "unit_test_dag"
@@ -579,17 +582,70 @@ exit 0
                 dag=self.dag,
             )
 
+    def test_external_task_sensor_with_xcom_arg_does_not_fail_on_init(self):
+        self.add_time_sensor()
+        op1 = MockOperator(task_id="op1", dag=self.dag)
+        op2 = ExternalTaskSensor(
+            task_id="test_external_task_sensor_with_xcom_arg_does_not_fail_on_init",
+            external_dag_id=TEST_DAG_ID,
+            external_task_ids=XComArg(op1),
+            allowed_states=["success"],
+            dag=self.dag,
+        )
+        assert isinstance(op2.external_task_ids, XComArg)
+
     def test_catch_duplicate_task_ids(self):
         self.add_time_sensor()
-        # Test By passing same task_id multiple times
+        op1 = ExternalTaskSensor(
+            task_id="test_external_task_duplicate_task_ids",
+            external_dag_id=TEST_DAG_ID,
+            external_task_ids=[TEST_TASK_ID, TEST_TASK_ID],
+            allowed_states=["success"],
+            dag=self.dag,
+        )
         with pytest.raises(ValueError):
-            ExternalTaskSensor(
-                task_id="test_external_task_duplicate_task_ids",
-                external_dag_id=TEST_DAG_ID,
-                external_task_ids=[TEST_TASK_ID, TEST_TASK_ID],
-                allowed_states=["success"],
-                dag=self.dag,
-            )
+            op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_catch_duplicate_task_ids_with_xcom_arg(self):
+        self.add_time_sensor()
+        op1 = PythonOperator(
+            python_callable=lambda: ["dupe_value", "dupe_value"],
+            task_id="op1",
+            do_xcom_push=True,
+            dag=self.dag,
+        )
+
+        op2 = ExternalTaskSensor(
+            task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
+            external_dag_id=TEST_DAG_ID,
+            external_task_ids=XComArg(op1),
+            allowed_states=["success"],
+            dag=self.dag,
+        )
+        with pytest.raises(ValueError):
+            op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+    def test_catch_duplicate_task_ids_with_multiple_xcom_args(self):
+        self.add_time_sensor()
+
+        op1 = PythonOperator(
+            python_callable=lambda: "value",
+            task_id="op1",
+            do_xcom_push=True,
+            dag=self.dag,
+        )
+
+        op2 = ExternalTaskSensor(
+            task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
+            external_dag_id=TEST_DAG_ID,
+            external_task_ids=[XComArg(op1), XComArg(op1)],
+            allowed_states=["success"],
+            dag=self.dag,
+        )
+        with pytest.raises(ValueError):
+            op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+            op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
     def test_catch_invalid_allowed_states(self):
         with pytest.raises(ValueError):