You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/01/05 19:57:25 UTC
[superset] 05/14: chore(sqla): refactor query utils (#21811)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 7c98e266ce6bd39fe8db22d8e25af711012bc9f2
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Mon Oct 17 10:40:42 2022 +0100
chore(sqla): refactor query utils (#21811)
Co-authored-by: Ville Brofeldt <vi...@apple.com>
---
superset/connectors/sqla/models.py | 26 +++++--
superset/connectors/sqla/utils.py | 7 ++
superset/models/core.py | 11 ++-
tests/integration_tests/charts/data/api_tests.py | 71 ++++++++++++++++-
tests/integration_tests/conftest.py | 99 +++++++++++++++++++++++-
tests/integration_tests/sqla_models_tests.py | 2 +-
tests/integration_tests/test_app.py | 20 ++++-
7 files changed, 221 insertions(+), 15 deletions(-)
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 4697530328..dbcfb80eb3 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -83,6 +83,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampEx
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
+ SupersetSecurityException,
)
from superset.jinja_context import (
BaseTemplateProcessor,
@@ -514,19 +515,19 @@ def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
- template_processor: Optional[BaseTemplateProcessor],
+ template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
- expression = validate_adhoc_subquery(
- expression,
- database_id,
- schema,
- )
try:
+ expression = validate_adhoc_subquery(
+ expression,
+ database_id,
+ schema,
+ )
expression = sanitize_clause(expression)
- except QueryClauseValidationException as ex:
+ except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression
@@ -1465,6 +1466,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
+ where = _process_sql_expression(
+ expression=where,
+ database_id=self.database_id,
+ schema=self.schema,
+ )
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
@@ -1477,7 +1483,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message,
)
) from ex
+ having = _process_sql_expression(
+ expression=having,
+ database_id=self.database_id,
+ schema=self.schema,
+ )
having_clause_and += [self.text(having)]
+
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index a2b54201d6..5359c9e214 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -22,6 +22,7 @@ import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import Engine
+from sqlalchemy.engine.url import URL as SqlaURL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
@@ -37,6 +38,7 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.tables.models import Table as NewTable
+from superset.utils.memoized import memoized
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
@@ -252,3 +254,8 @@ def load_or_create_tables( # pylint: disable=too-many-arguments
existing.add((table.schema, table.table))
return new_tables
+
+
+@memoized
+def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
+ return SqlaURL(drivername=drivername).get_dialect()().identifier_preparer.quote
diff --git a/superset/models/core.py b/superset/models/core.py
index fcc7cf16d8..97e1d763b9 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -21,7 +21,7 @@ import json
import logging
import textwrap
from ast import literal_eval
-from contextlib import closing
+from contextlib import closing, contextmanager
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
@@ -345,6 +345,15 @@ class Database(
effective_username = g.user.username
return effective_username
+ @contextmanager
+ def get_sqla_engine_with_context(
+ self,
+ schema: Optional[str] = None,
+ nullpool: bool = True,
+ source: Optional[utils.QuerySource] = None,
+ ) -> Engine:
+ yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)
+
@memoized(
watch=(
"impersonate_user",
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index e8f258421f..212c9d01af 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -21,7 +21,7 @@ import unittest
import copy
from datetime import datetime
from io import BytesIO
-from typing import Optional
+from typing import Optional, Dict, Any
from unittest import mock
from zipfile import ZipFile
@@ -974,3 +974,72 @@ class TestGetChartDataApi(BaseTestChartDataApi):
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]
+
+
+@pytest.fixture()
+def physical_query_context(physical_dataset) -> Dict[str, Any]:
+ return {
+ "datasource": {
+ "type": physical_dataset.type,
+ "id": physical_dataset.id,
+ },
+ "queries": [
+ {
+ "columns": ["col1"],
+ "metrics": ["count"],
+ "orderby": [["col1", True]],
+ }
+ ],
+ "result_type": ChartDataResultType.FULL,
+ "force": True,
+ }
+
+
+@pytest.mark.parametrize(
+ "status_code,extras",
+ [
+ (200, {"where": "1 = 1"}),
+ (200, {"having": "count(*) > 0"}),
+ (400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
+ (400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
+ ],
+)
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
+@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+def test_chart_data_subquery_not_allowed(
+ test_client,
+ login_as_admin,
+ physical_dataset,
+ physical_query_context,
+ status_code,
+ extras,
+):
+ physical_query_context["queries"][0]["extras"] = extras
+ rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
+
+ assert rv.status_code == status_code
+
+
+@pytest.mark.parametrize(
+ "status_code,extras",
+ [
+ (200, {"where": "1 = 1"}),
+ (200, {"having": "count(*) > 0"}),
+ (200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
+ (200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
+ ],
+)
+@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
+@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+def test_chart_data_subquery_allowed(
+ test_client,
+ login_as_admin,
+ physical_dataset,
+ physical_query_context,
+ status_code,
+ extras,
+):
+ physical_query_context["queries"][0]["extras"] = extras
+ rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
+
+ assert rv.status_code == status_code
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index fee13c8950..c605819ee6 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -21,13 +21,15 @@ from typing import Any, Callable, Generator, Optional, TYPE_CHECKING
from unittest.mock import patch
import pytest
+from flask.ctx import AppContext
+from flask.testing import FlaskClient
from sqlalchemy.engine import Engine
from superset import db
from superset.extensions import feature_flag_manager
from superset.utils.core import json_dumps_w_dates
from superset.utils.database import get_example_database, remove_database
-from tests.integration_tests.test_app import app
+from tests.integration_tests.test_app import app, login
if TYPE_CHECKING:
from superset.connectors.sqla.models import Database
@@ -42,6 +44,29 @@ def app_context():
yield
+@pytest.fixture
+def test_client(app_context: AppContext):
+ with app.test_client() as client:
+ yield client
+
+
+@pytest.fixture
+def login_as(test_client: "FlaskClient[Any]"):
+ """Fixture with app context and logged in admin user."""
+
+ def _login_as(username: str, password: str = "general"):
+ login(test_client, username=username, password=password)
+
+ yield _login_as
+ # no need to log out as both app_context and test_client are
+ # function level fixtures anyway
+
+
+@pytest.fixture
+def login_as_admin(login_as: Callable[..., None]):
+ yield login_as("admin")
+
+
@pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without
@@ -180,3 +205,75 @@ def with_feature_flags(**mock_feature_flags):
return functools.update_wrapper(wrapper, test_fn)
return decorate
+
+
+@pytest.fixture
+def physical_dataset():
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+ from superset.connectors.sqla.utils import get_identifier_quoter
+
+ example_database = get_example_database()
+
+ with example_database.get_sqla_engine_with_context() as engine:
+ quoter = get_identifier_quoter(engine.name)
+ # sqlite can only execute one statement at a time
+ engine.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS physical_dataset(
+ col1 INTEGER,
+ col2 VARCHAR(255),
+ col3 DECIMAL(4,2),
+ col4 VARCHAR(255),
+ col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
+ col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
+ {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
+ );
+ """
+ )
+ engine.execute(
+ """
+ INSERT INTO physical_dataset values
+ (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
+ (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
+ (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
+ (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
+ (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
+ (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
+ (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
+ (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
+ (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
+ (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
+ """
+ )
+
+ dataset = SqlaTable(
+ table_name="physical_dataset",
+ database=example_database,
+ )
+ TableColumn(column_name="col1", type="INTEGER", table=dataset)
+ TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
+ TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
+ TableColumn(column_name="col6", type="TIMESTAMP", is_dttm=True, table=dataset)
+ TableColumn(
+ column_name="time column with spaces",
+ type="TIMESTAMP",
+ is_dttm=True,
+ table=dataset,
+ )
+ SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+ db.session.merge(dataset)
+ db.session.commit()
+
+ yield dataset
+
+ engine.execute(
+ """
+ DROP TABLE physical_dataset;
+ """
+ )
+ dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
+ for ds in dataset:
+ db.session.delete(ds)
+ db.session.commit()
diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py
index 8990243c6b..f06836720d 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -262,7 +262,7 @@ class TestDatabaseModel(SupersetTestCase):
)
db.session.commit()
- with pytest.raises(SupersetSecurityException):
+ with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
diff --git a/tests/integration_tests/test_app.py b/tests/integration_tests/test_app.py
index 798f3e9cda..fb7b47b67c 100644
--- a/tests/integration_tests/test_app.py
+++ b/tests/integration_tests/test_app.py
@@ -14,11 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from typing import TYPE_CHECKING
-"""
-Here is where we create the app which ends up being shared across all tests.integration_tests. A future
-optimization will be to create a separate app instance for each test class.
-"""
from superset.app import create_app
+if TYPE_CHECKING:
+ from typing import Any
+
+ from flask.testing import FlaskClient
+
app = create_app()
+
+
+def login(
+ client: "FlaskClient[Any]", username: str = "admin", password: str = "general"
+):
+ resp = client.post(
+ "/login/",
+ data=dict(username=username, password=password),
+ ).get_data(as_text=True)
+ assert "User confirmation needed" not in resp