You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/04/04 08:22:24 UTC
[superset] 15/24: feat: improve adhoc SQL validation (#19454)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to tag 1.5.0rc1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 840be9972fd3318d70c9a3a153c4b0db3867a6d7
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 31 11:55:19 2022 -0700
feat: improve adhoc SQL validation (#19454)
* feat: improve adhoc SQL validation
* Small changes
* Add more unit tests
(cherry picked from commit 6828624f61fff21485b0b2e91ac53701d43cb0d7)
---
superset/connectors/sqla/models.py | 38 +++++++++++++----
superset/connectors/sqla/utils.py | 40 +++++++++++-------
superset/sql_parse.py | 84 ++++++++++++++++++++++++-------------
tests/unit_tests/sql_parse_tests.py | 80 ++++++++++++++++++++++++++---------
4 files changed, 170 insertions(+), 72 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 535eb81a57..31f3339541 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -923,7 +923,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
- validate_adhoc_subquery(expression)
+ expression = validate_adhoc_subquery(
+ expression,
+ self.database_id,
+ self.schema,
+ )
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
@@ -952,7 +956,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
- validate_adhoc_subquery(expression)
+ expression = validate_adhoc_subquery(
+ expression,
+ self.database_id,
+ self.schema,
+ )
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
@@ -1006,9 +1014,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"
- def _get_sqla_row_level_filters(
+ def get_sqla_row_level_filters(
self, template_processor: BaseTemplateProcessor
- ) -> List[str]:
+ ) -> List[TextClause]:
"""
Return the appropriate row level security filters for
this table and the current user.
@@ -1016,7 +1024,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
:param BaseTemplateProcessor template_processor: The template
processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
- :rtype: List[str]
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
@@ -1169,6 +1176,12 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
col = cast(AdhocMetric, col)
+ if col.get("sqlExpression"):
+ col["sqlExpression"] = validate_adhoc_subquery(
+ cast(str, col["sqlExpression"]),
+ self.database_id,
+ self.schema,
+ )
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
col = self.adhoc_metric_to_sqla(col, columns_by_name)
@@ -1218,7 +1231,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
elif selected in columns_by_name:
outer = columns_by_name[selected].get_sqla_col()
else:
- validate_adhoc_subquery(selected)
+ selected = validate_adhoc_subquery(
+ selected,
+ self.database_id,
+ self.schema,
+ )
outer = literal_column(f"({selected})")
outer = self.make_sqla_column_compatible(outer, selected)
else:
@@ -1231,7 +1248,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
select_exprs.append(outer)
elif columns:
for selected in columns:
- validate_adhoc_subquery(selected)
+ selected = validate_adhoc_subquery(
+ selected,
+ self.database_id,
+ self.schema,
+ )
select_exprs.append(
columns_by_name[selected].get_sqla_col()
if selected in columns_by_name
@@ -1402,7 +1423,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
_("Invalid filter operation type: %(op)s", op=op)
)
if is_feature_enabled("ROW_LEVEL_SECURITY"):
- where_clause_and += self._get_sqla_row_level_filters(template_processor)
+ where_clause_and += self.get_sqla_row_level_filters(template_processor)
if extras:
where = extras.get("where")
if where:
@@ -1449,7 +1470,6 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
and db_engine_spec.allows_hidden_cc_in_orderby
and col.name in [select_col.name for select_col in select_exprs]
):
- validate_adhoc_subquery(str(col.expression))
col = literal_column(col.name)
direction = asc if ascending else desc
qry = qry.order_by(direction(col))
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 4fc11a4d1d..766b74e57c 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -33,7 +33,7 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
-from superset.sql_parse import has_table_query, ParsedQuery, Table
+from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.tables.models import Table as NewTable
if TYPE_CHECKING:
@@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
return cols
-def validate_adhoc_subquery(raw_sql: str) -> None:
+def validate_adhoc_subquery(
+ sql: str,
+ database_id: int,
+ default_schema: str,
+) -> str:
"""
- Check if adhoc SQL contains sub-queries or nested sub-queries with table
- :param raw_sql: adhoc sql expression
+ Check if adhoc SQL contains sub-queries or nested sub-queries with table.
+
+ If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
+ predicates to it.
+
+ :param sql: adhoc sql expression
:raise SupersetSecurityException if sql contains sub-queries or
nested sub-queries with table
"""
# pylint: disable=import-outside-toplevel
from superset import is_feature_enabled
- if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
- return
-
- for statement in sqlparse.parse(raw_sql):
+ statements = []
+ for statement in sqlparse.parse(sql):
if has_table_query(statement):
- raise SupersetSecurityException(
- SupersetError(
- error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
- message=_("Custom SQL fields cannot contain sub-queries."),
- level=ErrorLevel.ERROR,
+ if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
+ raise SupersetSecurityException(
+ SupersetError(
+ error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
+ message=_("Custom SQL fields cannot contain sub-queries."),
+ level=ErrorLevel.ERROR,
+ )
)
- )
- return
+ statement = insert_rls(statement, database_id, default_schema)
+ statements.append(statement)
+
+ return ";\n".join(str(statement) for statement in statements)
def load_or_create_tables( # pylint: disable=too-many-arguments
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 95361b39a6..6bfb63c425 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -18,10 +18,11 @@ import logging
import re
from dataclasses import dataclass
from enum import Enum
-from typing import List, Optional, Set, Tuple
+from typing import cast, List, Optional, Set, Tuple
from urllib import parse
import sqlparse
+from sqlalchemy import and_
from sqlparse.sql import (
Identifier,
IdentifierList,
@@ -283,7 +284,7 @@ class ParsedQuery:
return statements
@staticmethod
- def _get_table(tlist: TokenList) -> Optional[Table]:
+ def get_table(tlist: TokenList) -> Optional[Table]:
"""
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
construct.
@@ -324,7 +325,7 @@ class ParsedQuery:
"""
# exclude subselects
if "(" not in str(token_list):
- table = self._get_table(token_list)
+ table = self.get_table(token_list)
if table and not table.table.startswith(CTE_PREFIX):
self._tables.add(table)
return
@@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
state = InsertRLSState.SCANNING
for token in token_list.tokens:
- # # Recurse into child token list
+ # Recurse into child token list
if isinstance(token, TokenList) and has_table_query(token):
return True
@@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
def add_table_name(rls: TokenList, table: str) -> None:
"""
- Modify a RLS expression ensuring columns are fully qualified.
+ Modify a RLS expression inplace ensuring columns are fully qualified.
"""
tokens = rls.tokens[:]
while tokens:
@@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
tokens.extend(token.tokens)
-def matches_table_name(candidate: Token, table: str) -> bool:
+def get_rls_for_table(
+ candidate: Token,
+ database_id: int,
+ default_schema: Optional[str],
+) -> Optional[TokenList]:
"""
- Returns if the token represents a reference to the table.
-
- Tables can be fully qualified with periods.
-
- Note that in theory a table should be represented as an identifier, but due to
- sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
- classified as a keyword.
+ Given a table name, return any associated RLS predicates.
"""
+ # pylint: disable=import-outside-toplevel
+ from superset import db
+ from superset.connectors.sqla.models import SqlaTable
+
if not isinstance(candidate, Identifier):
candidate = Identifier([Token(Name, candidate.value)])
- target = sqlparse.parse(table)[0].tokens[0]
- if not isinstance(target, Identifier):
- target = Identifier([Token(Name, target.value)])
+ table = ParsedQuery.get_table(candidate)
+ if not table:
+ return None
- # match from right to left, splitting on the period, eg, schema.table == table
- for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
- if left.value != right.value:
- return False
+ dataset = (
+ db.session.query(SqlaTable)
+ .filter(
+ and_(
+ SqlaTable.database_id == database_id,
+ SqlaTable.schema == (table.schema or default_schema),
+ SqlaTable.table_name == table.table,
+ )
+ )
+ .one_or_none()
+ )
+ if not dataset:
+ return None
+
+ template_processor = dataset.get_template_processor()
+ # pylint: disable=protected-access
+ predicate = " AND ".join(
+ str(filter_)
+ for filter_ in dataset.get_sqla_row_level_filters(template_processor)
+ )
+ if not predicate:
+ return None
+
+ rls = sqlparse.parse(predicate)[0]
+ add_table_name(rls, str(dataset))
- return True
+ return rls
-def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
+def insert_rls(
+ token_list: TokenList,
+ database_id: int,
+ default_schema: Optional[str],
+) -> TokenList:
"""
- Update a statement inplace applying an RLS associated with a given table.
+ Update a statement inplace applying any associated RLS predicates.
"""
- # make sure the identifier has the table name
- add_table_name(rls, table)
-
+ rls: Optional[TokenList] = None
state = InsertRLSState.SCANNING
for token in token_list.tokens:
# Recurse into child token list
if isinstance(token, TokenList):
i = token_list.tokens.index(token)
- token_list.tokens[i] = insert_rls(token, table, rls)
+ token_list.tokens[i] = insert_rls(token, database_id, default_schema)
# Found a source keyword (FROM/JOIN)
if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
@@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
elif state == InsertRLSState.SEEN_SOURCE and (
isinstance(token, Identifier) or token.ttype == Keyword
):
- if matches_table_name(token, table):
+ rls = get_rls_for_table(token, database_id, default_schema)
+ if rls:
state = InsertRLSState.FOUND_TABLE
# Found WHERE clause, insert RLS. Note that we insert it even it already exists,
# to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
+ rls = cast(TokenList, rls)
token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
token.tokens.extend(
[
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 75f099e52b..4a1ff89d74 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -14,21 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# pylint: disable=invalid-name, too-many-lines
+# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
import unittest
-from typing import Set
+from typing import Optional, Set
import pytest
import sqlparse
+from pytest_mock import MockerFixture
+from sqlalchemy import text
+from sqlparse.sql import Identifier, Token, TokenList
+from sqlparse.tokens import Name
from superset.exceptions import QueryClauseValidationException
from superset.sql_parse import (
add_table_name,
+ get_rls_for_table,
has_table_query,
insert_rls,
- matches_table_name,
ParsedQuery,
sanitize_clause,
strip_comments_from_sql,
@@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
),
],
)
-def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
+def test_insert_rls(
+ mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
+) -> None:
"""
Insert into a statement a given RLS condition associated with a table.
"""
- statement = sqlparse.parse(sql)[0]
condition = sqlparse.parse(rls)[0]
- assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
+ add_table_name(condition, table)
+
+ # pylint: disable=unused-argument
+ def get_rls_for_table(
+ candidate: Token, database_id: int, default_schema: str
+ ) -> Optional[TokenList]:
+ """
+ Return the RLS ``condition`` if ``candidate`` matches ``table``.
+ """
+ # compare ignoring schema
+ for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
+ if left != right:
+ return None
+ return condition
+
+ mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
+
+ statement = sqlparse.parse(sql)[0]
+ assert (
+ str(
+ insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
+ ).strip()
+ == expected.strip()
+ )
@pytest.mark.parametrize(
@@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
assert str(condition) == expected
-@pytest.mark.parametrize(
- "candidate,table,expected",
- [
- ("table", "table", True),
- ("schema.table", "table", True),
- ("table", "schema.table", True),
- ('schema."my table"', '"my table"', True),
- ('schema."my.table"', '"my.table"', True),
- ],
-)
-def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
- token = sqlparse.parse(candidate)[0].tokens[0]
- assert matches_table_name(token, table) == expected
+def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
+ """
+ Tests for ``get_rls_for_table``.
+ """
+ candidate = Identifier([Token(Name, "some_table")])
+ db = mocker.patch("superset.db")
+ dataset = db.session.query().filter().one_or_none()
+ dataset.__str__.return_value = "some_table"
+
+ dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
+ assert (
+ str(get_rls_for_table(candidate, 1, "public"))
+ == "some_table.organization_id = 1"
+ )
+
+ dataset.get_sqla_row_level_filters.return_value = [
+ text("organization_id = 1"),
+ text("foo = 'bar'"),
+ ]
+ assert (
+ str(get_rls_for_table(candidate, 1, "public"))
+ == "some_table.organization_id = 1 AND some_table.foo = 'bar'"
+ )
+
+ dataset.get_sqla_row_level_filters.return_value = []
+ assert get_rls_for_table(candidate, 1, "public") is None