You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2024/03/22 00:39:35 UTC
(superset) branch master updated: fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new b25dd0c055 fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
b25dd0c055 is described below
commit b25dd0c055a437f93a870a3f3188fd5b83d9ecfe
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Fri Mar 22 13:39:28 2024 +1300
fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
---
superset/commands/sql_lab/execute.py | 4 ++-
superset/jinja_context.py | 10 +++---
superset/models/sql_lab.py | 40 ++++++++++++++++--------
superset/security/manager.py | 13 +++-----
superset/sql_parse.py | 60 ++++++++++++++++++++++++++++++++++++
superset/sqllab/query_render.py | 3 +-
tests/unit_tests/sql_parse_tests.py | 41 ++++++++++++++++++++++++
7 files changed, 141 insertions(+), 30 deletions(-)
diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py
index 5d955571d8..533264fb28 100644
--- a/superset/commands/sql_lab/execute.py
+++ b/superset/commands/sql_lab/execute.py
@@ -144,11 +144,13 @@ class ExecuteSqlCommand(BaseCommand):
try:
logger.info("Triggering query_id: %i", query.id)
+ # Necessary to check access before rendering the Jinjafied query as the
+ # some Jinja macros execute statements upon rendering.
+ self._validate_access(query)
self._execution_context.set_query(query)
rendered_query = self._sql_query_render.render(self._execution_context)
validate_rendered_query = copy.copy(query)
validate_rendered_query.sql = rendered_query
- self._validate_access(validate_rendered_query)
self._set_query_limit_if_required(rendered_query)
self._query_dao.update(
query, {"limit": self._execution_context.query.limit}
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 2990953bae..0ee7667811 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -24,7 +24,7 @@ from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Unio
import dateutil
from flask import current_app, has_request_context, request
from flask_babel import gettext as _
-from jinja2 import DebugUndefined
+from jinja2 import DebugUndefined, Environment
from jinja2.sandbox import SandboxedEnvironment
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.expression import bindparam
@@ -479,11 +479,11 @@ class BaseTemplateProcessor:
self._applied_filters = applied_filters
self._removed_filters = removed_filters
self._context: dict[str, Any] = {}
- self._env = SandboxedEnvironment(undefined=DebugUndefined)
+ self.env: Environment = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs)
# custom filters
- self._env.filters["where_in"] = WhereInMacro(database.get_dialect())
+ self.env.filters["where_in"] = WhereInMacro(database.get_dialect())
def set_context(self, **kwargs: Any) -> None:
self._context.update(kwargs)
@@ -496,7 +496,7 @@ class BaseTemplateProcessor:
>>> process_template(sql)
"SELECT '2017-01-01T00:00:00'"
"""
- template = self._env.from_string(sql)
+ template = self.env.from_string(sql)
kwargs.update(self._context)
context = validate_template_context(self.engine, kwargs)
@@ -643,7 +643,7 @@ class TrinoTemplateProcessor(PrestoTemplateProcessor):
engine = "trino"
def process_template(self, sql: str, **kwargs: Any) -> str:
- template = self._env.from_string(sql)
+ template = self.env.from_string(sql)
kwargs.update(self._context)
# Backwards compatibility if migrating from Presto.
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index f4724d6dab..2d7384a74e 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -46,6 +46,7 @@ from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql.elements import ColumnElement, literal_column
from superset import security_manager
+from superset.exceptions import SupersetSecurityException
from superset.jinja_context import BaseTemplateProcessor, get_template_processor
from superset.models.helpers import (
AuditMixinNullable,
@@ -53,7 +54,7 @@ from superset.models.helpers import (
ExtraJSONMixin,
ImportExportMixin,
)
-from superset.sql_parse import CtasMethod, ParsedQuery, Table
+from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils.core import get_column_name, MediumText, QueryStatus, user_label
@@ -65,8 +66,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class SqlTablesMixin: # pylint: disable=too-few-public-methods
+ @property
+ def sql_tables(self) -> list[Table]:
+ try:
+ return list(
+ extract_tables_from_jinja_sql(
+ self.sql, # type: ignore
+ self.database.db_engine_spec.engine, # type: ignore
+ )
+ )
+ except SupersetSecurityException:
+ return []
+
+
class Query(
- ExtraJSONMixin, ExploreMixin, Model
+ SqlTablesMixin,
+ ExtraJSONMixin,
+ ExploreMixin,
+ Model,
): # pylint: disable=abstract-method,too-many-public-methods
"""ORM model for SQL query
@@ -181,10 +199,6 @@ class Query(
def username(self) -> str:
return self.user.username
- @property
- def sql_tables(self) -> list[Table]:
- return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
-
@property
def columns(self) -> list["TableColumn"]:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
@@ -355,7 +369,13 @@ class Query(
return self.make_sqla_column_compatible(sqla_column, label)
-class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
+class SavedQuery(
+ SqlTablesMixin,
+ AuditMixinNullable,
+ ExtraJSONMixin,
+ ImportExportMixin,
+ Model,
+):
"""ORM model for SQL query"""
__tablename__ = "saved_query"
@@ -425,12 +445,6 @@ class SavedQuery(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
def url(self) -> str:
return f"/sqllab?savedQueryId={self.id}"
- @property
- def sql_tables(self) -> list[Table]:
- return list(
- ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
- )
-
@property
def last_run_humanized(self) -> str:
return naturaltime(datetime.now() - self.changed_on)
diff --git a/superset/security/manager.py b/superset/security/manager.py
index a532431433..2833e88645 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -52,14 +52,12 @@ from sqlalchemy.orm import eagerload
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
-from superset import sql_parse
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
DatasetInvalidPermissionEvaluationException,
SupersetSecurityException,
)
-from superset.jinja_context import get_template_processor
from superset.security.guest_token import (
GuestToken,
GuestTokenResources,
@@ -68,6 +66,7 @@ from superset.security.guest_token import (
GuestTokenUser,
GuestUser,
)
+from superset.sql_parse import extract_tables_from_jinja_sql
from superset.superset_typing import Metric
from superset.utils.core import (
DatasourceName,
@@ -1961,16 +1960,12 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return
if query:
- # make sure the quuery is valid SQL by rendering any Jinja
- processor = get_template_processor(database=query.database)
- rendered_sql = processor.process_template(query.sql)
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
- for table_ in sql_parse.ParsedQuery(
- rendered_sql,
- engine=database.db_engine_spec.engine,
- ).tables
+ for table_ in extract_tables_from_jinja_sql(
+ query.sql, database.db_engine_spec.engine
+ )
}
elif table:
tables = {table}
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 9367d3c59f..f721f456d0 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -25,10 +25,12 @@ import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast
+from unittest.mock import Mock
import sqlglot
import sqlparse
from flask_babel import gettext as __
+from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
@@ -1232,3 +1234,61 @@ def extract_table_references(
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}
+
+
+def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
+ """
+ Extract all table references in the Jinjafied SQL statement.
+
+ Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
+ statement may represent invalid SQL which is non-parsable by SQLGlot.
+
+ Firstly, we extract any tables referenced within the confines of specific Jinja
+ macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
+ expression to help ensure that the resulting SQL statements are parsable by
+ SQLGlot.
+
+ :param sql: The Jinjafied SQL statement
+ :param engine: The associated database engine
+ :returns: The set of tables referenced in the SQL statement
+ :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
+ """
+
+ from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
+ get_template_processor,
+ )
+
+ # Mock the required database as the processor signature is exposed publically.
+ processor = get_template_processor(database=Mock(backend=engine))
+ template = processor.env.parse(sql)
+
+ tables = set()
+
+ for node in template.find_all(nodes.Call):
+ if isinstance(node.node, nodes.Getattr) and node.node.attr in (
+ "latest_partition",
+ "latest_sub_partition",
+ ):
+ # Extract the table referenced in the macro.
+ tables.add(
+ Table(
+ *[
+ remove_quotes(part)
+ for part in node.args[0].value.split(".")[::-1]
+ if len(node.args) == 1
+ ]
+ )
+ )
+
+ # Replace the potentially problematic Jinja macro with some benign SQL.
+ node.__class__ = nodes.TemplateData
+ node.fields = nodes.TemplateData.fields
+ node.data = "NULL"
+
+ return (
+ tables
+ | ParsedQuery(
+ sql_statement=processor.process_template(template),
+ engine=engine,
+ ).tables
+ )
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 5597bcb086..caf9a3cb2b 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -79,8 +79,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
sql_template_processor: BaseTemplateProcessor,
) -> None:
if is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
- # pylint: disable=protected-access
- syntax_tree = sql_template_processor._env.parse(rendered_query)
+ syntax_tree = sql_template_processor.env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(syntax_tree)
if undefined_parameters:
self._raise_undefined_parameter_exception(
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 973a4e3793..aa4171e763 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -32,6 +32,7 @@ from superset.exceptions import (
from superset.sql_parse import (
add_table_name,
extract_table_references,
+ extract_tables_from_jinja_sql,
get_rls_for_table,
has_table_query,
insert_rls_as_subquery,
@@ -1909,3 +1910,43 @@ def test_sqlstatement() -> None:
statement = SQLStatement("SET a=1")
assert statement.get_settings() == {"a": "1"}
+
+
+@pytest.mark.parametrize(
+ "engine",
+ [
+ "hive",
+ "presto",
+ "trino",
+ ],
+)
+@pytest.mark.parametrize(
+ "macro",
+ [
+ "latest_partition('foo.bar')",
+ "latest_sub_partition('foo.bar', baz='qux')",
+ ],
+)
+@pytest.mark.parametrize(
+ "sql,expected",
+ [
+ (
+ "SELECT '{{{{ {engine}.{macro} }}}}'",
+ {Table(table="bar", schema="foo")},
+ ),
+ (
+ "SELECT * FROM foo.baz WHERE quux = '{{{{ {engine}.{macro} }}}}'",
+ {Table(table="bar", schema="foo"), Table(table="baz", schema="foo")},
+ ),
+ ],
+)
+def test_extract_tables_from_jinja_sql(
+ engine: str,
+ macro: str,
+ sql: str,
+ expected: set[Table],
+) -> None:
+ assert (
+ extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
+ == expected
+ )