You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/14 21:07:50 UTC
(superset) 01/01: feat: support for KQL in SQLQuery
This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch remove-sqlparse-kusto
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 40f3607af8ad69a51e6b9ea07a8c411f2c712257
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Jan 24 10:02:47 2024 -0500
feat: support for KQL in SQLQuery
---
superset/sql_parse.py | 389 ++++++++++++++++++++++++++++++------
tests/unit_tests/sql_parse_tests.py | 88 +++++++-
2 files changed, 408 insertions(+), 69 deletions(-)
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58dc210e2b..29a7e3425d 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,15 @@
# pylint: disable=too-many-lines
+from __future__ import annotations
+
+import enum
import logging
import re
import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast, Generic, TypeVar
import sqlglot
import sqlparse
@@ -138,7 +141,7 @@ class CtasMethod(StrEnum):
VIEW = "VIEW"
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
"""
Extract limit clause from SQL statement.
@@ -159,9 +162,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
-def extract_top_from_query(
- statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
"""
Extract top clause value from SQL statement.
@@ -185,7 +186,7 @@ def extract_top_from_query(
return top
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
"""
parse the SQL and return the CTE and rest of the block to the caller
@@ -193,7 +194,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
:return: CTE and remainder block to the caller
"""
- cte: Optional[str] = None
+ cte: str | None = None
remainder = sql
stmt = sqlparse.parse(sql)[0]
@@ -211,7 +212,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +236,8 @@ class Table:
"""
table: str
- schema: Optional[str] = None
- catalog: Optional[str] = None
+ schema: str | None = None
+ catalog: str | None = None
def __str__(self) -> str:
"""
@@ -255,7 +256,7 @@ class Table:
def extract_tables_from_statement(
statement: exp.Expression,
- dialect: Optional[Dialects],
+ dialect: Dialects | None,
) -> set[Table]:
"""
Extract all table references in a single statement.
@@ -326,82 +327,151 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
return source.name in ctes_in_scope
-class SQLScript:
+# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
+# an "internal representation", which is the AST of the SQL statement. For most of the
+# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
+# case: KustoKQL uses a different syntax and there are no Python parsed, so we store the
+# AST as a string (the original query), and manipulate it with regular expressions.
+InternalRepresentation = TypeVar("InternalRepresentation")
+
+# The base type. This helps type checking the `split_query` method correctly, since each
+# derived class has a more specific return type (the class itself). This will no longer
+# be needed once Python 3.11 is the smalled version supported. See SIP 673 for more
+# information: https://peps.python.org/pep-0673/
+TBaseSQLStatement = TypeVar("TBaseSQLStatement")
+
+
+class BaseSQLStatement(Generic[InternalRepresentation]):
"""
- A SQL script, with 0+ statements.
+ Base class for SQL statements.
+
+ The class can be instantiated with a string representation of the query or, for
+ efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
+ which will split a query in multiple already parsed statements.
+
+ The `engine` parameters comes from the `engine` attribute in a Superset DB engine
+ spec.
"""
def __init__(
self,
- query: str,
- engine: Optional[str] = None,
+ statement: str | InternalRepresentation,
+ engine: str,
):
- dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+ self._parsed: InternalRepresentation = (
+ self._parse_statement(statement, engine)
+ if isinstance(statement, str)
+ else statement
+ )
+ self.engine = engine
+ self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
- self.statements = [
- SQLStatement(statement, engine=engine)
- for statement in parse(query, dialect=dialect)
- if statement
- ]
+ @classmethod
+ def split_query(
+ cls: type[TBaseSQLStatement],
+ query: str,
+ engine: str,
+ ) -> list[TBaseSQLStatement]:
+ """
+ Split a query into multiple instantiated statements.
+
+ This is a helper function to split a full SQL query into multiple
+ `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
+ statements within a query.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _parse_statement(
+ cls,
+ statement: str,
+ engine: str,
+ ) -> InternalRepresentation:
+ """
+ Parse a string containing a single SQL statement, and returns the parsed AST.
+
+ Derived classes should not assume that `statement` contains a single statement,
+ and MUST explicitly validate that. Since this validation is parser dependent the
+ responsibility is left to the children classes.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _extract_tables_from_statement(
+ cls,
+ parsed: InternalRepresentation,
+ engine: str,
+ ) -> set[Table]:
+ """
+ Extract all table references in a given statement.
+ """
+ raise NotImplementedError()
def format(self, comments: bool = True) -> str:
"""
- Pretty-format the SQL query.
+ Format the statement, optionally ommitting comments.
"""
- return ";\n".join(statement.format(comments) for statement in self.statements)
+ raise NotImplementedError()
- def get_settings(self) -> dict[str, str]:
+ def get_settings(self) -> dict[str, str | bool]:
"""
- Return the settings for the SQL query.
+ Return any settings set by the statement.
- >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
- >>> statement.get_settings()
- {"foo": "'baz'"}
+ For example, for this statement:
+
+ sql> SET foo = 'bar';
+ The method should return `{"foo": "'bar'"}`. Note the single quotes.
"""
- settings: dict[str, str] = {}
- for statement in self.statements:
- settings.update(statement.get_settings())
+ raise NotImplementedError()
- return settings
+ def __str__(self) -> str:
+ return self.format()
-class SQLStatement:
+class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
A SQL statement.
- This class provides helper methods to manipulate and introspect SQL.
+ This class is used for all engines with dialects that can be parsed using sqlglot.
"""
def __init__(
self,
- statement: Union[str, exp.Expression],
- engine: Optional[str] = None,
+ statement: str | exp.Expression,
+ engine: str,
):
- dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
-
- if isinstance(statement, str):
- try:
- self._parsed = self._parse_statement(statement, dialect)
- except ParseError as ex:
- raise SupersetParseError(statement, engine) from ex
- else:
- self._parsed = statement
+ self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+ super().__init__(statement, engine)
- self._dialect = dialect
- self.tables = extract_tables_from_statement(self._parsed, dialect)
+ @classmethod
+ def split_query(
+ cls,
+ query: str,
+ engine: str,
+ ) -> list[SQLStatement]:
+ return [
+ cls(statement, engine)
+ for statement in sqlglot.parse(query, engine)
+ if statement
+ ]
- @staticmethod
+ @classmethod
def _parse_statement(
- sql_statement: str,
- dialect: Optional[Dialects],
+ cls,
+ statement: str,
+ engine: str,
) -> exp.Expression:
"""
Parse a single SQL statement.
"""
+ dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+ # We could parse with `sqlglot.parse_one` to get a single statement, but we need
+ # to verify that the string contains exactly one statement.
statements = [
statement
- for statement in sqlglot.parse(sql_statement, dialect=dialect)
+ for statement in sqlglot.parse(statement, dialect=dialect)
if statement
]
if len(statements) != 1:
@@ -409,6 +479,18 @@ class SQLStatement:
return statements[0]
+ @classmethod
+ def _extract_tables_from_statement(
+ cls,
+ parsed: exp.Expression,
+ engine: str,
+ ) -> set[Table]:
+ """
+ Find all referenced tables.
+ """
+ dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+ return extract_tables_from_statement(parsed, dialect)
+
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
@@ -416,7 +498,7 @@ class SQLStatement:
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
- def get_settings(self) -> dict[str, str]:
+ def get_settings(self) -> dict[str, str | bool]:
"""
Return the settings for the SQL statement.
@@ -432,12 +514,189 @@ class SQLStatement:
}
+class KustoKQLStatement(BaseSQLStatement[str]):
+ """
+ Special class for Kusto KQL.
+
+ Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
+ like this:
+
+ StormEvents
+ | summarize PropertyDamage = sum(DamageProperty) by State
+ | join kind=innerunique PopulationData on State
+ | project State, PropertyDamagePerCapita = PropertyDamage / Population
+ | sort by PropertyDamagePerCapita
+
+ See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
+ details about it.
+ """
+
+ @classmethod
+ def split_query(
+ cls,
+ query: str,
+ engine: str,
+ ) -> list[KustoKQLStatement]:
+ """
+ Split a query at semi-colons.
+
+ Since we don't have a parser, we use a simple state machine based function. See
+ https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
+ for more information.
+ """
+
+ class KQLSplitState(enum.Enum):
+ """
+ State machine for splitting a KQL query.
+
+ The state machine keeps track of whether we're inside a string or not, so we
+ don't split the query in a semi-colon that's part of a string.
+ """
+
+ OUTSIDE_STRING = enum.auto()
+ INSIDE_SINGLE_QUOTED_STRING = enum.auto()
+ INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
+ INSIDE_MULTILINE_STRING = enum.auto()
+
+ statements = []
+ state = KQLSplitState.OUTSIDE_STRING
+ statement_start = 0
+ query = query if query.endswith(";") else query + ";"
+ for i, character in enumerate(query):
+ if state == KQLSplitState.OUTSIDE_STRING:
+ if character == ";":
+ statements.append(query[statement_start:i])
+ statement_start = i + 1
+ elif character == "'":
+ state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+ elif character == '"':
+ state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+ elif character == "`" and query[i - 2 : i] == "``":
+ state = KQLSplitState.INSIDE_MULTILINE_STRING
+
+ elif (
+ state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+ and character == "'"
+ and query[i - 1] != "\\"
+ ):
+ state = KQLSplitState.OUTSIDE_STRING
+
+ elif (
+ state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+ and character == '"'
+ and query[i - 1] != "\\"
+ ):
+ state = KQLSplitState.OUTSIDE_STRING
+
+ elif (
+ state == KQLSplitState.INSIDE_MULTILINE_STRING
+ and character == "`"
+ and query[i - 2 : i] == "``"
+ ):
+ state = KQLSplitState.OUTSIDE_STRING
+
+ return [cls(statement, engine) for statement in statements]
+
+ @classmethod
+ def _parse_statement(
+ cls,
+ statement: str,
+ engine: str,
+ ) -> str:
+ if engine != "kustokql":
+ raise ValueError(f"Invalid engine: {engine}")
+
+ # TODO: check if it's just a single statement
+
+ return statement.strip()
+
+ @classmethod
+ def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
+ """
+ Extract all tables referenced in the statement.
+
+ StormEvents
+ | where InjuriesDirect + InjuriesIndirect > 50
+ | join (PopulationData) on State
+ | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
+
+ """
+ logger.warning(
+ "Kusto KQL doesn't support table extraction. This means that data access "
+ "roles will not be enforced by Superset in the database."
+ )
+ return set()
+
+ def format(self, comments: bool = True) -> str:
+ """
+ Pretty-format the SQL statement.
+ """
+ return self._parsed
+
+ def get_settings(self) -> dict[str, str | bool]:
+ """
+ Return the settings for the SQL statement.
+
+ >>> statement = KustoKQLStatement("set querytrace;")
+ >>> statement.get_settings()
+ {"querytrace": True}
+
+ """
+ set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
+ if match := re.match(set_regex, self._parsed, re.IGNORECASE):
+ return {match.group("name"): match.group("value") or True}
+
+ return {}
+
+
+class SQLScript:
+ """
+ A SQL script, with 0+ statements.
+ """
+
+ # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
+ # adds a lot of complexity to Superset, so we should avoid adding new engines to
+ # this data structure.
+ special_engines = {
+ "kustokql": KustoKQLStatement,
+ }
+
+ def __init__(
+ self,
+ query: str,
+ engine: str,
+ ):
+ statement_class = self.special_engines.get(engine, SQLStatement)
+ self.statements = statement_class.split_query(query, engine)
+
+ def format(self, comments: bool = True) -> str:
+ """
+ Pretty-format the SQL query.
+ """
+ return ";\n".join(statement.format(comments) for statement in self.statements)
+
+ def get_settings(self) -> dict[str, str | bool]:
+ """
+ Return the settings for the SQL query.
+
+ >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
+ >>> statement.get_settings()
+ {"foo": "'baz'"}
+
+ """
+ settings: dict[str, str | bool] = {}
+ for statement in self.statements:
+ settings.update(statement.get_settings())
+
+ return settings
+
+
class ParsedQuery:
def __init__(
self,
sql_statement: str,
strip_comments: bool = False,
- engine: Optional[str] = None,
+ engine: str | None = None,
):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -446,7 +705,7 @@ class ParsedQuery:
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
- self._limit: Optional[int] = None
+ self._limit: int | None = None
logger.debug("Parsing with sqlparse statement: %s", self.sql)
self._parsed = sqlparse.parse(self.stripped())
@@ -550,7 +809,7 @@ class ParsedQuery:
return source.name in ctes_in_scope
@property
- def limit(self) -> Optional[int]:
+ def limit(self) -> int | None:
return self._limit
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -631,7 +890,7 @@ class ParsedQuery:
return True
- def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+ def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
for token in tokens:
if self._is_identifier(token):
for identifier_token in token.tokens:
@@ -695,7 +954,7 @@ class ParsedQuery:
return statements
@staticmethod
- def get_table(tlist: TokenList) -> Optional[Table]:
+ def get_table(tlist: TokenList) -> Table | None:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
@@ -731,7 +990,7 @@ class ParsedQuery:
def as_create_table(
self,
table_name: str,
- schema_name: Optional[str] = None,
+ schema_name: str | None = None,
overwrite: bool = False,
method: CtasMethod = CtasMethod.TABLE,
) -> str:
@@ -891,8 +1150,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
def get_rls_for_table(
candidate: Token,
database_id: int,
- default_schema: Optional[str],
-) -> Optional[TokenList]:
+ default_schema: str | None,
+) -> TokenList | None:
"""
Given a table name, return any associated RLS predicates.
"""
@@ -938,7 +1197,7 @@ def get_rls_for_table(
def insert_rls_as_subquery(
token_list: TokenList,
database_id: int,
- default_schema: Optional[str],
+ default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
@@ -954,7 +1213,7 @@ def insert_rls_as_subquery(
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
databases.
"""
- rls: Optional[TokenList] = None
+ rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
@@ -1030,7 +1289,7 @@ def insert_rls_as_subquery(
def insert_rls_in_predicate(
token_list: TokenList,
database_id: int,
- default_schema: Optional[str],
+ default_schema: str | None,
) -> TokenList:
"""
Update a statement inplace applying any associated RLS predicates.
@@ -1041,7 +1300,7 @@ def insert_rls_in_predicate(
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
"""
- rls: Optional[TokenList] = None
+ rls: TokenList | None = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
@@ -1175,7 +1434,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f097fd1df3..8bdd1ee4de 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -33,6 +33,7 @@ from superset.sql_parse import (
has_table_query,
insert_rls_as_subquery,
insert_rls_in_predicate,
+ KustoKQLStatement,
ParsedQuery,
sanitize_clause,
SQLScript,
@@ -1858,21 +1859,31 @@ def test_sqlquery() -> None:
"""
Test the `SQLScript` class.
"""
- script = SQLScript("SELECT 1; SELECT 2;")
+ script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
assert len(script.statements) == 2
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
assert script.statements[0].format() == "SELECT\n 1"
- script = SQLScript("SET a=1; SET a=2; SELECT 3;")
+ script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
assert script.get_settings() == {"a": "2"}
+ query = SQLScript(
+ """set querytrace;
+Events | take 100""",
+ "kustokql",
+ )
+ assert query.get_settings() == {"querytrace": True}
+
def test_sqlstatement() -> None:
"""
Test the `SQLStatement` class.
"""
- statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
+ statement = SQLStatement(
+ "SELECT * FROM table1 UNION ALL SELECT * FROM table2",
+ "sqlite",
+ )
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
@@ -1883,5 +1894,74 @@ def test_sqlstatement() -> None:
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
- statement = SQLStatement("SET a=1")
+ statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
+
+
+def test_kustokqlstatement() -> None:
+ """
+ Test the `KustoKQLStatement` class.
+ """
+ statements = KustoKQLStatement.split_query(
+ """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+ cachedResult
+ | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+ totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+ """,
+ "kustokql",
+ )
+ assert len(statements) == 4
+
+ statements = KustoKQLStatement.split_query(
+ """
+print program = ```
+ public class Program {
+ public static void Main() {
+ System.Console.WriteLine("Hello!");
+ }
+ }```
+ """,
+ "kustokql",
+ )
+ assert len(statements) == 1
+
+ statements = KustoKQLStatement.split_query(
+ """
+set querytrace;
+Events | take 100
+ """,
+ "kustokql",
+ )
+ assert len(statements) == 2
+ assert statements[0].format() == "set querytrace"
+ assert statements[1].format() == "Events | take 100"
+
+
+@pytest.mark.parametrize(
+ "kql,statements",
+ [
+ ('print banner=strcat("Hello", ", ", "World!")', 1),
+ (r"print 'O\'Malley\'s'", 1),
+ (r"print 'O\'Mal;ley\'s'", 1),
+ ("print ```foo;\nbar;\nbaz;```\n", 1),
+ ],
+)
+def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
+ assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements