You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/03/22 12:26:50 UTC

(superset) 05/06: fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit b4fde185084a956913b779dd535a19b148eb983c
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Fri Mar 22 13:39:28 2024 +1300

    fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
---
 superset/commands/sql_lab/execute.py |   4 +-
 superset/jinja_context.py            |  10 ++--
 superset/models/sql_lab.py           |  40 ++++++++-----
 superset/security/manager.py         |  13 ++---
 superset/sql_parse.py                | 105 +++++++++++++++++++++++++++--------
 superset/sqllab/query_render.py      |   3 +-
 tests/unit_tests/sql_parse_tests.py  |  41 ++++++++++++++
 7 files changed, 163 insertions(+), 53 deletions(-)

diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py
index 5d955571d8..533264fb28 100644
--- a/superset/commands/sql_lab/execute.py
+++ b/superset/commands/sql_lab/execute.py
@@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand):
         try:
             logger.info("Triggering query_id: %i", query.id)
 
+            # Necessary to check access before rendering the Jinjafied query as the
+            # some Jinja macros execute statements upon rendering.
+            self._validate_access(query)
             self._execution_context.set_query(query)
             rendered_query = self._sql_query_render.render(self._execution_context)
             validate_rendered_query = copy.copy(query)
             validate_rendered_query.sql = rendered_query
