You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/01/12 00:08:32 UTC
[airflow] branch v2-5-test updated: Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)
This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-5-test by this push:
new b967fc97d1 Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)
b967fc97d1 is described below
commit b967fc97d19f18ceeff38945fedb627382061a2c
Author: Victor Chiapaikeo <vc...@gmail.com>
AuthorDate: Wed Jan 4 06:39:53 2023 -0500
Allow XComArgs for external_task_ids of ExternalTaskSensor (#28692)
(cherry picked from commit 7f18fa96e434c64288d801904caf1fcde18e2cbf)
---
airflow/sensors/external_task.py | 6 ++-
tests/sensors/test_external_task_sensor.py | 72 ++++++++++++++++++++++++++----
2 files changed, 68 insertions(+), 10 deletions(-)
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index e9573a0671..967bb5a276 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -162,8 +162,7 @@ class ExternalTaskSensor(BaseSensorOperator):
f"when `external_task_id` or `external_task_ids` or `external_task_group_id` "
f"is not `None`: {State.task_states}"
)
- if external_task_ids and len(external_task_ids) > len(set(external_task_ids)):
- raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
+
elif not total_states <= set(State.dag_states):
raise ValueError(
f"Valid values for `allowed_states` and `failed_states` "
@@ -196,6 +195,9 @@ class ExternalTaskSensor(BaseSensorOperator):
@provide_session
def poke(self, context, session=None):
+ if self.external_task_ids and len(self.external_task_ids) > len(set(self.external_task_ids)):
+ raise ValueError("Duplicate task_ids passed in external_task_ids parameter")
+
dttm_filter = self._get_dttm_filter(context)
serialized_dttm_filter = ",".join(dt.isoformat() for dt in dttm_filter)
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index 80f538e868..b594210b13 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -33,8 +33,10 @@ from airflow.exceptions import AirflowException, AirflowSensorTimeout
from airflow.models import DagBag, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
+from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
+from airflow.operators.python import PythonOperator
from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor, ExternalTaskSensorLink
from airflow.sensors.time_sensor import TimeSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
@@ -45,6 +47,7 @@ from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests.models import TEST_DAGS_FOLDER
from tests.test_utils.db import clear_db_runs
+from tests.test_utils.mock_operators import MockOperator
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = "unit_test_dag"
@@ -579,17 +582,70 @@ exit 0
dag=self.dag,
)
+ def test_external_task_sensor_with_xcom_arg_does_not_fail_on_init(self):
+ self.add_time_sensor()
+ op1 = MockOperator(task_id="op1", dag=self.dag)
+ op2 = ExternalTaskSensor(
+ task_id="test_external_task_sensor_with_xcom_arg_does_not_fail_on_init",
+ external_dag_id=TEST_DAG_ID,
+ external_task_ids=XComArg(op1),
+ allowed_states=["success"],
+ dag=self.dag,
+ )
+ assert isinstance(op2.external_task_ids, XComArg)
+
def test_catch_duplicate_task_ids(self):
self.add_time_sensor()
- # Test By passing same task_id multiple times
+ op1 = ExternalTaskSensor(
+ task_id="test_external_task_duplicate_task_ids",
+ external_dag_id=TEST_DAG_ID,
+ external_task_ids=[TEST_TASK_ID, TEST_TASK_ID],
+ allowed_states=["success"],
+ dag=self.dag,
+ )
with pytest.raises(ValueError):
- ExternalTaskSensor(
- task_id="test_external_task_duplicate_task_ids",
- external_dag_id=TEST_DAG_ID,
- external_task_ids=[TEST_TASK_ID, TEST_TASK_ID],
- allowed_states=["success"],
- dag=self.dag,
- )
+ op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_catch_duplicate_task_ids_with_xcom_arg(self):
+ self.add_time_sensor()
+ op1 = PythonOperator(
+ python_callable=lambda: ["dupe_value", "dupe_value"],
+ task_id="op1",
+ do_xcom_push=True,
+ dag=self.dag,
+ )
+
+ op2 = ExternalTaskSensor(
+ task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
+ external_dag_id=TEST_DAG_ID,
+ external_task_ids=XComArg(op1),
+ allowed_states=["success"],
+ dag=self.dag,
+ )
+ with pytest.raises(ValueError):
+ op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+ def test_catch_duplicate_task_ids_with_multiple_xcom_args(self):
+ self.add_time_sensor()
+
+ op1 = PythonOperator(
+ python_callable=lambda: "value",
+ task_id="op1",
+ do_xcom_push=True,
+ dag=self.dag,
+ )
+
+ op2 = ExternalTaskSensor(
+ task_id="test_external_task_duplicate_task_ids_with_xcom_arg",
+ external_dag_id=TEST_DAG_ID,
+ external_task_ids=[XComArg(op1), XComArg(op1)],
+ allowed_states=["success"],
+ dag=self.dag,
+ )
+ with pytest.raises(ValueError):
+ op1.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+ op2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
def test_catch_invalid_allowed_states(self):
with pytest.raises(ValueError):