You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/14 21:26:06 UTC

(superset) branch remove-sqlparse-kusto updated (40f3607af8 -> 3582579c15)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a change to branch remove-sqlparse-kusto
in repository https://gitbox.apache.org/repos/asf/superset.git


 discard 40f3607af8 feat: support for KQL in SQLQuery
     new 3582579c15 feat: support for KQL in SQLQuery

This update added new revisions after undoing existing revisions.
That is to say, some revisions that were in the old version of the
branch are not in the new version.  This situation occurs
when a user --force pushes a change and generates a repository
containing something like this:

 * -- * -- B -- O -- O -- O   (40f3607af8)
            \
             N -- N -- N   refs/heads/remove-sqlparse-kusto (3582579c15)

You should already have received notification emails for all of the O
revisions, and so the following emails describe only the N revisions
from the common base, B.

Any revisions marked "omit" are not gone; other references still
refer to them.  Any revisions marked "discard" are gone forever.

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 superset/sql_parse.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)


(superset) 01/01: feat: support for KQL in SQLQuery

Posted by be...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch remove-sqlparse-kusto
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 3582579c152c5165d315e90703db5cbb89c83ca5
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Wed Jan 24 10:02:47 2024 -0500

    feat: support for KQL in SQLQuery
---
 superset/sql_parse.py               | 392 ++++++++++++++++++++++++++++++------
 tests/unit_tests/sql_parse_tests.py |  88 +++++++-
 2 files changed, 411 insertions(+), 69 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58dc210e2b..eebf12c55d 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,15 @@
 
 # pylint: disable=too-many-lines
 
+from __future__ import annotations
+
+import enum
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast, Generic, TypeVar
 
 import sqlglot
 import sqlparse
@@ -138,7 +141,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -159,9 +162,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -185,7 +186,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -193,7 +194,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -211,7 +212,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +236,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -255,7 +256,7 @@ class Table:
 
 def extract_tables_from_statement(
     statement: exp.Expression,
-    dialect: Optional[Dialects],
+    dialect: Dialects | None,
 ) -> set[Table]:
     """
     Extract all table references in a single statement.
@@ -326,82 +327,152 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
     return source.name in ctes_in_scope
 
 
-class SQLScript:
+# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
+# an "internal representation", which is the AST of the SQL statement. For most of the
+# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
+# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we
+# store the AST as a string (the original query), and manipulate it with regular
+# expressions.
+InternalRepresentation = TypeVar("InternalRepresentation")
+
+# The base type. This helps type checking the `split_query` method correctly, since each
+# derived class has a more specific return type (the class itself). This will no longer
+# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more
+# information: https://peps.python.org/pep-0673/
+TBaseSQLStatement = TypeVar("TBaseSQLStatement")
+
+
+class BaseSQLStatement(Generic[InternalRepresentation]):
     """
