You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2021/12/09 15:51:02 UTC

[superset] branch master updated: chore(sql): clean up invalid filter clause exception types (#17702)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a42071  chore(sql): clean up invalid filter clause exception types (#17702)
3a42071 is described below

commit 3a42071e0ff181a7a0f1b55a69e39440e2570018
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Thu Dec 9 17:49:32 2021 +0200

    chore(sql): clean up invalid filter clause exception types (#17702)
    
    * chore(sql): clean up invalid filter clause exception types
    
    * fix lint
    
    * rename exception
---
 superset/common/query_object.py                  | 16 ++++++-
 superset/exceptions.py                           |  4 ++
 superset/sql_parse.py                            | 22 +++++++++
 superset/viz.py                                  | 11 +++++
 tests/integration_tests/charts/data/api_tests.py | 22 +++++++++
 tests/unit_tests/sql_parse_tests.py              | 57 +++++++++++++++++++++++-
 6 files changed, 130 insertions(+), 2 deletions(-)

diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index ff1ad71..03ee9cb 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -25,7 +25,11 @@ from flask_babel import gettext as _
 from pandas import DataFrame
 
 from superset.common.chart_data import ChartDataResultType
-from superset.exceptions import QueryObjectValidationError
+from superset.exceptions import (
+    QueryClauseValidationException,
+    QueryObjectValidationError,
+)
+from superset.sql_parse import validate_filter_clause
 from superset.typing import Column, Metric, OrderBy
 from superset.utils import pandas_postprocessing
 from superset.utils.core import (
@@ -267,6 +271,7 @@ class QueryObject:  # pylint: disable=too-many-instance-attributes
         try:
             self._validate_there_are_no_missing_series()
             self._validate_no_have_duplicate_labels()
+            self._validate_filters()
             return None
         except QueryObjectValidationError as ex:
             if raise_exceptions:
@@ -285,6 +290,15 @@ class QueryObject:  # pylint: disable=too-many-instance-attributes
                 )
             )
 
+    def _validate_filters(self) -> None:
+        for param in ("where", "having"):
+            clause = self.extras.get(param)
+            if clause:
+                try:
+                    validate_filter_clause(clause)
+                except QueryClauseValidationException as ex:
+                    raise QueryObjectValidationError(ex.message) from ex
+
     def _validate_there_are_no_missing_series(self) -> None:
         missing_series = [col for col in self.series_columns if col not in self.columns]
         if missing_series:
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 76da484..2a90260 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -194,6 +194,10 @@ class CacheLoadError(SupersetException):
     status = 404
 
 
+class QueryClauseValidationException(SupersetException):
+    status = 400
+
+
 class DashboardImportException(SupersetException):
     pass
 
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 47a9e5c..1130763 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -32,6 +32,8 @@ from sqlparse.sql import (
 from sqlparse.tokens import DDL, DML, Keyword, Name, Punctuation, String, Whitespace
 from sqlparse.utils import imt
 
+from superset.exceptions import QueryClauseValidationException
+
 RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
 ON_KEYWORD = "ON"
 PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
@@ -378,3 +380,23 @@ class ParsedQuery:
         for i in statement.tokens:
             str_res += str(i.value)
         return str_res
+
+
+def validate_filter_clause(clause: str) -> None:
+    if sqlparse.format(clause, strip_comments=True) != sqlparse.format(clause):
+        raise QueryClauseValidationException("Filter clause contains comment")
+
+    statements = sqlparse.parse(clause)
+    if len(statements) != 1:
+        raise QueryClauseValidationException("Filter clause contains multiple queries")
+    open_parens = 0
+
+    for token in statements[0]:
+        if token.value in (")", "("):
+            open_parens += 1 if token.value == "(" else -1
+            if open_parens < 0:
+                raise QueryClauseValidationException(
+                    "Closing unclosed parenthesis in filter clause"
+                )
+    if open_parens > 0:
+        raise QueryClauseValidationException("Unclosed parenthesis in filter clause")
diff --git a/superset/viz.py b/superset/viz.py
index 53bc333..23f2cf3 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -62,12 +62,14 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import (
     CacheLoadError,
     NullValueException,
+    QueryClauseValidationException,
     QueryObjectValidationError,
     SpatialException,
     SupersetSecurityException,
 )
 from superset.extensions import cache_manager, security_manager
 from superset.models.helpers import QueryResult
+from superset.sql_parse import validate_filter_clause
 from superset.typing import Column, Metric, QueryObjectDict, VizData, VizPayload
 from superset.utils import core as utils, csv
 from superset.utils.cache import set_and_log_cache
@@ -373,6 +375,15 @@ class BaseViz:  # pylint: disable=too-many-public-methods
         self.from_dttm = from_dttm
         self.to_dttm = to_dttm
 
+        # validate sql filters
+        for param in ("where", "having"):
+            clause = self.form_data.get(param)
+            if clause:
+                try:
+                    validate_filter_clause(clause)
+                except QueryClauseValidationException as ex:
+                    raise QueryObjectValidationError(ex.message) from ex
+
         # extras are used to query elements specific to a datasource type
         # for instance the extra where clause that applies only to Tables
         extras = {
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index 12d667f..cf6d0b5 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -425,6 +425,28 @@ class TestPostChartDataApi(BaseTestChartDataApi):
 
         assert rv.status_code == 400
 
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_with_invalid_where_parameter_closing_unclosed__400(self):
+        self.query_context_payload["queries"][0]["filters"] = []
+        self.query_context_payload["queries"][0]["extras"][
+            "where"
+        ] = "state = 'CA') OR (state = 'NY'"
+
+        rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
+
+        assert rv.status_code == 400
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_with_invalid_having_parameter_closing_and_comment__400(self):
+        self.query_context_payload["queries"][0]["filters"] = []
+        self.query_context_payload["queries"][0]["extras"][
+            "having"
+        ] = "COUNT(1) = 0) UNION ALL SELECT 'abc', 1--comment"
+
+        rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
+
+        assert rv.status_code == 400
+
     def test_with_invalid_datasource__400(self):
         self.query_context_payload["datasource"] = "abc"
 
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 61ea2e0..f405b9f 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -20,9 +20,16 @@
 import unittest
 from typing import Set
 
+import pytest
 import sqlparse
 
-from superset.sql_parse import ParsedQuery, strip_comments_from_sql, Table
+from superset.exceptions import QueryClauseValidationException
+from superset.sql_parse import (
+    ParsedQuery,
+    strip_comments_from_sql,
+    Table,
+    validate_filter_clause,
+)
 
 
 def extract_tables(query: str) -> Set[Table]:
@@ -1144,3 +1151,51 @@ def test_strip_comments_from_sql() -> None:
         strip_comments_from_sql("SELECT '--abc' as abc, col2 FROM table1\n")
         == "SELECT '--abc' as abc, col2 FROM table1"
     )
+
+
+def test_validate_filter_clause_valid():
+    # regular clauses
+    assert validate_filter_clause("col = 1") is None
+    assert validate_filter_clause("1=\t\n1") is None
+    assert validate_filter_clause("(col = 1)") is None
+    assert validate_filter_clause("(col1 = 1) AND (col2 = 2)") is None
+
+    # Valid literal values that appear to be invalid
+    assert validate_filter_clause("col = 'col1 = 1) AND (col2 = 2'") is None
+    assert validate_filter_clause("col = 'select 1; select 2'") is None
+    assert validate_filter_clause("col = 'abc -- comment'") is None
+
+
+def test_validate_filter_clause_closing_unclosed():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("col1 = 1) AND (col2 = 2)")
+
+
+def test_validate_filter_clause_unclosed():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("(col1 = 1) AND (col2 = 2")
+
+
+def test_validate_filter_clause_closing_and_unclosed():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("col1 = 1) AND (col2 = 2")
+
+
+def test_validate_filter_clause_closing_and_unclosed_nested():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("(col1 = 1)) AND ((col2 = 2)")
+
+
+def test_validate_filter_clause_multiple():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("TRUE; SELECT 1")
+
+
+def test_validate_filter_clause_comment():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("1 = 1 -- comment")
+
+
+def test_validate_filter_clause_subquery_comment():
+    with pytest.raises(QueryClauseValidationException):
+        validate_filter_clause("(1 = 1 -- comment\n)")