-            self._validate_access(validate_rendered_query)
             self._set_query_limit_if_required(rendered_query)
             self._query_dao.update(
                 query, {"limit": self._execution_context.query.limit}
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 54d1f54866..9edddf24d9 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
 import dateutil
 from flask import current_app, g, has_request_context, request
 from flask_babel import gettext as _
-from jinja2 import DebugUndefined
+from jinja2 import DebugUndefined, Environment
 from jinja2.sandbox import SandboxedEnvironment
 from sqlalchemy.engine.interfaces import Dialect
 from sqlalchemy.sql.expression import bindparam
@@ -462,11 +462,11 @@ class BaseTemplateProcessor:
         self._applied_filters = applied_filters
         self._removed_filters = removed_filters
         self._context: dict[str, Any] = {}
-        self._env = SandboxedEnvironment(undefined=DebugUndefined)
+        self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
         self.set_context(**kwargs)
 
         # custom filters
-        self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
+        self.env.filters["where_in"] = WhereInMacro(database.get_dialect())
 
     def set_context(self, **kwargs: Any) -> None:
         self._context.update(kwargs)
@@ -479,7 +479,7 @@ class BaseTemplateProcessor:
         >>> process_template(sql)
         "SELECT '2017-01-01T00:00:00'"
         """
-        template = self._env.from_string(sql)
+        template = self.env.from_string(sql)
         kwargs.update(self._context)
 
         context = validate_template_context(self.engine, kwargs)
@@ -623,7 +623,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
     engine = "trino"
 
     def process_template(self, sql: str, **kwargs: Any) -> str:
-        template = self._env.from_string(sql)
+        template = self.env.from_string(sql)
         kwargs.update(self._context)
 
         # Backwards compatibility if migrating from Presto.
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index aff8c1ce3d..2fd0d2dc76 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship
 from sqlalchemy.sql.elements import ColumnElement, literal_column
 
 from superset import security_manager
+from superset.exceptions import SupersetSecurityException
 from superset.jinja_context import BaseTemplateProcessor, get_template_processor
 from superset.models.helpers import (
     AuditMixinNullable,
@@ -53,7 +54,7 @@ from superset.models.helpers import (
     ExtraJSONMixin,
     ImportExportMixin,
 )
-from superset.sql_parse import CtasMethod, ParsedQuery, Table
+from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
 from superset.sqllab.limiting_factor import LimitingFactor
 from superset.utils.core import get_column_name, QueryStatus, user_label
 
@@ -65,8 +66,25 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class SqlTablesMixin:  # pylint: disable=too-few-public-methods
+    @property
+    def sql_tables(self) -> list[Table]:
+        try:
+            return list(
+                extract_tables_from_jinja_sql(
+                    self.sql,  # type: ignore
+                    self.database.db_engine_spec.engine,  # type: ignore
+                )
+            )
+        except SupersetSecurityException:
+            return []
+
+
 class Query(
-    Model, ExtraJSONMixin, ExploreMixin
+    SqlTablesMixin,
+    ExtraJSONMixin,
+    ExploreMixin,
+    Model,
 ):  # pylint: disable=abstract-method,too-many-public-methods
     """ORM model for SQL query
 
@@ -181,10 +199,6 @@ class Query(
     def username(self) -> str:
         return self.user.username
 
-    @property
-    def sql_tables(self) -> list[Table]:
-        return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
-
     @property
     def columns(self) -> list["TableColumn"]:
         from superset.connectors.sqla.models import (  # pylint: disable=import-outside-toplevel
@@ -355,7 +369,13 @@ class Query(
         return self.make_sqla_column_compatible(sqla_column, label)
 
 
-class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
+class SavedQuery(
+    SqlTablesMixin,
+    AuditMixinNullable,
+    ExtraJSONMixin,
+    ImportExportMixin,
+    Model,
+):
     """ORM model for SQL query"""
 
     __tablename__ = "saved_query"
@@ -425,12 +445,6 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
     def url(self) -> str:
         return f"/sqllab?savedQueryId={self.id}"
 
-    @property
-    def sql_tables(self) -> list[Table]:
-        return list(
-            ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
-        )
-
     @property
     def last_run_humanized(self) -> str:
         return naturaltime(datetime.now() - self.changed_on)
diff --git a/superset/security/manager.py b/superset/security/manager.py
index bc235ad505..3c4b4106d3 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload, Session
 from sqlalchemy.orm.mapper import Mapper
 from sqlalchemy.orm.query import Query as SqlaQuery
 
-from superset import sql_parse
 from superset.constants import RouteMethod
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import (
     DatasetInvalidPermissionEvaluationException,
     SupersetSecurityException,
 )
-from superset.jinja_context import get_template_processor
 from superset.security.guest_token import (
     GuestToken,
     GuestTokenResources,
@@ -68,6 +66,7 @@ from superset.security.guest_token import (
     GuestTokenUser,
     GuestUser,
 )
+from superset.sql_parse import extract_tables_from_jinja_sql
 from superset.superset_typing import Metric
 from superset.utils.core import (
     DatasourceName,
@@ -1961,16 +1960,12 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 return
 
             if query:
-                # make sure the quuery is valid SQL by rendering any Jinja
-                processor = get_template_processor(database=query.database)
-                rendered_sql = processor.process_template(query.sql)
                 default_schema = database.get_default_schema_for_query(query)
                 tables = {
                     Table(table_.table, table_.schema or default_schema)
-                    for table_ in sql_parse.ParsedQuery(
-                        rendered_sql,
-                        engine=database.db_engine_spec.engine,
-                    ).tables
+                    for table_ in extract_tables_from_jinja_sql(
+                        query.sql, database.db_engine_spec.engine
+                    )
                 }
             elif table:
                 tables = {table}
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index db51991e22..58bca48a6e 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -16,16 +16,19 @@
 # under the License.
 
 # pylint: disable=too-many-lines
+from __future__ import annotations
 
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional
+from typing import Any, cast
+from unittest.mock import Mock
 
 import sqlparse
 from flask_babel import gettext as __
+from jinja2 import nodes
 from sqlalchemy import and_
 from sqlglot import exp, parse, parse_one
 from sqlglot.dialects import Dialects
@@ -142,7 +145,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -163,9 +166,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -189,7 +190,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -197,7 +198,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -215,7 +216,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -239,8 +240,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -262,7 +263,7 @@ class ParsedQuery:
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -271,7 +272,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -382,7 +383,7 @@ class ParsedQuery:
         return source.name in ctes_in_scope
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -463,7 +464,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -527,7 +528,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -563,7 +564,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -723,8 +724,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -770,7 +771,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -786,7 +787,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -862,7 +863,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -873,7 +874,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1007,7 +1008,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """
@@ -1051,3 +1052,61 @@ def extract_table_references(
         Table(*[part["value"] for part in table["name"][::-1]])
         for table in find_nodes_by_key(tree, "Table")
     }
+
+
+def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
+    """
+    Extract all table references in the Jinjafied SQL statement.
+
+    Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
+    statement may represent invalid SQL which is non-parsable by SQLGlot.
+
+    Firstly, we extract any tables referenced within the confines of specific Jinja
+    macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
+    expression to help ensure that the resulting SQL statements are parsable by
+    SQLGlot.
+
+    :param sql: The Jinjafied SQL statement
+    :param engine: The associated database engine
+    :returns: The set of tables referenced in the SQL statement
+    :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
+    """
+
+    from superset.jinja_context import (  # pylint: disable=import-outside-toplevel
+        get_template_processor,
+    )
+
+    # Mock the required database as the processor signature is exposed publically.
+    processor = get_template_processor(database=Mock(backend=engine))
+    template = processor.env.parse(sql)
+
+    tables = set()
+
+    for node in template.find_all(nodes.Call):
+        if isinstance(node.node, nodes.Getattr) and node.node.attr in (
+            "latest_partition",
+            "latest_sub_partition",
+        ):
+            # Extract the table referenced in the macro.
+            tables.add(
+                Table(
+                    *[
+                        remove_quotes(part)
+                        for part in node.args[0].value.split(".")[::-1]
+                        if len(node.args) == 1
+                    ]
+                )
+            )
+
+            # Replace the potentially problematic Jinja macro with some benign SQL.
+            node.__class__ = nodes.TemplateData
+            node.fields = nodes.TemplateData.fields
+            node.data = "NULL"
+
+    return (
+        tables
+        | ParsedQuery(
+            sql_statement=processor.process_template(template),
+            engine=engine,
+        ).tables
+    )
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 5597bcb086..caf9a3cb2b 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
         sql_template_processor: BaseTemplateProcessor,
     ) -> None:
         if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
-            # pylint: disable=protected-access
-            syntax_tree = sql_template_processor._env.parse(rendered_query)
+            syntax_tree = sql_template_processor.env.parse(rendered_query)
             undefined_parameters = find_undeclared_variables(syntax_tree)
             if undefined_parameters:
                 self._raise_undefined_parameter_exception(
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index fc1ae1b231..42b796bc58 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -32,6 +32,7 @@ from superset.exceptions import (
 from superset.sql_parse import (
     add_table_name,
     extract_table_references,
+    extract_tables_from_jinja_sql,
     get_rls_for_table,
     has_table_query,
     insert_rls_as_subquery,
@@ -1875,3 +1876,43 @@ WITH t AS (
 )
 SELECT * FROM t"""
     ).is_select()
+
+
+@pytest.mark.parametrize(
+    "engine",
+    [
+        "hive",
+        "presto",
+        "trino",
+    ],
+)
+@pytest.mark.parametrize(
+    "macro",
+    [
+        "latest_partition('foo.bar')",
+        "latest_sub_partition('foo.bar', baz='qux')",
+    ],
+)
+@pytest.mark.parametrize(
+    "sql,expected",
+    [
+        (
+            "SELECT '{{{{ {engine}.{macro} }}}}'",
+            {Table(table="bar", schema="foo")},
+        ),
+        (
+            "SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
+            {Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
+        ),
+    ],
+)
+def test_extract_tables_from_jinja_sql(
+    engine: str,
+    macro: str,
+    sql: str,
+    expected: set[Table],
+) -> None:
+    assert (
+        extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
+        == expected
+    )