-    A SQL script, with 0+ statements.
+    Base class for SQL statements.
+
+    The class can be instantiated with a string representation of the query or, for
+    efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
+    which will split a query in multiple already parsed statements.
+
+    The `engine` parameters comes from the `engine` attribute in a Superset DB engine
+    spec.
     """
 
     def __init__(
         self,
-        query: str,
-        engine: Optional[str] = None,
+        statement: str | InternalRepresentation,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        self._parsed: InternalRepresentation = (
+            self._parse_statement(statement, engine)
+            if isinstance(statement, str)
+            else statement
+        )
+        self.engine = engine
+        self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
 
-        self.statements = [
-            SQLStatement(statement, engine=engine)
-            for statement in parse(query, dialect=dialect)
-            if statement
-        ]
+    @classmethod
+    def split_query(
+        cls: type[TBaseSQLStatement],
+        query: str,
+        engine: str,
+    ) -> list[TBaseSQLStatement]:
+        """
+        Split a query into multiple instantiated statements.
+
+        This is a helper function to split a full SQL query into multiple
+        `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the
+        statements within a query.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> InternalRepresentation:
+        """
+        Parse a string containing a single SQL statement, and returns the parsed AST.
+
+        Derived classes should not assume that `statement` contains a single statement,
+        and MUST explicitly validate that. Since this validation is parser dependent the
+        responsibility is left to the children classes.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: InternalRepresentation,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Extract all table references in a given statement.
+        """
+        raise NotImplementedError()
 
     def format(self, comments: bool = True) -> str:
         """
-        Pretty-format the SQL query.
+        Format the statement, optionally ommitting comments.
         """
-        return ";\n".join(statement.format(comments) for statement in self.statements)
+        raise NotImplementedError()
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
-        Return the settings for the SQL query.
+        Return any settings set by the statement.
 
-            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
-            >>> statement.get_settings()
-            {"foo": "'baz'"}
+        For example, for this statement:
 
+            sql> SET foo = 'bar';
+
+        The method should return `{"foo": "'bar'"}`. Note the single quotes.
         """
-        settings: dict[str, str] = {}
-        for statement in self.statements:
-            settings.update(statement.get_settings())
+        raise NotImplementedError()
 
-        return settings
+    def __str__(self) -> str:
+        return self.format()
 
 
-class SQLStatement:
+class SQLStatement(BaseSQLStatement[exp.Expression]):
     """
     A SQL statement.
 
-    This class provides helper methods to manipulate and introspect SQL.
+    This class is used for all engines with dialects that can be parsed using sqlglot.
     """
 
     def __init__(
         self,
-        statement: Union[str, exp.Expression],
-        engine: Optional[str] = None,
+        statement: str | exp.Expression,
+        engine: str,
     ):
-        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
-
-        if isinstance(statement, str):
-            try:
-                self._parsed = self._parse_statement(statement, dialect)
-            except ParseError as ex:
-                raise SupersetParseError(statement, engine) from ex
-        else:
-            self._parsed = statement
+        self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        super().__init__(statement, engine)
 
-        self._dialect = dialect
-        self.tables = extract_tables_from_statement(self._parsed, dialect)
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[SQLStatement]:
+        return [
+            cls(statement, engine)
+            for statement in sqlglot.parse(query, engine)
+            if statement
+        ]
 
-    @staticmethod
+    @classmethod
     def _parse_statement(
-        sql_statement: str,
-        dialect: Optional[Dialects],
+        cls,
+        statement: str,
+        engine: str,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
         """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+
+        # We could parse with `sqlglot.parse_one` to get a single statement, but we need
+        # to verify that the string contains exactly one statement.
         statements = [
             statement
-            for statement in sqlglot.parse(sql_statement, dialect=dialect)
+            for statement in sqlglot.parse(statement, dialect=dialect)
             if statement
         ]
         if len(statements) != 1:
@@ -409,6 +480,18 @@ class SQLStatement:
 
         return statements[0]
 
+    @classmethod
+    def _extract_tables_from_statement(
+        cls,
+        parsed: exp.Expression,
+        engine: str,
+    ) -> set[Table]:
+        """
+        Find all referenced tables.
+        """
+        dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
+        return extract_tables_from_statement(parsed, dialect)
+
     def format(self, comments: bool = True) -> str:
         """
         Pretty-format the SQL statement.
@@ -416,7 +499,7 @@ class SQLStatement:
         write = Dialect.get_or_raise(self._dialect)
         return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
 
-    def get_settings(self) -> dict[str, str]:
+    def get_settings(self) -> dict[str, str | bool]:
         """
         Return the settings for the SQL statement.
 
@@ -432,12 +515,191 @@ class SQLStatement:
         }
 
 
+class KustoKQLStatement(BaseSQLStatement[str]):
+    """
+    Special class for Kusto KQL.
+
+    Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look
+    like this:
+
+        StormEvents
+        | summarize PropertyDamage = sum(DamageProperty) by State
+        | join kind=innerunique PopulationData on State
+        | project State, PropertyDamagePerCapita = PropertyDamage / Population
+        | sort by PropertyDamagePerCapita
+
+    See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more
+    details about it.
+    """
+
+    @classmethod
+    def split_query(
+        cls,
+        query: str,
+        engine: str,
+    ) -> list[KustoKQLStatement]:
+        """
+        Split a query at semi-colons.
+
+        Since we don't have a parser, we use a simple state machine based function. See
+        https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
+        for more information.
+        """
+
+        class KQLSplitState(enum.Enum):
+            """
+            State machine for splitting a KQL query.
+
+            The state machine keeps track of whether we're inside a string or not, so we
+            don't split the query in a semi-colon that's part of a string.
+            """
+
+            OUTSIDE_STRING = enum.auto()
+            INSIDE_SINGLE_QUOTED_STRING = enum.auto()
+            INSIDE_DOUBLE_QUOTED_STRING = enum.auto()
+            INSIDE_MULTILINE_STRING = enum.auto()
+
+        statements = []
+        state = KQLSplitState.OUTSIDE_STRING
+        statement_start = 0
+        query = query if query.endswith(";") else query + ";"
+        for i, character in enumerate(query):
+            if state == KQLSplitState.OUTSIDE_STRING:
+                if character == ";":
+                    statements.append(query[statement_start:i])
+                    statement_start = i + 1
+                elif character == "'":
+                    state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                elif character == '"':
+                    state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                elif character == "`" and query[i - 2 : i] == "``":
+                    state = KQLSplitState.INSIDE_MULTILINE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
+                and character == "'"
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
+                and character == '"'
+                and query[i - 1] != "\\"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+            elif (
+                state == KQLSplitState.INSIDE_MULTILINE_STRING
+                and character == "`"
+                and query[i - 2 : i] == "``"
+            ):
+                state = KQLSplitState.OUTSIDE_STRING
+
+        return [cls(statement, engine) for statement in statements]
+
+    @classmethod
+    def _parse_statement(
+        cls,
+        statement: str,
+        engine: str,
+    ) -> str:
+        if engine != "kustokql":
+            raise ValueError(f"Invalid engine: {engine}")
+
+        statements = cls.split_query(statement, engine)
+        if len(statements) != 1:
+            raise ValueError("SQLStatement should have exactly one statement")
+
+        return statements[0]._parsed.strip()
+
+    @classmethod
+    def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]:
+        """
+        Extract all tables referenced in the statement.
+
+            StormEvents
+            | where InjuriesDirect + InjuriesIndirect > 50
+            | join (PopulationData) on State
+            | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect
+
+        """
+        logger.warning(
+            "Kusto KQL doesn't support table extraction. This means that data access "
+            "roles will not be enforced by Superset in the database."
+        )
+        return set()
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL statement.
+        """
+        return self._parsed
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL statement.
+
+            >>> statement = KustoKQLStatement("set querytrace;")
+            >>> statement.get_settings()
+            {"querytrace": True}
+
+        """
+        set_regex = r"^set\s+(?P<name>\w+)(?:\s*=\s*(?P<value>\w+))?$"
+        if match := re.match(set_regex, self._parsed, re.IGNORECASE):
+            return {match.group("name"): match.group("value") or True}
+
+        return {}
+
+
+class SQLScript:
+    """
+    A SQL script, with 0+ statements.
+    """
+
+    # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines
+    # adds a lot of complexity to Superset, so we should avoid adding new engines to
+    # this data structure.
+    special_engines = {
+        "kustokql": KustoKQLStatement,
+    }
+
+    def __init__(
+        self,
+        query: str,
+        engine: str,
+    ):
+        statement_class = self.special_engines.get(engine, SQLStatement)
+        self.statements = statement_class.split_query(query, engine)
+
+    def format(self, comments: bool = True) -> str:
+        """
+        Pretty-format the SQL query.
+        """
+        return ";\n".join(statement.format(comments) for statement in self.statements)
+
+    def get_settings(self) -> dict[str, str | bool]:
+        """
+        Return the settings for the SQL query.
+
+            >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'")
+            >>> statement.get_settings()
+            {"foo": "'baz'"}
+
+        """
+        settings: dict[str, str | bool] = {}
+        for statement in self.statements:
+            settings.update(statement.get_settings())
+
+        return settings
+
+
 class ParsedQuery:
     def __init__(
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -446,7 +708,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -550,7 +812,7 @@ class ParsedQuery:
         return source.name in ctes_in_scope
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -631,7 +893,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -695,7 +957,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -731,7 +993,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -891,8 +1153,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -938,7 +1200,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -954,7 +1216,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1030,7 +1292,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -1041,7 +1303,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1175,7 +1437,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f097fd1df3..8bdd1ee4de 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -33,6 +33,7 @@ from superset.sql_parse import (
     has_table_query,
     insert_rls_as_subquery,
     insert_rls_in_predicate,
+    KustoKQLStatement,
     ParsedQuery,
     sanitize_clause,
     SQLScript,
@@ -1858,21 +1859,31 @@ def test_sqlquery() -> None:
     """
     Test the `SQLScript` class.
     """
-    script = SQLScript("SELECT 1; SELECT 2;")
+    script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
 
     assert len(script.statements) == 2
     assert script.format() == "SELECT\n  1;\nSELECT\n  2"
     assert script.statements[0].format() == "SELECT\n  1"
 
-    script = SQLScript("SET a=1; SET a=2; SELECT 3;")
+    script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
     assert script.get_settings() == {"a": "2"}
 
+    query = SQLScript(
+        """set querytrace;
+Events | take 100""",
+        "kustokql",
+    )
+    assert query.get_settings() == {"querytrace": True}
+
 
 def test_sqlstatement() -> None:
     """
     Test the `SQLStatement` class.
     """
-    statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2")
+    statement = SQLStatement(
+        "SELECT * FROM table1 UNION ALL SELECT * FROM table2",
+        "sqlite",
+    )
 
     assert statement.tables == {
         Table(table="table1", schema=None, catalog=None),
@@ -1883,5 +1894,74 @@ def test_sqlstatement() -> None:
         == "SELECT\n  *\nFROM table1\nUNION ALL\nSELECT\n  *\nFROM table2"
     )
 
-    statement = SQLStatement("SET a=1")
+    statement = SQLStatement("SET a=1", "sqlite")
     assert statement.get_settings() == {"a": "1"}
+
+
+def test_kustokqlstatement() -> None:
+    """
+    Test the `KustoKQLStatement` class.
+    """
+    statements = KustoKQLStatement.split_query(
+        """
+let totalPagesPerDay = PageViews
+| summarize by Page, Day = startofday(Timestamp)
+| summarize count() by Day;
+let materializedScope = PageViews
+| summarize by Page, Day = startofday(Timestamp);
+let cachedResult = materialize(materializedScope);
+cachedResult
+| project Page, Day1 = Day
+| join kind = inner
+(
+    cachedResult
+    | project Page, Day2 = Day
+)
+on Page
+| where Day2 > Day1
+| summarize count() by Day1, Day2
+| join kind = inner
+    totalPagesPerDay
+on $left.Day1 == $right.Day
+| project Day1, Day2, Percentage = count_*100.0/count_1
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 4
+
+    statements = KustoKQLStatement.split_query(
+        """
+print program = ```
+  public class Program {
+    public static void Main() {
+      System.Console.WriteLine("Hello!");
+    }
+  }```
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 1
+
+    statements = KustoKQLStatement.split_query(
+        """
+set querytrace;
+Events | take 100
+        """,
+        "kustokql",
+    )
+    assert len(statements) == 2
+    assert statements[0].format() == "set querytrace"
+    assert statements[1].format() == "Events | take 100"
+
+
+@pytest.mark.parametrize(
+    "kql,statements",
+    [
+        ('print banner=strcat("Hello", ", ", "World!")', 1),
+        (r"print 'O\'Malley\'s'", 1),
+        (r"print 'O\'Mal;ley\'s'", 1),
+        ("print ```foo;\nbar;\nbaz;```\n", 1),
+    ],
+)
+def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
+    assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements