You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/11/06 08:07:20 UTC
(airflow) branch main updated: Fix query in `get_dag_by_pickle` util function (#35339)
This is an automated email from the ASF dual-hosted git repository.
taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d519c648dd Fix query in `get_dag_by_pickle` util function (#35339)
d519c648dd is described below
commit d519c648dd21e8f2392b7c85c8e1f2cb4ce87693
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Nov 6 12:07:12 2023 +0400
Fix query in `get_dag_by_pickle` util function (#35339)
---
airflow/utils/cli.py | 2 +-
tests/utils/test_cli_util.py | 19 ++++++++++++++++++-
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index 4343c797d9..9134d16855 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -269,7 +269,7 @@ def get_dag_by_pickle(pickle_id: int, session: Session = NEW_SESSION) -> DAG:
"""Fetch DAG from the database using pickling."""
from airflow.models import DagPickle
- dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id)).first()
+ dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id).limit(1))
if not dag_pickle:
raise AirflowException(f"pickle_id could not be found in DagPickle.id list: {pickle_id}")
pickle_dag = dag_pickle.pickle
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index 5c0edcd2a2..cad05ab3dc 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -34,7 +34,7 @@ from airflow import settings
from airflow.exceptions import AirflowException
from airflow.models.log import Log
from airflow.utils import cli, cli_action_loggers, timezone
-from airflow.utils.cli import _search_for_dag_file
+from airflow.utils.cli import _search_for_dag_file, get_dag_by_pickle
repo_root = Path(airflow.__file__).parent.parent
@@ -171,6 +171,23 @@ class TestCliUtil:
pid, _, _, _ = cli.setup_locations(process=process_name)
assert pid == default_pid_path
+ @pytest.mark.db_test
+ def test_get_dag_by_pickle(self, session, dag_maker):
+ from airflow.models.dagpickle import DagPickle
+
+ with dag_maker(dag_id="test_get_dag_by_pickle") as dag:
+ pass
+
+ dp = DagPickle(dag=dag)
+ session.add(dp)
+ session.commit()
+
+ dp_from_db = get_dag_by_pickle(pickle_id=dp.id, session=session)
+ assert dp_from_db.dag_id == "test_get_dag_by_pickle"
+
+ with pytest.raises(AirflowException, match="pickle_id could not be found .* -42"):
+ get_dag_by_pickle(pickle_id=-42, session=session)
+
@contextmanager
def fail_action_logger_callback():