You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/10/20 22:54:43 UTC

[superset] 05/07: fix: CTE queries with non-SELECT statements (#25014)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch 2.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f16dcd4fed8b401fd99cf903189d1f74758421e4
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Sat Aug 19 15:49:15 2023 +0100

    fix: CTE queries with non-SELECT statements (#25014)
---
 superset/sql_parse.py               |  55 +++++++++++++++++++
 tests/unit_tests/sql_parse_tests.py | 103 ++++++++++++++++++++++++++++++++++++
 2 files changed, 158 insertions(+)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index a3c1af87b0..216b4e8825 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -216,9 +216,53 @@ class ParsedQuery:
     def limit(self) -> Optional[int]:
         return self._limit
 
+    def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
+        if "with" not in parsed:
+            return []
+        return parsed["with"].get("cte_tables", [])
+
+    def _check_cte_is_select(self, oxide_parse: list[dict[str, Any]]) -> bool:
+        """
+        Check if a oxide parsed CTE contains only SELECT statements
+
+        :param oxide_parse: parsed CTE
+        :return: True if CTE is a SELECT statement
+        """
+        for query in oxide_parse:
+            parsed_query = query["Query"]
+            cte_tables = self._get_cte_tables(parsed_query)
+            for cte_table in cte_tables:
+                is_select = all(
+                    key == "Select" for key in cte_table["query"]["body"].keys()
+                )
+                if not is_select:
+                    return False
+        return True
+
     def is_select(self) -> bool:
         # make sure we strip comments; prevents a bug with coments in the CTE
         parsed = sqlparse.parse(self.strip_comments())
+
+        # Check if this is a CTE
+        if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE:
+            if sqloxide_parse is not None:
+                try:
+                    if not self._check_cte_is_select(
+                        sqloxide_parse(self.strip_comments(), dialect="ansi")
+                    ):
+                        return False
+                except ValueError:
+                    # sqloxide was not able to parse the query, so let's continue with
+                    # sqlparse
+                    pass
+            inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or []
+            # Check if the inner CTE is a not a SELECT
+            if any(token.ttype == DDL for token in inner_cte) or any(
+                token.ttype == DML and token.normalized != "SELECT"
+                for token in inner_cte
+            ):
+                return False
+
         if parsed[0].get_type() == "SELECT":
             return True
 
@@ -240,6 +284,17 @@ class ParsedQuery:
             token.ttype == DML and token.value == "SELECT" for token in parsed[0]
         )
 
+    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+        for token in tokens:
+            if self._is_identifier(token):
+                for identifier_token in token.tokens:
+                    if (
+                        isinstance(identifier_token, Parenthesis)
+                        and identifier_token.is_group
+                    ):
+                        return identifier_token.tokens
+        return None
+
     def is_valid_ctas(self) -> bool:
         parsed = sqlparse.parse(self.strip_comments())
         return parsed[-1].get_type() == "SELECT"
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index d6939fa080..09eeabce2f 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -1008,6 +1008,109 @@ FROM foo f"""
     assert sql.is_select()
 
 
+def test_cte_is_select_lowercase() -> None:
+    """
+    Some CTEs with lowercase select are not correctly identified as SELECTS.
+    """
+    sql = ParsedQuery(
+        """WITH foo AS(
+select
+  FLOOR(__time TO WEEK) AS "week",
+  name,
+  COUNT(DISTINCT user_id) AS "unique_users"
+FROM "druid"."my_table"
+GROUP BY 1,2
+)
+select
+  f.week,
+  f.name,
+  f.unique_users
+FROM foo f"""
+    )
+    assert sql.is_select()
+
+
+def test_cte_insert_is_not_select() -> None:
+    """
+    Some CTEs with lowercase select are not correctly identified as SELECTS.
+    """
+    sql = ParsedQuery(
+        """WITH foo AS(
+        INSERT INTO foo (id) VALUES (1) RETURNING 1
+        ) select * FROM foo f"""
+    )
+    assert sql.is_select() is False
+
+
+def test_cte_delete_is_not_select() -> None:
+    """
+    Some CTEs with lowercase select are not correctly identified as SELECTS.
+    """
+    sql = ParsedQuery(
+        """WITH foo AS(
+        DELETE FROM foo RETURNING *
+        ) select * FROM foo f"""
+    )
+    assert sql.is_select() is False
+
+
+def test_cte_is_not_select_lowercase() -> None:
+    """
+    Some CTEs with lowercase select are not correctly identified as SELECTS.
+    """
+    sql = ParsedQuery(
+        """WITH foo AS(
+        insert into foo (id) values (1) RETURNING 1
+        ) select * FROM foo f"""
+    )
+    assert sql.is_select() is False
+
+
+def test_cte_with_multiple_selects() -> None:
+    sql = ParsedQuery(
+        "WITH a AS ( select * from foo1 ), b as (select * from foo2) SELECT * FROM a;"
+    )
+    assert sql.is_select()
+
+
+def test_cte_with_multiple_with_non_select() -> None:
+    sql = ParsedQuery(
+        """WITH a AS (
+        select * from foo1
+        ), b as (
+        update foo2 set id=2
+        ) SELECT * FROM a"""
+    )
+    assert sql.is_select() is False
+    sql = ParsedQuery(
+        """WITH a AS (
+         update foo2 set name=2
+         ),
+        b as (
+        select * from foo1
+        ) SELECT * FROM a"""
+    )
+    assert sql.is_select() is False
+    sql = ParsedQuery(
+        """WITH a AS (
+         update foo2 set name=2
+         ),
+        b as (
+        update foo1 set name=2
+        ) SELECT * FROM a"""
+    )
+    assert sql.is_select() is False
+    sql = ParsedQuery(
+        """WITH a AS (
+        INSERT INTO foo (id) VALUES (1)
+        ),
+        b as (
+        select 1
+        ) SELECT * FROM a"""
+    )
+    assert sql.is_select() is False
+
+
 def test_unknown_select() -> None:
     """
     Test that `is_select` works when sqlparse fails to identify the type.