You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/22 16:48:26 UTC

(superset) branch master updated: feat: support for KQL in `SQLScript` (#27522)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new cd7972d05b feat: support for KQL in `SQLScript` (#27522)
cd7972d05b is described below

commit cd7972d05b2ed0ebb110d01965df7e5e54f9ee15
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Fri Mar 22 12:48:20 2024 -0400

    feat: support for KQL in `SQLScript` (#27522)
---
 superset/sql_parse.py               | 363 +++++++++++++++++++++++++++++++-----
 tests/unit_tests/sql_parse_tests.py | 154 ++++++++++++++-
 2 files changed, 468 insertions(+), 49 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index f721f456d0..11e4279aa2 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -19,12 +19,13 @@
 
 from __future__ import annotations
 
+import enum
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast
+from typing import Any, cast, Generic, TypeVar
 from unittest.mock import Mock
 
 import sqlglot
@@ -334,89 +335,175 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
     return source.name in ctes_in_scope
 
 
-class SQLScript:
+# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
+# an "internal representation", which is the AST of the SQL statement. For most of the
+# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
+# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
+# store the AST as a string (the original query), and manipulate it with regular
+# expressions.
+InternalRepresentation = TypeVar("InternalRepresentation")
+
+# The base type. This helps type checking the `split_query` method correctly, since each
+# derived class has a more specific return type (the class itself). This will no longer
+# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
+# information: https://peps.python.org/pep-0673/
+TBaseSQLStatement = TypeVar("TBaseSQLStatement")  # pylint: disable=invalid-name
+
+
+class BaseSQLStatement(Generic[InternalRepresentation]):
     """
-    A SQL script, with 0+ statements.
+    Base class for SQL statements.
+
+    The class can be instantiated with a string representation of the query or, for
+    efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
+    which will split a query in multiple already parsed statements.
+
+    The `engine` parameters comes from the `engine` attribute in a Superset DB engine
+    spec.
     """
 
     def __init__(
         self,
-        query: str,
-        engine: str | None = None,
+        statement: str | InternalRepresentation,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        self._parsed: InternalRepresentation = (
+            self._parse_statement(statement, engine)
+            if isinstance(statement, str)
+            else statement
+        )
+        self.engine = engine
+        self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
 
-        self.statements = [
-            SQLStatement(statement, engine=engine)
-            for statement in parse(query, dialect=dialect)
-            if statement
-        ]
+    @classmethod
+    def split_query(
+        cls: type[TBaseSQLStatement],
+        query: str,
+        engine: str,
+    ) -> list[TBaseSQLStatement]:
+        """
+        Split a query into multiple instantiated statements.
+
+        This is a helper function to split a full SQL query into multiple
+        `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
+        statements within a query.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> InternalRepresentation:
+        """
+        Parse a string containing a single SQL statement, and returns the parsed AST.
+
+        Derived classes should not assume that `statement` contains a single statement,
+        and MUST explicitly validate that. Since this validation is parser dependent the
+        responsibility is left to the children classes.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: InternalRepresentation,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Extract all table references in a given statement.
+        """
+        raise NotImplementedError()
 
     def format(self, comments: bool = True) -> str:
         """
-        Pretty-format the SQL query.
+        Format the statement, optionally ommitting comments.
         """
-        return ";\n".join(statement.format(comments) for statement in self.statements)
+        raise NotImplementedError()
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
-        Return the settings for the SQL query.
+        Return any settings set by the statement.
 
-            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
-            >>> statement.get_settings()
-            {"foo": "'baz'"}
+        For example, for this statement:
 
+            sql> SET foo = 'bar';
+
+        The method should return `{"foo": "'bar'"}`. Note the single quotes.
         """
-        settings: dict[str, str] = {}
-        for statement in self.statements:
-            settings.update(statement.get_settings())
+        raise NotImplementedError()
 
-        return settings
+    def __str__(self) -> str:
+        return self.format()
 
 
-class SQLStatement:
+class SQLStatement(BaseSQLStatement[exp.Expression]):
     """
     A SQL statement.
 
-    This class provides helper methods to manipulate and introspect SQL.
+    This class is used for all engines with dialects that can be parsed using sqlglot.
     """
 
     def __init__(
         self,
         statement: str | exp.Expression,
-        engine: str | None = None,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        self._dialect = SQLGLOT_DIALECTS.get(engine)
+        super().__init__(statement, engine)
 
-        if isinstance(statement, str):
-            try:
-                self._parsed = self._parse_statement(statement, dialect)
-            except ParseError as ex:
-                raise SupersetParseError(statement, engine) from ex
-        else:
-            self._parsed = statement
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[SQLStatement]:
+        dialect = SQLGLOT_DIALECTS.get(engine)
 
-        self._dialect = dialect
-        self.tables = extract_tables_from_statement(self._parsed, dialect)
+        try:
+            statements = sqlglot.parse(query, dialect=dialect)
+        except sqlglot.errors.ParseError as ex:
+            raise SupersetParseError("Unable to split query") from ex
 
-    @staticmethod
+        return [cls(statement, engine) for statement in statements if statement]
+
+    @classmethod
     def _parse_statement(
-        sql_statement: str,
-        dialect: Dialects | None,
+        cls,
+        statement: str,
+        engine: str,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
         """
-        statements = [
-            statement
-            for statement in sqlglot.parse(sql_statement, dialect=dialect)
-            if statement
-        ]
+        dialect = SQLGLOT_DIALECTS.get(engine)
+
+        # We could parse with `sqlglot.parse_one` to get a single statement, but we need
+        # to verify that the string contains exactly one statement.
+        try:
+            statements = sqlglot.parse(statement, dialect=dialect)
+        except sqlglot.errors.ParseError as ex:
+            raise SupersetParseError("Unable to split query") from ex
+
+        statements = [statement for statement in statements if statement]
         if len(statements) != 1:
-            raise ValueError("SQLStatement should have exactly one statement")
+            raise SupersetParseError("SQLStatement should have exactly one statement")
 
         return statements[0]
 
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: exp.Expression,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Find all referenced tables.
+        """
+        dialect = SQLGLOT_DIALECTS.get(engine)
+        return extract_tables_from_statement(parsed, dialect)
+
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL statement.
@@ -424,7 +511,7 @@ class SQLStatement:
         write = Dialect.get_or_raise(self._dialect)
         return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
         Return the settings for the SQL statement.
 
@@ -440,6 +527,192 @@ class SQLStatement:
         }
 
 
+class KQLSplitState(enum.Enum):
+    """
+    State machine for splitting a KQL query.
+
+    The state machine keeps track of whether we're inside a string or not, so we
+    don't split the query in a semi-colon that's part of a string.
+    """
+
+    OUTSIDE_STRING = enum.auto()
+    INSIDE_SINGLE_QUOTED_STRING = enum.auto()
+    INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
+    INSIDE_MULTILINE_STRING = enum.auto()
+
+
+def split_kql(kql: str) -> list[str]:
+    """
+    Custom function for splitting KQL statements.
+    """
+    statements = []
+    state = KQLSplitState.OUTSIDE_STRING
+    statement_start = 0
+    query = kql if kql.endswith(";") else kql + ";"
+    for i, character in enumerate(query):
+        if state == KQLSplitState.OUTSIDE_STRING:
+            if character == ";":
+                statements.append(query[statement_start:i])
+                statement_start = i + 1
+            elif character == "'":
+                state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+            elif character == '"':
+                state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+            elif character == "`" and query[i - 2 : i] == "``":
+                state = KQLSplitState.INSIDE_MULTILINE_STRING
+
+        elif (
+            state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+            and character == "'"
+            and query[i - 1] != "\\"
+        ):
+            state = KQLSplitState.OUTSIDE_STRING
+
+        elif (
+            state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+            and character == '"'
+            and query[i - 1] != "\\"
+        ):
+            state = KQLSplitState.OUTSIDE_STRING
+
+        elif (
+            state == KQLSplitState.INSIDE_MULTILINE_STRING
+            and character == "`"
+            and query[i - 2 : i] == "``"
+        ):
+            state = KQLSplitState.OUTSIDE_STRING
+
+    return statements
+
+
+class KustoKQLStatement(BaseSQLStatement[str]):
+    """
+    Special class for Kusto KQL.
+
+    Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
+    like this:
+
+        StormEvents
+        | summarize PropertyDamage = sum(DamageProperty) by State
+        | join kind=innerunique PopulationData on State
+        | project State, PropertyDamagePerCapita = PropertyDamage / Population
+        | sort by PropertyDamagePerCapita
+
+    See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
+    details about it.
+    """
+
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[KustoKQLStatement]:
+        """
+        Split a query at semi-colons.
+
+        Since we don't have a parser, we use a simple state machine based function. See
+        https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
+        for more information.
+        """
+        return [cls(statement, engine) for statement in split_kql(query)]
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> str:
+        if engine != "kustokql":
+            raise SupersetParseError(f"Invalid engine: {engine}")
+
+        statements = split_kql(statement)
+        if len(statements) != 1:
+            raise SupersetParseError("SQLStatement should have exactly one statement")
+
+        return statements[0].strip()
+
+    @classmethod
+    def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
+        """
+        Extract all tables referenced in the statement.
+
+            StormEvents
+            | where InjuriesDirect + InjuriesIndirect > 50
+            | join (PopulationData) on State
+            | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
+
+        """
+        logger.warning(
+            "Kusto KQL doesn't support table extraction. This means that data access "
+            "roles will not be enforced by Superset in the database."
+        )
+        return set()
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        return self._parsed
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL statement.
+
+            >>> statement = KustoKQLStatement("set querytrace;")
+            >>> statement.get_settings()
+            {"querytrace": True}
+
+        """
+        set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
+        if match := re.match(set_regex, self._parsed, re.IGNORECASE):
+            return {match.group("name"): match.group("value") or True}
+
+        return {}
+
+
+class SQLScript:
+    """
+    A SQL script, with 0+ statements.
+    """
+
+    # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
+    # adds a lot of complexity to Superset, so we should avoid adding new engines to
+    # this data structure.
+    special_engines = {
+        "kustokql": KustoKQLStatement,
+    }
+
+    def __init__(
+        self,
+        query: str,
+        engine: str,
+    ):
+        statement_class = self.special_engines.get(engine, SQLStatement)
+        self.statements = statement_class.split_query(query, engine)
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in self.statements)
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str | bool] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
 class ParsedQuery:
     def __init__(
         self,
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index aa4171e763..79958b0743 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -37,8 +37,10 @@ from superset.sql_parse import (
     has_table_query,
     insert_rls_as_subquery,
     insert_rls_in_predicate,
+    KustoKQLStatement,
     ParsedQuery,
     sanitize_clause,
+    split_kql,
     SQLScript,
     SQLStatement,
     strip_comments_from_sql,
@@ -1883,21 +1885,31 @@ def test_sqlquery() -> None:
     """
     Test the `SQLScript` class.
     """
-    script = SQLScript("SELECT 1; SELECT 2;")
+    script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
 
     assert len(script.statements) == 2
     assert script.format() == "SELECT\n  1;\nSELECT\n  2"
     assert script.statements[0].format() == "SELECT\n  1"
 
-    script = SQLScript("SET a=1; SET a=2; SELECT 3;")
+    script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
     assert script.get_settings() == {"a": "2"}
 
+    query = SQLScript(
+        """set querytrace;
+Events | take 100""",
+        "kustokql",
+    )
+    assert query.get_settings() == {"querytrace": True}
+
 
 def test_sqlstatement() -> None:
     """
     Test the `SQLStatement` class.
     """
-    statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
+    statement = SQLStatement(
+        "SELECT * FROM table1 UNION ALL SELECT * FROM table2",
+        "sqlite",
+    )
 
     assert statement.tables == {
         Table(table="table1", schema=None, catalog=None),
@@ -1908,7 +1920,7 @@ def test_sqlstatement() -> None:
         == "SELECT\n  *\nFROM table1\nUNION ALL\nSELECT\n  *\nFROM table2"
     )
 
-    statement = SQLStatement("SET a=1")
+    statement = SQLStatement("SET a=1", "sqlite")
     assert statement.get_settings() == {"a": "1"}
 
 
@@ -1950,3 +1962,137 @@ def test_extract_tables_from_jinja_sql(
         extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
         == expected
     )
+
+
+def test_kustokqlstatement_split_query() -> None:
+    """
+    Test the `KustoKQLStatement` split method.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 4
+
+
+def test_kustokqlstatement_with_program() -> None:
+    """
+    Test the `KustoKQLStatement` split method when the KQL has a program.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+print program = ```
+  public class Program {
+    public static void Main() {
+      System.Console.WriteLine("Hello!");
+    }
+  }```
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 1
+
+
+def test_kustokqlstatement_with_set() -> None:
+    """
+    Test the `KustoKQLStatement` split method when the KQL has a set command.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+set querytrace;
+Events | take 100
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 2
+    assert statements[0].format() == "set querytrace"
+    assert statements[1].format() == "Events | take 100"
+
+
+@pytest.mark.parametrize(
+    "kql,statements",
+    [
+        ('print banner=strcat("Hello", ", ", "World!")', 1),
+        (r"print 'O\'Malley\'s'", 1),
+        (r"print 'O\'Mal;ley\'s'", 1),
+        ("print ```foo;\nbar;\nbaz;```\n", 1),
+    ],
+)
+def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
+    assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements
+
+
+def test_split_kql() -> None:
+    """
+    Test the `split_kql` function.
+    """
+    kql = """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+    """
+    assert split_kql(kql) == [
+        """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day""",
+        """
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp)""",
+        """
+let cachedResult = materialize(materializedScope)""",
+        """
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+    """,
+    ]