You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/03/06 21:47:13 UTC

[airflow] 29/37: Resolve all variables in pickled XCom iterator (#28982)

This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit e2cf93305803027d462373ca356bd5badee1128e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Jan 18 18:19:35 2023 +0800

    Resolve all variables in pickled XCom iterator (#28982)
    
    (cherry picked from commit ccf53e167ea57716c76ec7ab5bd1223f0c0d47d3)
---
 airflow/models/xcom.py            |  7 +++++-
 tests/conftest.py                 |  2 +-
 tests/models/test_taskinstance.py | 49 ++++++++++++++++++++++++++++++++++++++-
 3 files changed, 55 insertions(+), 3 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 3b43618424..6294fa3d7f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -731,7 +731,12 @@ class LazyXComAccess(collections.abc.Sequence):
         # do the same for count(), but I think it should be performant enough to
         # calculate only that eagerly.
         with self._get_bound_query() as query:
-            statement = query.statement.compile(query.session.get_bind())
+            statement = query.statement.compile(
+                query.session.get_bind(),
+                # This inlines all the values into the SQL string to simplify
+                # cross-process commuinication as much as possible.
+                compile_kwargs={"literal_binds": True},
+            )
             return (str(statement), query.count())
 
     def __setstate__(self, state: Any) -> None:
diff --git a/tests/conftest.py b/tests/conftest.py
index d71d8eb0f0..3fb6d83489 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -291,7 +291,7 @@ def skip_if_not_marked_with_backend(selected_backend, item):
         if selected_backend in backend_names:
             return
     pytest.skip(
-        f"The test is skipped because it does not have the right backend marker "
+        f"The test is skipped because it does not have the right backend marker. "
         f"Only tests marked with pytest.mark.backend('{selected_backend}') are run: {item}"
     )
 
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 17ce74178d..849358cc69 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3592,6 +3592,40 @@ class TestMappedTaskInstanceReceiveValue:
         assert out_lines == ["hello FOO", "goodbye FOO", "hello BAR", "goodbye BAR"]
 
 
+def _get_lazy_xcom_access_expected_sql_lines() -> list[str]:
+    backend = os.environ.get("BACKEND")
+    if backend == "mssql":
+        return [
+            "SELECT xcom.value",
+            "FROM xcom",
+            "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+            "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.[key] = 'xxx'",
+        ]
+    elif backend == "mysql":
+        return [
+            "SELECT xcom.value",
+            "FROM xcom",
+            "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+            "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.`key` = 'xxx'",
+        ]
+    elif backend == "postgres":
+        return [
+            "SELECT xcom.value",
+            "FROM xcom",
+            "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+            "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.key = 'xxx'",
+        ]
+    elif backend == "sqlite":
+        return [
+            "SELECT xcom.value",
+            "FROM xcom",
+            "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+            "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.\"key\" = 'xxx'",
+        ]
+    else:
+        raise RuntimeError(f"unknown backend {backend!r}")
+
+
 def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session):
     with dag_maker(session=session):
         EmptyOperator(task_id="t")
@@ -3599,9 +3633,22 @@ def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session):
     run: DagRun = dag_maker.create_dagrun()
     run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session)
 
-    original = LazyXComAccess.build_from_xcom_query(session.query(XCom))
+    query = session.query(XCom.value).filter_by(
+        dag_id=run.dag_id,
+        run_id=run.run_id,
+        task_id="t",
+        map_index=-1,
+        key="xxx",
+    )
+
+    original = LazyXComAccess.build_from_xcom_query(query)
     processed = pickle.loads(pickle.dumps(original))
 
+    # After the object went through pickling, the underlying ORM query should be
+    # replaced by one backed by a literal SQL string with all variables binded.
+    sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()]
+    assert sql_lines == _get_lazy_xcom_access_expected_sql_lines()
+
     assert len(processed) == 1
     assert list(processed) == [123]