You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2021/11/03 21:30:03 UTC

[airflow] 04/17: Use ``execution_date`` to check for existing ``DagRun`` for ``TriggerDagRunOperator`` (#18968)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 901901a4d19e3fbe33b58881bf7e0b09e4982fed
Author: Gulshan Gill <gu...@gmail.com>
AuthorDate: Thu Nov 4 03:09:41 2021 +0800

    Use ``execution_date`` to check for existing ``DagRun`` for ``TriggerDagRunOperator`` (#18968)
    
    A small suggestion to change `DagRun.find` in `trigger_dag` to use `execution_date` as a parameter rather than `run_id`.
    
    I feel it would be better to use this rather than `run_id` as a parameter since using `run_id` will miss out checking for a scheduled run that ran at the same `execution_date` and throw the error below when it tries to create a new run with the same `execution_date`:
    
    ```
    sqlalchemy.exc.IntegrityError: (psycopg2.errors.UniqueViolation) duplicate key value violates unique constraint "dag_run_dag_id_execution_date_key"
    ```
    
    There is a constraint in `dag_run` called `dag_run_dag_id_execution_date_key` which can be found [here](https://github.com/apache/airflow/blob/c4f5233cd10ae03ee69fba861c8a6fa64e1f8a71/airflow/models/dagrun.py#L103).
    
    (cherry picked from commit e54ee6e0d38ca469be6ba686e32ce7a3a34d03ca)
---
 airflow/api/common/experimental/trigger_dag.py    |  6 ++-
 airflow/models/dagrun.py                          | 65 +++++++++++++++++------
 tests/api/common/experimental/test_trigger_dag.py |  6 +--
 tests/models/test_dagrun.py                       | 23 ++++++++
 4 files changed, 78 insertions(+), 22 deletions(-)

diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py
index 2e64f86..38a873c 100644
--- a/airflow/api/common/experimental/trigger_dag.py
+++ b/airflow/api/common/experimental/trigger_dag.py
@@ -68,10 +68,12 @@ def _trigger_dag(
             )
 
     run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
-    dag_run = DagRun.find(dag_id=dag_id, run_id=run_id)
+    dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id)
 
     if dag_run:
-        raise DagRunAlreadyExists(f"Run id {run_id} already exists for dag id {dag_id}")
+        raise DagRunAlreadyExists(
+            f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}"
+        )
 
     run_conf = None
     if conf:
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 800720c..8d2ab2a 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -285,12 +285,13 @@ class DagRun(Base, LoggingMixin):
             query.limit(max_number), of=cls, session=session, **skip_locked(session=session)
         )
 
