You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/09/15 21:20:35 UTC

[airflow] 01/07: Fix ``DagRunState`` enum query for ``MySQLdb`` driver (#17886)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 66dcbf429aee0316e206a1d6ded089580fc94ddf
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Mon Aug 30 19:05:08 2021 +0100

    Fix ``DagRunState`` enum query for ``MySQLdb`` driver (#17886)
    
    same as https://github.com/apache/airflow/pull/13278 but for `DagRunState` introduced in https://github.com/apache/airflow/pull/16854
    
    closes https://github.com/apache/airflow/issues/17879
    
    (cherry picked from commit a3f9c690aa80d12ff1d5c42eaaff4fced07b9429)
---
 airflow/utils/state.py    |  3 +++
 tests/utils/test_state.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 61 insertions(+)

diff --git a/airflow/utils/state.py b/airflow/utils/state.py
index e95b409..00e518e 100644
--- a/airflow/utils/state.py
+++ b/airflow/utils/state.py
@@ -65,6 +65,9 @@ class DagRunState(str, Enum):
     SUCCESS = "success"
     FAILED = "failed"
 
+    def __str__(self) -> str:
+        return self.value
+
 
 class State:
     """
diff --git a/tests/utils/test_state.py b/tests/utils/test_state.py
new file mode 100644
index 0000000..88e4a0d
--- /dev/null
+++ b/tests/utils/test_state.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow.models import DAG
+from airflow.models.dagrun import DagRun
+from airflow.utils.session import create_session
+from airflow.utils.state import DagRunState
+from airflow.utils.types import DagRunType
+from tests.models import DEFAULT_DATE
+
+
+def test_dagrun_state_enum_escape():
+    """
+    Make sure DagRunState.QUEUED is converted to string 'queued' when
+    referenced in DB query
+    """
+    with create_session() as session:
+        dag = DAG(dag_id='test_dagrun_state_enum_escape', start_date=DEFAULT_DATE)
+        dag.create_dagrun(
+            run_type=DagRunType.SCHEDULED,
+            state=DagRunState.QUEUED,
+            execution_date=DEFAULT_DATE,
+            start_date=DEFAULT_DATE,
+            session=session,
+        )
+
+        query = session.query(DagRun.dag_id, DagRun.state, DagRun.run_type,).filter(
+            DagRun.dag_id == dag.dag_id,
+            # make sure enum value can be used in filter queries
+            DagRun.state == DagRunState.QUEUED,
+        )
+        assert str(query.statement.compile(compile_kwargs={"literal_binds": True})) == (
+            'SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n'
+            'FROM dag_run \n'
+            "WHERE dag_run.dag_id = 'test_dagrun_state_enum_escape' AND dag_run.state = 'queued'"
+        )
+
+        rows = query.all()
+        assert len(rows) == 1
+        assert rows[0].dag_id == dag.dag_id
+        # make sure value in db is stored as `queued`, not `DagRunType.QUEUED`
+        assert rows[0].state == 'queued'
+
+        session.rollback()