You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/14 21:07:50 UTC

(superset) 01/01: feat: support for KQL in SQLQuery

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch remove-sqlparse-kusto
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 40f3607af8ad69a51e6b9ea07a8c411f2c712257
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Jan 24 10:02:47 2024 -0500

    feat: support for KQL in SQLQuery
---
 superset/sql_parse.py               | 389 ++++++++++++++++++++++++++++++------
 tests/unit_tests/sql_parse_tests.py |  88 +++++++-
 2 files changed, 408 insertions(+), 69 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58dc210e2b..29a7e3425d 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,15 @@
 
 # pylint: disable=too-many-lines
 
+from __future__ import annotations
+
+import enum
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast, Generic, TypeVar
 
 import sqlglot
 import sqlparse
@@ -138,7 +141,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -159,9 +162,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -185,7 +186,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -193,7 +194,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -211,7 +212,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +236,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -255,7 +256,7 @@ class Table:
 
 def extract_tables_from_statement(
     statement: exp.Expression,
-    dialect: Optional[Dialects],
+    dialect: Dialects | None,
 ) -> set[Table]:
     """
     Extract all table references in a single statement.
@@ -326,82 +327,151 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
     return source.name in ctes_in_scope
 
 
-class SQLScript:
+# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
+# an "internal representation", which is the AST of the SQL statement. For most of the
+# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
+# case: KustoKQL uses a different syntax and there are no Python parsed, so we store the
+# AST as a string (the original query), and manipulate it with regular expressions.
+InternalRepresentation = TypeVar("InternalRepresentation")
+
+# The base type. This helps type checking the `split_query` method correctly, since each
+# derived class has a more specific return type (the class itself). This will no longer
+# be needed once Python 3.11 is the smalled version supported. See SIP 673 for more
+# information: https://peps.python.org/pep-0673/
+TBaseSQLStatement = TypeVar("TBaseSQLStatement")
+
+
+class BaseSQLStatement(Generic[InternalRepresentation]):
     """