-    @staticmethod
+    @classmethod
     @provide_session
     def find(
+        cls,
         dag_id: Optional[Union[str, List[str]]] = None,
         run_id: Optional[str] = None,
-        execution_date: Optional[datetime] = None,
+        execution_date: Optional[Union[datetime, List[datetime]]] = None,
         state: Optional[DagRunState] = None,
         external_trigger: Optional[bool] = None,
         no_backfills: bool = False,
@@ -324,35 +325,65 @@ class DagRun(Base, LoggingMixin):
         :param execution_end_date: dag run that was executed until this date
         :type execution_end_date: datetime.datetime
         """
-        DR = DagRun
-
-        qry = session.query(DR)
+        qry = session.query(cls)
         dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
         if dag_ids:
-            qry = qry.filter(DR.dag_id.in_(dag_ids))
+            qry = qry.filter(cls.dag_id.in_(dag_ids))
         if run_id:
-            qry = qry.filter(DR.run_id == run_id)
+            qry = qry.filter(cls.run_id == run_id)
         if execution_date:
             if isinstance(execution_date, list):
-                qry = qry.filter(DR.execution_date.in_(execution_date))
+                qry = qry.filter(cls.execution_date.in_(execution_date))
             else:
-                qry = qry.filter(DR.execution_date == execution_date)
+                qry = qry.filter(cls.execution_date == execution_date)
         if execution_start_date and execution_end_date:
-            qry = qry.filter(DR.execution_date.between(execution_start_date, execution_end_date))
+            qry = qry.filter(cls.execution_date.between(execution_start_date, execution_end_date))
         elif execution_start_date:
-            qry = qry.filter(DR.execution_date >= execution_start_date)
+            qry = qry.filter(cls.execution_date >= execution_start_date)
         elif execution_end_date:
-            qry = qry.filter(DR.execution_date <= execution_end_date)
+            qry = qry.filter(cls.execution_date <= execution_end_date)
         if state:
-            qry = qry.filter(DR.state == state)
+            qry = qry.filter(cls.state == state)
         if external_trigger is not None:
-            qry = qry.filter(DR.external_trigger == external_trigger)
+            qry = qry.filter(cls.external_trigger == external_trigger)
         if run_type:
-            qry = qry.filter(DR.run_type == run_type)
+            qry = qry.filter(cls.run_type == run_type)
         if no_backfills:
-            qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB)
+            qry = qry.filter(cls.run_type != DagRunType.BACKFILL_JOB)
+
+        return qry.order_by(cls.execution_date).all()
+
+    @classmethod
+    @provide_session
+    def find_duplicate(
+        cls,
+        dag_id: str,
+        run_id: str,
+        execution_date: datetime,
+        session: Session = None,
+    ) -> Optional['DagRun']:
+        """
+        Return an existing run for the DAG with a specific run_id or execution_date.
 
-        return qry.order_by(DR.execution_date).all()
+        *None* is returned if no such DAG run is found.
+
+        :param dag_id: the dag_id to find duplicates for
+        :type dag_id: str
+        :param run_id: defines the run id for this dag run
+        :type run_id: str
+        :param execution_date: the execution date
+        :type execution_date: datetime.datetime
+        :param session: database session
+        :type session: sqlalchemy.orm.session.Session
+        """
+        return (
+            session.query(cls)
+            .filter(
+                cls.dag_id == dag_id,
+                or_(cls.run_id == run_id, cls.execution_date == execution_date),
+            )
+            .one_or_none()
+        )
 
     @staticmethod
     def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py
index cbca935..2f16446 100644
--- a/tests/api/common/experimental/test_trigger_dag.py
+++ b/tests/api/common/experimental/test_trigger_dag.py
@@ -49,7 +49,7 @@ class TestTriggerDag(unittest.TestCase):
         dag = DAG(dag_id)
         dag_bag_mock.dags = [dag_id]
         dag_bag_mock.get_dag.return_value = dag
-        dag_run_mock.find.return_value = DagRun()
+        dag_run_mock.find_duplicate.return_value = DagRun()
         with pytest.raises(AirflowException):
             _trigger_dag(dag_id, dag_bag_mock)
 
@@ -60,7 +60,7 @@ class TestTriggerDag(unittest.TestCase):
         dag_id = "trigger_dag"
         dag_bag_mock.dags = [dag_id]
         dag_bag_mock.get_dag.return_value = dag_mock
-        dag_run_mock.find.return_value = None
+        dag_run_mock.find_duplicate.return_value = None
         dag1 = mock.MagicMock(subdags=[])
         dag2 = mock.MagicMock(subdags=[])
         dag_mock.subdags = [dag1, dag2]
@@ -76,7 +76,7 @@ class TestTriggerDag(unittest.TestCase):
         dag_id = "trigger_dag"
         dag_bag_mock.dags = [dag_id]
         dag_bag_mock.get_dag.return_value = dag_mock
-        dag_run_mock.find.return_value = None
+        dag_run_mock.find_duplicate.return_value = None
         dag1 = mock.MagicMock(subdags=[])
         dag2 = mock.MagicMock(subdags=[dag1])
         dag_mock.subdags = [dag1, dag2]
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index c4ef287..00799be 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -142,6 +142,29 @@ class TestDagRun(unittest.TestCase):
         assert 0 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))
         assert 1 == len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))
 
+    def test_dagrun_find_duplicate(self):
+        session = settings.Session()
+        now = timezone.utcnow()
+
+        dag_id = "test_dagrun_find_duplicate"
+        dag_run = models.DagRun(
+            dag_id=dag_id,
+            run_id=dag_id,
+            run_type=DagRunType.MANUAL,
+            execution_date=now,
+            start_date=now,
+            state=State.RUNNING,
+            external_trigger=True,
+        )
+        session.add(dag_run)
+
+        session.commit()
+
+        assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, execution_date=now) is not None
+        assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, execution_date=None) is not None
+        assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=None, execution_date=now) is not None
+        assert models.DagRun.find_duplicate(dag_id=dag_id, run_id=None, execution_date=None) is None
+
     def test_dagrun_success_when_all_skipped(self):
         """
         Tests that a DAG run succeeds when all tasks are skipped