You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2022/04/04 08:22:24 UTC

[superset] 15/24: feat: improve adhoc SQL validation (#19454)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to tag 1.5.0rc1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 840be9972fd3318d70c9a3a153c4b0db3867a6d7
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 31 11:55:19 2022 -0700

    feat: improve adhoc SQL validation (#19454)
    
    * feat: improve adhoc SQL validation
    
    * Small changes
    
    * Add more unit tests
    
    (cherry picked from commit 6828624f61fff21485b0b2e91ac53701d43cb0d7)
---
 superset/connectors/sqla/models.py  | 38 +++++++++++++----
 superset/connectors/sqla/utils.py   | 40 +++++++++++-------
 superset/sql_parse.py               | 84 ++++++++++++++++++++++++-------------
 tests/unit_tests/sql_parse_tests.py | 80 ++++++++++++++++++++++++++---------
 4 files changed, 170 insertions(+), 72 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 535eb81a57..31f3339541 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -923,7 +923,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         elif expression_type == utils.AdhocMetricExpressionType.SQL:
             tp = self.get_template_processor()
             expression = tp.process_template(cast(str, metric["sqlExpression"]))
-            validate_adhoc_subquery(expression)
+            expression = validate_adhoc_subquery(
+                expression,
+                self.database_id,
+                self.schema,
+            )
             try:
                 expression = sanitize_clause(expression)
             except QueryClauseValidationException as ex:
@@ -952,7 +956,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         if template_processor and expression:
             expression = template_processor.process_template(expression)
         if expression:
-            validate_adhoc_subquery(expression)
+            expression = validate_adhoc_subquery(
+                expression,
+                self.database_id,
+                self.schema,
+            )
             try:
                 expression = sanitize_clause(expression)
             except QueryClauseValidationException as ex:
@@ -1006,9 +1014,9 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             if is_alias_used_in_orderby(col):
                 col.name = f"{col.name}__"
 
-    def _get_sqla_row_level_filters(
+    def get_sqla_row_level_filters(
         self, template_processor: BaseTemplateProcessor
-    ) -> List[str]:
+    ) -> List[TextClause]:
         """
         Return the appropriate row level security filters for
         this table and the current user.
@@ -1016,7 +1024,6 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
         :param BaseTemplateProcessor template_processor: The template
         processor to apply to the filters.
         :returns: A list of SQL clauses to be ANDed together.
-        :rtype: List[str]
         """
         all_filters: List[TextClause] = []
         filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
@@ -1169,6 +1176,12 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
             col: Union[AdhocMetric, ColumnElement] = orig_col
             if isinstance(col, dict):
                 col = cast(AdhocMetric, col)
+                if col.get("sqlExpression"):
+                    col["sqlExpression"] = validate_adhoc_subquery(
+                        cast(str, col["sqlExpression"]),
+                        self.database_id,
+                        self.schema,
+                    )
                 if utils.is_adhoc_metric(col):
                     # add adhoc sort by column to columns_by_name if not exists
                     col = self.adhoc_metric_to_sqla(col, columns_by_name)
@@ -1218,7 +1231,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                     elif selected in columns_by_name:
                         outer = columns_by_name[selected].get_sqla_col()
                     else:
-                        validate_adhoc_subquery(selected)
+                        selected = validate_adhoc_subquery(
+                            selected,
+                            self.database_id,
+                            self.schema,
+                        )
                         outer = literal_column(f"({selected})")
                         outer = self.make_sqla_column_compatible(outer, selected)
                 else:
@@ -1231,7 +1248,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 select_exprs.append(outer)
         elif columns:
             for selected in columns:
-                validate_adhoc_subquery(selected)
+                selected = validate_adhoc_subquery(
+                    selected,
+                    self.database_id,
+                    self.schema,
+                )
                 select_exprs.append(
                     columns_by_name[selected].get_sqla_col()
                     if selected in columns_by_name
@@ -1402,7 +1423,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                             _("Invalid filter operation type: %(op)s", op=op)
                         )
         if is_feature_enabled("ROW_LEVEL_SECURITY"):
-            where_clause_and += self._get_sqla_row_level_filters(template_processor)
+            where_clause_and += self.get_sqla_row_level_filters(template_processor)
         if extras:
             where = extras.get("where")
             if where:
@@ -1449,7 +1470,6 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 and db_engine_spec.allows_hidden_cc_in_orderby
                 and col.name in [select_col.name for select_col in select_exprs]
             ):
-                validate_adhoc_subquery(str(col.expression))
                 col = literal_column(col.name)
             direction = asc if ascending else desc
             qry = qry.order_by(direction(col))
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 4fc11a4d1d..766b74e57c 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -33,7 +33,7 @@ from superset.exceptions import (
 )
 from superset.models.core import Database
 from superset.result_set import SupersetResultSet
-from superset.sql_parse import has_table_query, ParsedQuery, Table
+from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
 from superset.tables.models import Table as NewTable
 
 if TYPE_CHECKING:
@@ -136,29 +136,39 @@ def get_virtual_table_metadata(dataset: "SqlaTable") -> List[Dict[str, str]]:
     return cols
 
 
-def validate_adhoc_subquery(raw_sql: str) -> None:
+def validate_adhoc_subquery(
+    sql: str,
+    database_id: int,
+    default_schema: str,
+) -> str:
     """
-    Check if adhoc SQL contains sub-queries or nested sub-queries with table
-    :param raw_sql: adhoc sql expression
+    Check if adhoc SQL contains sub-queries or nested sub-queries with table.
+
+    If sub-queries are allowed, the adhoc SQL is modified to insert any applicable RLS
+    predicates to it.
+
+    :param sql: adhoc sql expression
     :raise SupersetSecurityException if sql contains sub-queries or
     nested sub-queries with table
     """
     # pylint: disable=import-outside-toplevel
     from superset import is_feature_enabled
 
-    if is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
-        return
-
-    for statement in sqlparse.parse(raw_sql):
+    statements = []
+    for statement in sqlparse.parse(sql):
         if has_table_query(statement):
-            raise SupersetSecurityException(
-                SupersetError(
-                    error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
-                    message=_("Custom SQL fields cannot contain sub-queries."),
-                    level=ErrorLevel.ERROR,
+            if not is_feature_enabled("ALLOW_ADHOC_SUBQUERY"):
+                raise SupersetSecurityException(
+                    SupersetError(
+                        error_type=SupersetErrorType.ADHOC_SUBQUERY_NOT_ALLOWED_ERROR,
+                        message=_("Custom SQL fields cannot contain sub-queries."),
+                        level=ErrorLevel.ERROR,
+                    )
                 )
-            )
-    return
+            statement = insert_rls(statement, database_id, default_schema)
+        statements.append(statement)
+
+    return ";\n".join(str(statement) for statement in statements)
 
 
 def load_or_create_tables(  # pylint: disable=too-many-arguments
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 95361b39a6..6bfb63c425 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -18,10 +18,11 @@ import logging
 import re
 from dataclasses import dataclass
 from enum import Enum
-from typing import List, Optional, Set, Tuple
+from typing import cast, List, Optional, Set, Tuple
 from urllib import parse
 
 import sqlparse
+from sqlalchemy import and_
 from sqlparse.sql import (
     Identifier,
     IdentifierList,
@@ -283,7 +284,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def _get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Optional[Table]:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -324,7 +325,7 @@ class ParsedQuery:
         """
         # exclude subselects
         if "(" not in str(token_list):
-            table = self._get_table(token_list)
+            table = self.get_table(token_list)
             if table and not table.table.startswith(CTE_PREFIX):
                 self._tables.add(table)
             return
@@ -500,7 +501,7 @@ def has_table_query(token_list: TokenList) -> bool:
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
 
-        # # Recurse into child token list
+        # Recurse into child token list
         if isinstance(token, TokenList) and has_table_query(token):
             return True
 
@@ -523,7 +524,7 @@ def has_table_query(token_list: TokenList) -> bool:
 
 def add_table_name(rls: TokenList, table: str) -> None:
     """
-    Modify a RLS expression ensuring columns are fully qualified.
+    Modify a RLS expression inplace ensuring columns are fully qualified.
     """
     tokens = rls.tokens[:]
     while tokens:
@@ -539,45 +540,70 @@ def add_table_name(rls: TokenList, table: str) -> None:
             tokens.extend(token.tokens)
 
 
-def matches_table_name(candidate: Token, table: str) -> bool:
+def get_rls_for_table(
+    candidate: Token,
+    database_id: int,
+    default_schema: Optional[str],
+) -> Optional[TokenList]:
     """
-    Returns if the token represents a reference to the table.
-
-    Tables can be fully qualified with periods.
-
-    Note that in theory a table should be represented as an identifier, but due to
-    sqlparse's aggressive list of keywords (spanning multiple dialects) often it gets
-    classified as a keyword.
+    Given a table name, return any associated RLS predicates.
     """
+    # pylint: disable=import-outside-toplevel
+    from superset import db
+    from superset.connectors.sqla.models import SqlaTable
+
     if not isinstance(candidate, Identifier):
         candidate = Identifier([Token(Name, candidate.value)])
 
-    target = sqlparse.parse(table)[0].tokens[0]
-    if not isinstance(target, Identifier):
-        target = Identifier([Token(Name, target.value)])
+    table = ParsedQuery.get_table(candidate)
+    if not table:
+        return None
 
-    # match from right to left, splitting on the period, eg, schema.table == table
-    for left, right in zip(candidate.tokens[::-1], target.tokens[::-1]):
-        if left.value != right.value:
-            return False
+    dataset = (
+        db.session.query(SqlaTable)
+        .filter(
+            and_(
+                SqlaTable.database_id == database_id,
+                SqlaTable.schema == (table.schema or default_schema),
+                SqlaTable.table_name == table.table,
+            )
+        )
+        .one_or_none()
+    )
+    if not dataset:
+        return None
+
+    template_processor = dataset.get_template_processor()
+    # pylint: disable=protected-access
+    predicate = " AND ".join(
+        str(filter_)
+        for filter_ in dataset.get_sqla_row_level_filters(template_processor)
+    )
+    if not predicate:
+        return None
+
+    rls = sqlparse.parse(predicate)[0]
+    add_table_name(rls, str(dataset))
 
-    return True
+    return rls
 
 
-def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
+def insert_rls(
+    token_list: TokenList,
+    database_id: int,
+    default_schema: Optional[str],
+) -> TokenList:
     """
-    Update a statement inplace applying an RLS associated with a given table.
+    Update a statement inplace applying any associated RLS predicates.
     """
-    # make sure the identifier has the table name
-    add_table_name(rls, table)
-
+    rls: Optional[TokenList] = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
 
         # Recurse into child token list
         if isinstance(token, TokenList):
             i = token_list.tokens.index(token)
-            token_list.tokens[i] = insert_rls(token, table, rls)
+            token_list.tokens[i] = insert_rls(token, database_id, default_schema)
 
         # Found a source keyword (FROM/JOIN)
         if imt(token, m=[(Keyword, "FROM"), (Keyword, "JOIN")]):
@@ -587,12 +613,14 @@ def insert_rls(token_list: TokenList, table: str, rls: TokenList) -> TokenList:
         elif state == InsertRLSState.SEEN_SOURCE and (
             isinstance(token, Identifier) or token.ttype == Keyword
         ):
-            if matches_table_name(token, table):
+            rls = get_rls_for_table(token, database_id, default_schema)
+            if rls:
                 state = InsertRLSState.FOUND_TABLE
 
         # Found WHERE clause, insert RLS. Note that we insert it even it already exists,
         # to be on the safe side: it could be present in a clause like `1=1 OR RLS`.
         elif state == InsertRLSState.FOUND_TABLE and isinstance(token, Where):
+            rls = cast(TokenList, rls)
             token.tokens[1:1] = [Token(Whitespace, " "), Token(Punctuation, "(")]
             token.tokens.extend(
                 [
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 75f099e52b..4a1ff89d74 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -14,21 +14,24 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
-# pylint: disable=invalid-name, too-many-lines
+# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
 
 import unittest
-from typing import Set
+from typing import Optional, Set
 
 import pytest
 import sqlparse
+from pytest_mock import MockerFixture
+from sqlalchemy import text
+from sqlparse.sql import Identifier, Token, TokenList
+from sqlparse.tokens import Name
 
 from superset.exceptions import QueryClauseValidationException
 from superset.sql_parse import (
     add_table_name,
+    get_rls_for_table,
     has_table_query,
     insert_rls,
-    matches_table_name,
     ParsedQuery,
     sanitize_clause,
     strip_comments_from_sql,
@@ -1391,13 +1394,37 @@ def test_has_table_query(sql: str, expected: bool) -> None:
         ),
     ],
 )
-def test_insert_rls(sql: str, table: str, rls: str, expected: str) -> None:
+def test_insert_rls(
+    mocker: MockerFixture, sql: str, table: str, rls: str, expected: str
+) -> None:
     """
     Insert into a statement a given RLS condition associated with a table.
     """
-    statement = sqlparse.parse(sql)[0]
     condition = sqlparse.parse(rls)[0]
-    assert str(insert_rls(statement, table, condition)).strip() == expected.strip()
+    add_table_name(condition, table)
+
+    # pylint: disable=unused-argument
+    def get_rls_for_table(
+        candidate: Token, database_id: int, default_schema: str
+    ) -> Optional[TokenList]:
+        """
+        Return the RLS ``condition`` if ``candidate`` matches ``table``.
+        """
+        # compare ignoring schema
+        for left, right in zip(str(candidate).split(".")[::-1], table.split(".")[::-1]):
+            if left != right:
+                return None
+        return condition
+
+    mocker.patch("superset.sql_parse.get_rls_for_table", new=get_rls_for_table)
+
+    statement = sqlparse.parse(sql)[0]
+    assert (
+        str(
+            insert_rls(token_list=statement, database_id=1, default_schema="my_schema")
+        ).strip()
+        == expected.strip()
+    )
 
 
 @pytest.mark.parametrize(
@@ -1415,16 +1442,29 @@ def test_add_table_name(rls: str, table: str, expected: str) -> None:
     assert str(condition) == expected
 
 
-@pytest.mark.parametrize(
-    "candidate,table,expected",
-    [
-        ("table", "table", True),
-        ("schema.table", "table", True),
-        ("table", "schema.table", True),
-        ('schema."my table"', '"my table"', True),
-        ('schema."my.table"', '"my.table"', True),
-    ],
-)
-def test_matches_table_name(candidate: str, table: str, expected: bool) -> None:
-    token = sqlparse.parse(candidate)[0].tokens[0]
-    assert matches_table_name(token, table) == expected
+def test_get_rls_for_table(mocker: MockerFixture, app_context: None) -> None:
+    """
+    Tests for ``get_rls_for_table``.
+    """
+    candidate = Identifier([Token(Name, "some_table")])
+    db = mocker.patch("superset.db")
+    dataset = db.session.query().filter().one_or_none()
+    dataset.__str__.return_value = "some_table"
+
+    dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")]
+    assert (
+        str(get_rls_for_table(candidate, 1, "public"))
+        == "some_table.organization_id = 1"
+    )
+
+    dataset.get_sqla_row_level_filters.return_value = [
+        text("organization_id = 1"),
+        text("foo = 'bar'"),
+    ]
+    assert (
+        str(get_rls_for_table(candidate, 1, "public"))
+        == "some_table.organization_id = 1 AND some_table.foo = 'bar'"
+    )
+
+    dataset.get_sqla_row_level_filters.return_value = []
+    assert get_rls_for_table(candidate, 1, "public") is None