You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2021/04/16 07:23:17 UTC

[GitHub] [airflow] yuqian90 commented on a change in pull request #14640: Allow ExternalTaskSensor to wait for taskgroup

yuqian90 commented on a change in pull request #14640:
URL: https://github.com/apache/airflow/pull/14640#discussion_r614600073



##########
File path: airflow/models/dag.py
##########
@@ -745,10 +745,18 @@ def tasks(self, val):
     def task_ids(self) -> List[str]:
         return list(self.task_dict.keys())
 
+    @property
+    def task_group_dict(self) -> Dict[str, "TaskGroup"]:
+        return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None}

Review comment:
       `get_task_group_dict()` is a recursive function that can be costly. I think we should keep it a method instead of making it a property (which tends to suggest to users that it's cheap to access).

##########
File path: airflow/models/dag.py
##########
@@ -1201,7 +1209,7 @@ def clear(
             tis = tis.filter(TI.task_id.in_(self.task_ids))
 
         if include_parentdag and self.is_subdag and self.parent_dag is not None:
-            p_dag = self.parent_dag.sub_dag(
+            p_dag = self.parent_dag.partial_subset(

Review comment:
       Since you are changing this to `partial_subset`, let's use the new feature of this method. You may consider passing it like this to avoid the regex dance: 
   ```
   task_ids_or_regex=[self.dag_id.split('.')[1]]
   ```
   
   This is what the docstr of `partial_subset` says:
   ```
           :param task_ids_or_regex: Either a list of task_ids, or a regex to
               match against task ids (as a string, or compiled regex pattern).
   ```

##########
File path: airflow/sensors/external_task.py
##########
@@ -206,29 +232,48 @@ def get_count(self, dttm_filter, session, states) -> int:
         """
         TI = TaskInstance
         DR = DagRun
+
         if self.external_task_id:
             count = (
-                session.query(func.count())  # .count() is inefficient
-                .filter(
-                    TI.dag_id == self.external_dag_id,
-                    TI.task_id == self.external_task_id,
-                    TI.state.in_(states),  # pylint: disable=no-member
-                    TI.execution_date.in_(dttm_filter),
-                )
+                self._count_query(TI, session, states, dttm_filter)
+                .filter(TI.task_id == self.external_task_id)
                 .scalar()
             )
-        else:
+        elif self.external_task_group_id:
+            external_task_group_task_ids = self.get_external_task_group_task_ids(session)
             count = (
-                session.query(func.count())
-                .filter(
-                    DR.dag_id == self.external_dag_id,
-                    DR.state.in_(states),  # pylint: disable=no-member
-                    DR.execution_date.in_(dttm_filter),
-                )
+                self._count_query(TI, session, states, dttm_filter)
+                .filter(TI.task_id.in_(external_task_group_task_ids))
                 .scalar()
-            )
+            ) / len(external_task_group_task_ids)
+        else:
+            count = self._count_query(DR, session, states, dttm_filter).scalar()
+
         return count
 
+    def _count_query(self, model, session, states, dttm_filter) -> "Query":
+        query = session.query(func.count()).filter(  # .count() is inefficient
+            model.dag_id == self.external_dag_id,
+            model.state.in_(states),  # pylint: disable=no-member
+            model.execution_date.in_(dttm_filter),
+        )
+
+        return query
+
+    def get_external_task_group_task_ids(self, session):
+        """Return task ids for the external TaskGroup"""
+        refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
+        task_group: Optional["TaskGroup"] = refreshed_dag_info.task_group_dict.get(
+            self.external_task_group_id
+        )
+        if not task_group:
+            raise AirflowException(
+                f"The external task group {self.external_task_group_id} in "
+                f"DAG {self.external_dag_id} does not exist."
+            )
+        task_ids = [task.task_id for task in task_group]
+        return task_ids
+

Review comment:
       The existing task execution code is creating DagBag on its own instead of reading serialized dags from db. For example this line is creating a DagBag. I think we should do the same here. It's important for tasks to get the latest view of the dag during execution. 
   
   https://github.com/apache/airflow/blob/f1edc220d3f9cb050016d23246a682276bd09eee/airflow/sensors/external_task.py#L213
   

##########
File path: tests/sensors/test_external_task.py
##########
@@ -445,7 +545,7 @@ def clear_tasks(dag_bag, dag, task, start_date=DEFAULT_DATE, end_date=DEFAULT_DA
     """
     Clear the task and its downstream tasks recursively for the dag in the given dagbag.
     """
-    subdag = dag.sub_dag(task_ids_or_regex=f"^{task.task_id}$", include_downstream=True)
+    subdag = dag.partial_subset(task_ids_or_regex=f"^{task.task_id}$", include_downstream=True)

Review comment:
       same here. `task_ids_or_regex` can be a list of task_id

##########
File path: airflow/models/dag.py
##########
@@ -1282,7 +1290,7 @@ def clear(
                             external_dag = dag_bag.get_dag(tii.dag_id)
                             if not external_dag:
                                 raise AirflowException(f"Could not find dag {tii.dag_id}")
-                            downstream = external_dag.sub_dag(
+                            downstream = external_dag.partial_subset(

Review comment:
       Same here. This wants an exact match of the task_id. So passing `task_ids_or_regex` as a list of task_id is better.

##########
File path: airflow/models/dag.py
##########
@@ -745,10 +745,18 @@ def tasks(self, val):
     def task_ids(self) -> List[str]:
         return list(self.task_dict.keys())
 
+    @property
+    def task_group_dict(self) -> Dict[str, "TaskGroup"]:
+        return {k: v for k, v in self.task_group.get_task_group_dict().items() if k is not None}
+
     @property
     def task_group(self) -> "TaskGroup":
         return self._task_group
 
+    @property
+    def task_groups(self) -> List["TaskGroup"]:
+        return list(self.task_group_dict.values())

Review comment:
       Same here for `task_groups`

##########
File path: tests/sensors/test_external_task.py
##########
@@ -24,32 +23,49 @@
 from airflow.exceptions import AirflowException, AirflowSensorTimeout
 from airflow.models import DagBag, TaskInstance
 from airflow.models.dag import DAG
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.operators.bash import BashOperator
 from airflow.operators.dummy import DummyOperator
 from airflow.sensors.external_task import ExternalTaskMarker, ExternalTaskSensor
 from airflow.sensors.time_sensor import TimeSensor
 from airflow.serialization.serialized_objects import SerializedBaseOperator
 from airflow.utils.state import State
+from airflow.utils.task_group import TaskGroup
 from airflow.utils.timezone import datetime
+from tests.test_utils.db import clear_db_runs
 
 DEFAULT_DATE = datetime(2015, 1, 1)
 TEST_DAG_ID = 'unit_test_dag'
 TEST_TASK_ID = 'time_sensor_check'
+TEST_TASK_GROUP_ID = 'dummy_task_group'
 DEV_NULL = '/dev/null'
 
 
-class TestExternalTaskSensor(unittest.TestCase):
-    def setUp(self):
-        self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
+class TestExternalTaskSensor:
+    def setup_method(self):
         self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
         self.dag = DAG(TEST_DAG_ID, default_args=self.args)
+        SerializedDagModel.write_dag(self.dag)

Review comment:
       I don't think `SerializedDagModel.write_dag` is needed if the task creates its own dagbag like it previously does.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org