You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/01/05 19:57:25 UTC

[superset] 05/14: chore(sqla): refactor query utils (#21811)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 7c98e266ce6bd39fe8db22d8e25af711012bc9f2
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Mon Oct 17 10:40:42 2022 +0100

    chore(sqla): refactor query utils (#21811)
    
    Co-authored-by: Ville Brofeldt <vi...@apple.com>
---
 superset/connectors/sqla/models.py               | 26 +++++--
 superset/connectors/sqla/utils.py                |  7 ++
 superset/models/core.py                          | 11 ++-
 tests/integration_tests/charts/data/api_tests.py | 71 ++++++++++++++++-
 tests/integration_tests/conftest.py              | 99 +++++++++++++++++++++++-
 tests/integration_tests/sqla_models_tests.py     |  2 +-
 tests/integration_tests/test_app.py              | 20 ++++-
 7 files changed, 221 insertions(+), 15 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 4697530328..dbcfb80eb3 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -83,6 +83,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampEx
 from superset.exceptions import (
     QueryClauseValidationException,
     QueryObjectValidationError,
+    SupersetSecurityException,
 )
 from superset.jinja_context import (
     BaseTemplateProcessor,
@@ -514,19 +515,19 @@ def _process_sql_expression(
     expression: Optional[str],
     database_id: int,
     schema: str,
-    template_processor: Optional[BaseTemplateProcessor],
+    template_processor: Optional[BaseTemplateProcessor] = None,
 ) -> Optional[str]:
     if template_processor and expression:
         expression = template_processor.process_template(expression)
     if expression:
-        expression = validate_adhoc_subquery(
-            expression,
-            database_id,
-            schema,
-        )
         try:
+            expression = validate_adhoc_subquery(
+                expression,
+                database_id,
+                schema,
+            )
             expression = sanitize_clause(expression)
-        except QueryClauseValidationException as ex:
+        except (QueryClauseValidationException, SupersetSecurityException) as ex:
             raise QueryObjectValidationError(ex.message) from ex
     return expression
 
@@ -1465,6 +1466,11 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                             msg=ex.message,
                         )
                     ) from ex
+                where = _process_sql_expression(
+                    expression=where,
+                    database_id=self.database_id,
+                    schema=self.schema,
+                )
                 where_clause_and += [self.text(where)]
             having = extras.get("having")
             if having:
@@ -1477,7 +1483,13 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                             msg=ex.message,
                         )
                     ) from ex
+                having = _process_sql_expression(
+                    expression=having,
+                    database_id=self.database_id,
+                    schema=self.schema,
+                )
                 having_clause_and += [self.text(having)]
+
         if apply_fetch_values_predicate and self.fetch_values_predicate:
             qry = qry.where(self.get_fetch_values_predicate())
         if granularity:
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index a2b54201d6..5359c9e214 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -22,6 +22,7 @@ import sqlparse
 from flask_babel import lazy_gettext as _
 from sqlalchemy import and_, inspect, or_
 from sqlalchemy.engine import Engine
+from sqlalchemy.engine.url import URL as SqlaURL
 from sqlalchemy.exc import NoSuchTableError
 from sqlalchemy.orm import Session
 from sqlalchemy.sql.type_api import TypeEngine
@@ -37,6 +38,7 @@ from superset.result_set import SupersetResultSet
 from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
 from superset.superset_typing import ResultSetColumnType
 from superset.tables.models import Table as NewTable