-    A SQL script, with 0+ statements.
+    Base class for SQL statements.
+
+    The class can be instantiated with a string representation of the query or, for
+    efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
+    which will split a query in multiple already parsed statements.
+
+    The `engine` parameters comes from the `engine` attribute in a Superset DB engine
+    spec.
     """
 
     def __init__(
         self,
-        query: str,
-        engine: Optional[str] = None,
+        statement: str | InternalRepresentation,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        self._parsed: InternalRepresentation = (
+            self._parse_statement(statement, engine)
+            if isinstance(statement, str)
+            else statement
+        )
+        self.engine = engine
+        self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
 
-        self.statements = [
-            SQLStatement(statement, engine=engine)
-            for statement in parse(query, dialect=dialect)
-            if statement
-        ]
+    @classmethod
+    def split_query(
+        cls: type[TBaseSQLStatement],
+        query: str,
+        engine: str,
+    ) -> list[TBaseSQLStatement]:
+        """
+        Split a query into multiple instantiated statements.
+
+        This is a helper function to split a full SQL query into multiple
+        `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
+        statements within a query.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> InternalRepresentation:
+        """
+        Parse a string containing a single SQL statement, and returns the parsed AST.
+
+        Derived classes should not assume that `statement` contains a single statement,
+        and MUST explicitly validate that. Since this validation is parser dependent the
+        responsibility is left to the children classes.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: InternalRepresentation,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Extract all table references in a given statement.
+        """
+        raise NotImplementedError()
 
     def format(self, comments: bool = True) -> str:
         """
-        Pretty-format the SQL query.
+        Format the statement, optionally ommitting comments.
         """
-        return ";\n".join(statement.format(comments) for statement in self.statements)
+        raise NotImplementedError()
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
-        Return the settings for the SQL query.
+        Return any settings set by the statement.
 
-            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
-            >>> statement.get_settings()
-            {"foo": "'baz'"}
+        For example, for this statement:
+
+            sql> SET foo = 'bar';
 
+        The method should return `{"foo": "'bar'"}`. Note the single quotes.
         """
-        settings: dict[str, str] = {}
-        for statement in self.statements:
-            settings.update(statement.get_settings())
+        raise NotImplementedError()
 
-        return settings
+    def __str__(self) -> str:
+        return self.format()
 
 
-class SQLStatement:
+class SQLStatement(BaseSQLStatement[exp.Expression]):
     """
     A SQL statement.
 
-    This class provides helper methods to manipulate and introspect SQL.
+    This class is used for all engines with dialects that can be parsed using sqlglot.
     """
 
     def __init__(
         self,
-        statement: Union[str, exp.Expression],
-        engine: Optional[str] = None,
+        statement: str | exp.Expression,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
-
-        if isinstance(statement, str):
-            try:
-                self._parsed = self._parse_statement(statement, dialect)
-            except ParseError as ex:
-                raise SupersetParseError(statement, engine) from ex
-        else:
-            self._parsed = statement
+        self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        super().__init__(statement, engine)
 
-        self._dialect = dialect
-        self.tables = extract_tables_from_statement(self._parsed, dialect)
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[SQLStatement]:
+        return [
+            cls(statement, engine)
+            for statement in sqlglot.parse(query, engine)
+            if statement
+        ]
 
-    @staticmethod
+    @classmethod
     def _parse_statement(
-        sql_statement: str,
-        dialect: Optional[Dialects],
+        cls,
+        statement: str,
+        engine: str,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
         """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        # We could parse with `sqlglot.parse_one` to get a single statement, but we need
+        # to verify that the string contains exactly one statement.
         statements = [
             statement
-            for statement in sqlglot.parse(sql_statement, dialect=dialect)
+            for statement in sqlglot.parse(statement, dialect=dialect)
             if statement
         ]
         if len(statements) != 1:
@@ -409,6 +479,18 @@ class SQLStatement:
 
         return statements[0]
 
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: exp.Expression,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Find all referenced tables.
+        """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        return extract_tables_from_statement(parsed, dialect)
+
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL statement.
@@ -416,7 +498,7 @@ class SQLStatement:
         write = Dialect.get_or_raise(self._dialect)
         return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
         Return the settings for the SQL statement.
 
@@ -432,12 +514,189 @@ class SQLStatement:
         }
 
 
+class KustoKQLStatement(BaseSQLStatement[str]):
+    """
+    Special class for Kusto KQL.
+
+    Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
+    like this:
+
+        StormEvents
+        | summarize PropertyDamage = sum(DamageProperty) by State
+        | join kind=innerunique PopulationData on State
+        | project State, PropertyDamagePerCapita = PropertyDamage / Population
+        | sort by PropertyDamagePerCapita
+
+    See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
+    details about it.
+    """
+
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[KustoKQLStatement]:
+        """
+        Split a query at semi-colons.
+
+        Since we don't have a parser, we use a simple state machine based function. See
+        https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
+        for more information.
+        """
+
+        class KQLSplitState(enum.Enum):
+            """
+            State machine for splitting a KQL query.
+
+            The state machine keeps track of whether we're inside a string or not, so we
+            don't split the query in a semi-colon that's part of a string.
+            """
+
+            OUTSIDE_STRING = enum.auto()
+            INSIDE_SINGLE_QUOTED_STRING = enum.auto()
+            INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
+            INSIDE_MULTILINE_STRING = enum.auto()
+
+        statements = []
+        state = KQLSplitState.OUTSIDE_STRING
+        statement_start = 0
+        query = query if query.endswith(";") else query + ";"
+        for i, character in enumerate(query):
+            if state == KQLSplitState.OUTSIDE_STRING:
+                if character == ";":
+                    statements.append(query[statement_start:i])
+                    statement_start = i + 1
+                elif character == "'":
+                    state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                elif character == '"':
+                    state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                elif character == "`" and query[i - 2 : i] == "``":
+                    state = KQLSplitState.INSIDE_MULTILINE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                and character == "'"
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                and character == '"'
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_MULTILINE_STRING
+                and character == "`"
+                and query[i - 2 : i] == "``"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+        return [cls(statement, engine) for statement in statements]
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> str:
+        if engine != "kustokql":
+            raise ValueError(f"Invalid engine: {engine}")
+
+        # TODO: check if it's just a single statement
+
+        return statement.strip()
+
+    @classmethod
+    def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
+        """
+        Extract all tables referenced in the statement.
+
+            StormEvents
+            | where InjuriesDirect + InjuriesIndirect > 50
+            | join (PopulationData) on State
+            | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
+
+        """
+        logger.warning(
+            "Kusto KQL doesn't support table extraction. This means that data access "
+            "roles will not be enforced by Superset in the database."
+        )
+        return set()
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        return self._parsed
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL statement.
+
+            >>> statement = KustoKQLStatement("set querytrace;")
+            >>> statement.get_settings()
+            {"querytrace": True}
+
+        """
+        set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
+        if match := re.match(set_regex, self._parsed, re.IGNORECASE):
+            return {match.group("name"): match.group("value") or True}
+
+        return {}
+
+
+class SQLScript:
+    """
+    A SQL script, with 0+ statements.
+    """
+
+    # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
+    # adds a lot of complexity to Superset, so we should avoid adding new engines to
+    # this data structure.
+    special_engines = {
+        "kustokql": KustoKQLStatement,
+    }
+
+    def __init__(
+        self,
+        query: str,
+        engine: str,
+    ):
+        statement_class = self.special_engines.get(engine, SQLStatement)
+        self.statements = statement_class.split_query(query, engine)
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in self.statements)
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str | bool] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
 class ParsedQuery:
     def __init__(
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -446,7 +705,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -550,7 +809,7 @@ class ParsedQuery:
         return source.name in ctes_in_scope
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -631,7 +890,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -695,7 +954,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -731,7 +990,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -891,8 +1150,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -938,7 +1197,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -954,7 +1213,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1030,7 +1289,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -1041,7 +1300,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1175,7 +1434,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f097fd1df3..8bdd1ee4de 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -33,6 +33,7 @@ from superset.sql_parse import (
     has_table_query,
     insert_rls_as_subquery,
     insert_rls_in_predicate,
+    KustoKQLStatement,
     ParsedQuery,
     sanitize_clause,
     SQLScript,
@@ -1858,21 +1859,31 @@ def test_sqlquery() -> None:
     """
     Test the `SQLScript` class.
     """
-    script = SQLScript("SELECT 1; SELECT 2;")
+    script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
 
     assert len(script.statements) == 2
     assert script.format() == "SELECT\n  1;\nSELECT\n  2"
     assert script.statements[0].format() == "SELECT\n  1"
 
-    script = SQLScript("SET a=1; SET a=2; SELECT 3;")
+    script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
     assert script.get_settings() == {"a": "2"}
 
+    query = SQLScript(
+        """set querytrace;
+Events | take 100""",
+        "kustokql",
+    )
+    assert query.get_settings() == {"querytrace": True}
+
 
 def test_sqlstatement() -> None:
     """
     Test the `SQLStatement` class.
     """
-    statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
+    statement = SQLStatement(
+        "SELECT * FROM table1 UNION ALL SELECT * FROM table2",
+        "sqlite",
+    )
 
     assert statement.tables == {
         Table(table="table1", schema=None, catalog=None),
@@ -1883,5 +1894,74 @@ def test_sqlstatement() -> None:
         == "SELECT\n  *\nFROM table1\nUNION ALL\nSELECT\n  *\nFROM table2"
     )
 
-    statement = SQLStatement("SET a=1")
+    statement = SQLStatement("SET a=1", "sqlite")
     assert statement.get_settings() == {"a": "1"}
+
+
+def test_kustokqlstatement() -> None:
+    """
+    Test the `KustoKQLStatement` class.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 4
+
+    statements = KustoKQLStatement.split_query(
+        """
+print program = ```
+  public class Program {
+    public static void Main() {
+      System.Console.WriteLine("Hello!");
+    }
+  }```
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 1
+
+    statements = KustoKQLStatement.split_query(
+        """
+set querytrace;
+Events | take 100
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 2
+    assert statements[0].format() == "set querytrace"
+    assert statements[1].format() == "Events | take 100"
+
+
+@pytest.mark.parametrize(
+    "kql,statements",
+    [
+        ('print banner=strcat("Hello", ", ", "World!")', 1),
+        (r"print 'O\'Malley\'s'", 1),
+        (r"print 'O\'Mal;ley\'s'", 1),
+        ("print ```foo;\nbar;\nbaz;```\n", 1),
+    ],
+)
+def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
+    assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements