You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/11/06 08:07:20 UTC

(airflow) branch main updated: Fix query in `get_dag_by_pickle` util function (#35339)

This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d519c648dd Fix query in `get_dag_by_pickle` util function (#35339)
d519c648dd is described below

commit d519c648dd21e8f2392b7c85c8e1f2cb4ce87693
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Nov 6 12:07:12 2023 +0400

    Fix query in `get_dag_by_pickle` util function (#35339)
---
 airflow/utils/cli.py         |  2 +-
 tests/utils/test_cli_util.py | 19 ++++++++++++++++++-
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py
index 4343c797d9..9134d16855 100644
--- a/airflow/utils/cli.py
+++ b/airflow/utils/cli.py
@@ -269,7 +269,7 @@ def get_dag_by_pickle(pickle_id: int, session: Session = NEW_SESSION) -> DAG:
     """Fetch DAG from the database using pickling."""
     from airflow.models import DagPickle
 
-    dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id)).first()
+    dag_pickle = session.scalar(select(DagPickle).where(DagPickle.id == pickle_id).limit(1))
     if not dag_pickle:
         raise AirflowException(f"pickle_id could not be found in DagPickle.id list: {pickle_id}")
     pickle_dag = dag_pickle.pickle
diff --git a/tests/utils/test_cli_util.py b/tests/utils/test_cli_util.py
index 5c0edcd2a2..cad05ab3dc 100644
--- a/tests/utils/test_cli_util.py
+++ b/tests/utils/test_cli_util.py
@@ -34,7 +34,7 @@ from airflow import settings
 from airflow.exceptions import AirflowException
 from airflow.models.log import Log
 from airflow.utils import cli, cli_action_loggers, timezone
-from airflow.utils.cli import _search_for_dag_file
+from airflow.utils.cli import _search_for_dag_file, get_dag_by_pickle
 
 repo_root = Path(airflow.__file__).parent.parent
 
@@ -171,6 +171,23 @@ class TestCliUtil:
         pid, _, _, _ = cli.setup_locations(process=process_name)
         assert pid == default_pid_path
 
+    @pytest.mark.db_test
+    def test_get_dag_by_pickle(self, session, dag_maker):
+        from airflow.models.dagpickle import DagPickle
+
+        with dag_maker(dag_id="test_get_dag_by_pickle") as dag:
+            pass
+
+        dp = DagPickle(dag=dag)
+        session.add(dp)
+        session.commit()
+
+        dp_from_db = get_dag_by_pickle(pickle_id=dp.id, session=session)
+        assert dp_from_db.dag_id == "test_get_dag_by_pickle"
+
+        with pytest.raises(AirflowException, match="pickle_id could not be found .* -42"):
+            get_dag_by_pickle(pickle_id=-42, session=session)
+
 
 @contextmanager
 def fail_action_logger_callback():