+from superset.utils.memoized import memoized
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import SqlaTable
@@ -252,3 +254,8 @@ def load_or_create_tables(  # pylint: disable=too-many-arguments
             existing.add((table.schema, table.table))
 
     return new_tables
+
+
+@memoized
+def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
+    return SqlaURL(drivername=drivername).get_dialect()().identifier_preparer.quote
diff --git a/superset/models/core.py b/superset/models/core.py
index fcc7cf16d8..97e1d763b9 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -21,7 +21,7 @@ import json
 import logging
 import textwrap
 from ast import literal_eval
-from contextlib import closing
+from contextlib import closing, contextmanager
 from copy import deepcopy
 from datetime import datetime
 from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
@@ -345,6 +345,15 @@ class Database(
                 effective_username = g.user.username
         return effective_username
 
+    @contextmanager
+    def get_sqla_engine_with_context(
+        self,
+        schema: Optional[str] = None,
+        nullpool: bool = True,
+        source: Optional[utils.QuerySource] = None,
+    ) -> Engine:
+        yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
+
     @memoized(
         watch=(
             "impersonate_user",
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index e8f258421f..212c9d01af 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -21,7 +21,7 @@ import unittest
 import copy
 from datetime import datetime
 from io import BytesIO
-from typing import Optional
+from typing import Optional, Dict, Any
 from unittest import mock
 from zipfile import ZipFile
 
@@ -974,3 +974,72 @@ class TestGetChartDataApi(BaseTestChartDataApi):
         unique_genders = {row["male_or_female"] for row in data}
         assert unique_genders == {"male", "female"}
         assert result["applied_filters"] == [{"column": "male_or_female"}]
+
+
+@pytest.fixture()
+def physical_query_context(physical_dataset) -> Dict[str, Any]:
+    return {
+        "datasource": {
+            "type": physical_dataset.type,
+            "id": physical_dataset.id,
+        },
+        "queries": [
+            {
+                "columns": ["col1"],
+                "metrics": ["count"],
+                "orderby": [["col1", True]],
+            }
+        ],
+        "result_type": ChartDataResultType.FULL,
+        "force": True,
+    }
+
+
+@pytest.mark.parametrize(
+    "status_code,extras",
+    [
+        (200, {"where": "1 = 1"}),
+        (200, {"having": "count(*) > 0"}),
+        (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
+        (400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
+    ],
+)
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
+@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+def test_chart_data_subquery_not_allowed(
+    test_client,
+    login_as_admin,
+    physical_dataset,
+    physical_query_context,
+    status_code,
+    extras,
+):
+    physical_query_context["queries"][0]["extras"] = extras
+    rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
+
+    assert rv.status_code == status_code
+
+
+@pytest.mark.parametrize(
+    "status_code,extras",
+    [
+        (200, {"where": "1 = 1"}),
+        (200, {"having": "count(*) > 0"}),
+        (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
+        (200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
+    ],
+)
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
+@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+def test_chart_data_subquery_allowed(
+    test_client,
+    login_as_admin,
+    physical_dataset,
+    physical_query_context,
+    status_code,
+    extras,
+):
+    physical_query_context["queries"][0]["extras"] = extras
+    rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
+
+    assert rv.status_code == status_code
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index fee13c8950..c605819ee6 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -21,13 +21,15 @@ from typing import Any, Callable, Generator, Optional, TYPE_CHECKING
 from unittest.mock import patch
 
 import pytest
+from flask.ctx import AppContext
+from flask.testing import FlaskClient
 from sqlalchemy.engine import Engine
 
 from superset import db
 from superset.extensions import feature_flag_manager
 from superset.utils.core import json_dumps_w_dates
 from superset.utils.database import get_example_database, remove_database
-from tests.integration_tests.test_app import app
+from tests.integration_tests.test_app import app, login
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import Database
@@ -42,6 +44,29 @@ def app_context():
         yield
 
 
+@pytest.fixture
+def test_client(app_context: AppContext):
+    with app.test_client() as client:
+        yield client
+
+
+@pytest.fixture
+def login_as(test_client: "FlaskClient[Any]"):
+    """Fixture with app context and logged in admin user."""
+
+    def _login_as(username: str, password: str = "general"):
+        login(test_client, username=username, password=password)
+
+    yield _login_as
+    # no need to log out as both app_context and test_client are
+    # function level fixtures anyway
+
+
+@pytest.fixture
+def login_as_admin(login_as: Callable[..., None]):
+    yield login_as("admin")
+
+
 @pytest.fixture(autouse=True, scope="session")
 def setup_sample_data() -> Any:
     # TODO(john-bodley): Determine a cleaner way of setting up the sample data without
@@ -180,3 +205,75 @@ def with_feature_flags(**mock_feature_flags):
         return functools.update_wrapper(wrapper, test_fn)
 
     return decorate
+
+
+@pytest.fixture
+def physical_dataset():
+    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+    from superset.connectors.sqla.utils import get_identifier_quoter
+
+    example_database = get_example_database()
+
+    with example_database.get_sqla_engine_with_context() as engine:
+        quoter = get_identifier_quoter(engine.name)
+        # sqlite can only execute one statement at a time
+        engine.execute(
+            f"""
+            CREATE TABLE IF NOT EXISTS physical_dataset(
+            col1 INTEGER,
+            col2 VARCHAR(255),
+            col3 DECIMAL(4,2),
+            col4 VARCHAR(255),
+            col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
+            col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
+            {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
+            );
+            """
+        )
+        engine.execute(
+            """
+            INSERT INTO physical_dataset values
+            (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
+            (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
+            (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
+            (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
+            (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
+            (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
+            (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
+            (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
+            (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
+            (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
+        """
+        )
+
+    dataset = SqlaTable(
+        table_name="physical_dataset",
+        database=example_database,
+    )
+    TableColumn(column_name="col1", type="INTEGER", table=dataset)
+    TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+    TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
+    TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
+    TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
+    TableColumn(column_name="col6", type="TIMESTAMP", is_dttm=True, table=dataset)
+    TableColumn(
+        column_name="time column with spaces",
+        type="TIMESTAMP",
+        is_dttm=True,
+        table=dataset,
+    )
+    SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+    db.session.merge(dataset)
+    db.session.commit()
+
+    yield dataset
+
+    engine.execute(
+        """
+        DROP TABLE physical_dataset;
+    """
+    )
+    dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
+    for ds in dataset:
+        db.session.delete(ds)
+    db.session.commit()
diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py
index 8990243c6b..f06836720d 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -262,7 +262,7 @@ class TestDatabaseModel(SupersetTestCase):
         )
         db.session.commit()
 
-        with pytest.raises(SupersetSecurityException):
+        with pytest.raises(QueryObjectValidationError):
             table.get_sqla_query(**base_query_obj)
         # Cleanup
         db.session.delete(table)
diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py
index 798f3e9cda..fb7b47b67c 100644
--- a/tests/integration_tests/test_app.py
+++ b/tests/integration_tests/test_app.py
@@ -14,11 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+from typing import TYPE_CHECKING
 
-"""
-Here is where we create the app which ends up being shared across all tests.integration_tests. A future
-optimization will be to create a separate app instance for each test class.
-"""
 from superset.app import create_app
 
+if TYPE_CHECKING:
+    from typing import Any
+
+    from flask.testing import FlaskClient
+
 app = create_app()
+
+
+def login(
+    client: "FlaskClient[Any]", username: str = "admin", password: str = "general"
+):
+    resp = client.post(
+        "/login/",
+        data=dict(username=username, password=password),
+    ).get_data(as_text=True)
+    assert "User confirmation needed" not in resp