You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 14:31:33 UTC

[airflow] 02/09: Introduce tuple_().in_() shim for MSSQL compat

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 898480765fe8117938037b194a10663565d44e3a
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 18:01:55 2022 +0800

    Introduce tuple_().in_() shim for MSSQL compat
---
 airflow/api/common/mark_tasks.py |  5 +++--
 airflow/jobs/scheduler_job.py    | 47 +++++++++++++++-------------------------
 airflow/models/dag.py            |  8 +++----
 airflow/models/taskinstance.py   | 21 +++++-------------
 airflow/utils/sqlalchemy.py      | 25 +++++++++++++++++++--
 5 files changed, 52 insertions(+), 54 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 349b935e82..1d4709fb82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
 from datetime import datetime
 from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
-from sqlalchemy import or_, tuple_
+from sqlalchemy import or_
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm.session import Session as SASession
 
@@ -32,6 +32,7 @@ from airflow.operators.subdag import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -203,7 +204,7 @@ def get_all_dag_task_query(
     if is_string_list:
         qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
     else:
-        qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids))
+        qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
     qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
         contains_eager(TaskInstance.dag_run)
     )
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index e0b8c437ac..ac1d25833b 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -28,7 +28,7 @@ from collections import defaultdict
 from datetime import timedelta
 from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple
 
-from sqlalchemy import and_, func, not_, or_, text, tuple_
+from sqlalchemy import func, not_, or_, text
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import load_only, selectinload
 from sqlalchemy.orm.session import Session, make_transient
@@ -55,7 +55,13 @@ from airflow.utils.docs import get_docs_url
 from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
 from airflow.utils.session import create_session, provide_session
-from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import (
+    is_lock_not_available_error,
+    prohibit_commit,
+    skip_locked,
+    tuple_in_condition,
+    with_row_locks,
+)
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -321,17 +327,7 @@ class SchedulerJob(BaseJob):
                 query = query.filter(not_(TI.dag_id.in_(starved_dags)))
 
             if starved_tasks:
-                if settings.engine.dialect.name == 'mssql':
-                    task_filter = or_(
-                        and_(
-                            TaskInstance.dag_id == dag_id,
-                            TaskInstance.task_id == task_id,
-                        )
-                        for (dag_id, task_id) in starved_tasks
-                    )
-                else:
-                    task_filter = tuple_(TaskInstance.dag_id, TaskInstance.task_id).in_(starved_tasks)
-
+                task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks)
                 query = query.filter(not_(task_filter))
 
             query = query.limit(max_tis)
@@ -980,24 +976,15 @@ class SchedulerJob(BaseJob):
         # as DagModel.dag_id and DagModel.next_dagrun
         # This list is used to verify if the DagRun already exist so that we don't attempt to create
         # duplicate dag runs
-
-        if session.bind.dialect.name == 'mssql':
-            existing_dagruns_filter = or_(
-                *(
-                    and_(
-                        DagRun.dag_id == dm.dag_id,
-                        DagRun.execution_date == dm.next_dagrun,
-                    )
-                    for dm in dag_models
-                )
-            )
-        else:
-            existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_(
-                [(dm.dag_id, dm.next_dagrun) for dm in dag_models]
-            )
-
         existing_dagruns = (
-            session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all()
+            session.query(DagRun.dag_id, DagRun.execution_date)
+            .filter(
+                tuple_in_condition(
+                    (DagRun.dag_id, DagRun.execution_date),
+                    ((dm.dag_id, dm.next_dagrun) for dm in dag_models),
+                ),
+            )
+            .all()
         )
 
         active_runs_of_dags = defaultdict(
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9c93bcef13..83860ba591 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -52,7 +52,7 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -85,7 +85,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -1451,7 +1451,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(task_ids), None), str):
             tis = tis.filter(TI.task_id.in_(task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
+            tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1611,7 +1611,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(exclude_task_ids), None), str):
             tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
+            tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
 
         return tis
 
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 48d3a047fb..9d135a47b8 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -67,7 +67,6 @@ from sqlalchemy import (
     inspect,
     or_,
     text,
-    tuple_,
 )
 from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.mutable import MutableDict
@@ -122,7 +121,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars
 from airflow.utils.platform import getuser
 from airflow.utils.retries import run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks
+from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.timeout import timeout
 
@@ -2540,20 +2539,10 @@ class TaskInstance(Base, LoggingMixin):
                 TaskInstance.task_id == first_task_id,
             )
 
-        if settings.engine.dialect.name == 'mssql':
-            return or_(
-                and_(
-                    TaskInstance.dag_id == ti.dag_id,
-                    TaskInstance.task_id == ti.task_id,
-                    TaskInstance.run_id == ti.run_id,
-                    TaskInstance.map_index == ti.map_index,
-                )
-                for ti in tis
-            )
-        else:
-            return tuple_(
-                TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index
-            ).in_([ti.key.primary for ti in tis])
+        return tuple_in_condition(
+            (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index),
+            (ti.key.primary for ti in tis),
+        )
 
 
 # State of the task instance.
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index c240a94456..5c36d826b2 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -19,15 +19,19 @@
 import datetime
 import json
 import logging
-from typing import Any, Dict
+from operator import and_, or_
+from typing import Any, Dict, Iterable, Tuple
 
 import pendulum
 from dateutil import relativedelta
-from sqlalchemy import event, nullsfirst
+from sqlalchemy import event, nullsfirst, tuple_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import Session
+from sqlalchemy.sql import ColumnElement
+from sqlalchemy.sql.expression import ColumnOperators
 from sqlalchemy.types import JSON, DateTime, Text, TypeDecorator, TypeEngine, UnicodeText
 
+from airflow import settings
 from airflow.configuration import conf
 
 log = logging.getLogger(__name__)
@@ -319,3 +323,20 @@ def is_lock_not_available_error(error: OperationalError):
     if db_err_code in ('55P03', 1205, 3572):
         return True
     return False
+
+
+def tuple_in_condition(
+    columns: Tuple[ColumnElement, ...],
+    collection: Iterable[Any],
+) -> ColumnOperators:
+    """Generates a tuple-in-collection operator to use in ``.filter()``.
+
+    For most SQL backends, this generates a simple ``([col, ...]) IN [condition]``
+    clause. This however does not work with MSSQL, where we need to expand to
+    ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually.
+
+    :meta private:
+    """
+    if settings.engine.dialect.name != "mssql":
+        return tuple_(*columns).in_(collection)
+    return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))