You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by pi...@apache.org on 2023/03/06 21:47:13 UTC
[airflow] 29/37: Resolve all variables in pickled XCom iterator (#28982)
This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit e2cf93305803027d462373ca356bd5badee1128e
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Jan 18 18:19:35 2023 +0800
Resolve all variables in pickled XCom iterator (#28982)
(cherry picked from commit ccf53e167ea57716c76ec7ab5bd1223f0c0d47d3)
---
airflow/models/xcom.py | 7 +++++-
tests/conftest.py | 2 +-
tests/models/test_taskinstance.py | 49 ++++++++++++++++++++++++++++++++++++++-
3 files changed, 55 insertions(+), 3 deletions(-)
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 3b43618424..6294fa3d7f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -731,7 +731,12 @@ class LazyXComAccess(collections.abc.Sequence):
# do the same for count(), but I think it should be performant enough to
# calculate only that eagerly.
with self._get_bound_query() as query:
- statement = query.statement.compile(query.session.get_bind())
+ statement = query.statement.compile(
+ query.session.get_bind(),
+ # This inlines all the values into the SQL string to simplify
+ # cross-process commuinication as much as possible.
+ compile_kwargs={"literal_binds": True},
+ )
return (str(statement), query.count())
def __setstate__(self, state: Any) -> None:
diff --git a/tests/conftest.py b/tests/conftest.py
index d71d8eb0f0..3fb6d83489 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -291,7 +291,7 @@ def skip_if_not_marked_with_backend(selected_backend, item):
if selected_backend in backend_names:
return
pytest.skip(
- f"The test is skipped because it does not have the right backend marker "
+ f"The test is skipped because it does not have the right backend marker. "
f"Only tests marked with pytest.mark.backend('{selected_backend}') are run: {item}"
)
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 17ce74178d..849358cc69 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -3592,6 +3592,40 @@ class TestMappedTaskInstanceReceiveValue:
assert out_lines == ["hello FOO", "goodbye FOO", "hello BAR", "goodbye BAR"]
+def _get_lazy_xcom_access_expected_sql_lines() -> list[str]:
+ backend = os.environ.get("BACKEND")
+ if backend == "mssql":
+ return [
+ "SELECT xcom.value",
+ "FROM xcom",
+ "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+ "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.[key] = 'xxx'",
+ ]
+ elif backend == "mysql":
+ return [
+ "SELECT xcom.value",
+ "FROM xcom",
+ "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+ "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.`key` = 'xxx'",
+ ]
+ elif backend == "postgres":
+ return [
+ "SELECT xcom.value",
+ "FROM xcom",
+ "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+ "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.key = 'xxx'",
+ ]
+ elif backend == "sqlite":
+ return [
+ "SELECT xcom.value",
+ "FROM xcom",
+ "WHERE xcom.dag_id = 'test_dag' AND xcom.run_id = 'test' "
+ "AND xcom.task_id = 't' AND xcom.map_index = -1 AND xcom.\"key\" = 'xxx'",
+ ]
+ else:
+ raise RuntimeError(f"unknown backend {backend!r}")
+
+
def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session):
with dag_maker(session=session):
EmptyOperator(task_id="t")
@@ -3599,9 +3633,22 @@ def test_lazy_xcom_access_does_not_pickle_session(dag_maker, session):
run: DagRun = dag_maker.create_dagrun()
run.get_task_instance("t", session=session).xcom_push("xxx", 123, session=session)
- original = LazyXComAccess.build_from_xcom_query(session.query(XCom))
+ query = session.query(XCom.value).filter_by(
+ dag_id=run.dag_id,
+ run_id=run.run_id,
+ task_id="t",
+ map_index=-1,
+ key="xxx",
+ )
+
+ original = LazyXComAccess.build_from_xcom_query(query)
processed = pickle.loads(pickle.dumps(original))
+ # After the object went through pickling, the underlying ORM query should be
+ # replaced by one backed by a literal SQL string with all variables binded.
+ sql_lines = [line.strip() for line in str(processed._query.statement.compile(None)).splitlines()]
+ assert sql_lines == _get_lazy_xcom_access_expected_sql_lines()
+
assert len(processed) == 1
assert list(processed) == [123]