You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bb...@apache.org on 2022/04/20 19:00:54 UTC

[airflow] 12/19: Introduce tuple_().in_() shim for MSSQL compat

This is an automated email from the ASF dual-hosted git repository.

bbovenzi pushed a commit to branch mapped-instance-actions
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 469092494da6b8baa6cfe145b76e40eaa495635e
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Apr 19 18:01:55 2022 +0800

    Introduce tuple_().in_() shim for MSSQL compat
---
 airflow/api/common/mark_tasks.py | 5 +++--
 airflow/models/dag.py            | 8 ++++----
 airflow/utils/sqlalchemy.py      | 8 +++-----
 3 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 349b935e82..1d4709fb82 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
 from datetime import datetime
 from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
 
-from sqlalchemy import or_, tuple_
+from sqlalchemy import or_
 from sqlalchemy.orm import contains_eager
 from sqlalchemy.orm.session import Session as SASession
 
@@ -32,6 +32,7 @@ from airflow.operators.subdag import SubDagOperator
 from airflow.utils import timezone
 from airflow.utils.helpers import exactly_one
 from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
 
@@ -203,7 +204,7 @@ def get_all_dag_task_query(
     if is_string_list:
         qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
     else:
-        qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids))
+        qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
     qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
         contains_eager(TaskInstance.dag_run)
     )
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 9c93bcef13..83860ba591 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -52,7 +52,7 @@ import jinja2
 import pendulum
 from dateutil.relativedelta import relativedelta
 from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
 from sqlalchemy.orm import backref, joinedload, relationship
 from sqlalchemy.orm.query import Query
 from sqlalchemy.orm.session import Session
@@ -85,7 +85,7 @@ from airflow.utils.file import correct_maybe_zipped
 from airflow.utils.helpers import exactly_one, validate_key
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
 
@@ -1451,7 +1451,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(task_ids), None), str):
             tis = tis.filter(TI.task_id.in_(task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids))
+            tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids))
 
         # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC
         if end_date or not self.allow_future_exec_dates:
@@ -1611,7 +1611,7 @@ class DAG(LoggingMixin):
         elif isinstance(next(iter(exclude_task_ids), None), str):
             tis = tis.filter(TI.task_id.notin_(exclude_task_ids))
         else:
-            tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids))
+            tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids)))
 
         return tis
 
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index de4ad01e69..5c36d826b2 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -19,11 +19,12 @@
 import datetime
 import json
 import logging
+from operator import and_, or_
 from typing import Any, Dict, Iterable, Tuple
 
 import pendulum
 from dateutil import relativedelta
-from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_
+from sqlalchemy import event, nullsfirst, tuple_
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import ColumnElement
@@ -338,7 +339,4 @@ def tuple_in_condition(
     """
     if settings.engine.dialect.name != "mssql":
         return tuple_(*columns).in_(collection)
-    clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection]
-    if not clauses:
-        return false()
-    return or_(*clauses)
+    return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))