You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/08/14 17:58:54 UTC
[incubator-superset] branch master updated: fix(chart-data-api):
assert referenced columns are present in datasource (#10451)
This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new acb00f5 fix(chart-data-api): assert referenced columns are present in datasource (#10451)
acb00f5 is described below
commit acb00f509c193ea90aecc7486eee7c6e9fe1a8b3
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Aug 14 20:58:24 2020 +0300
fix(chart-data-api): assert referenced columns are present in datasource (#10451)
* fix(chart-data-api): assert requested columns are present in datasource
* add filter tests
* add column_names to AnnotationDatasource
* add assertion for simple metrics
* lint
---
superset/charts/schemas.py | 2 +-
superset/common/query_context.py | 16 +++++++
superset/connectors/sqla/models.py | 17 ++++++-
superset/utils/core.py | 57 ++++++++++++++++++++----
superset/viz.py | 18 ++++++++
tests/core_tests.py | 19 ++++++++
tests/query_context_tests.py | 91 +++++++++++++++++++++++++++++++++-----
7 files changed, 196 insertions(+), 24 deletions(-)
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 34fa4d8..1fba09c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema):
deprecated=True,
)
having_filters = fields.List(
- fields.Dict(),
+ fields.Nested(ChartDataFilterSchema),
description="HAVING filters to be added to legacy Druid datasource queries. "
"This field is deprecated and should be passed to `extras` "
"as `having_druid`.",
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 0d33f9c..d2cecae 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
import pandas as pd
+from flask_babel import gettext as _
from superset import app, cache, db, security_manager
from superset.common.query_object import QueryObject
@@ -235,6 +236,21 @@ class QueryContext:
if query_obj and not is_loaded:
try:
+ invalid_columns = [
+ col
+ for col in query_obj.columns
+ + query_obj.groupby
+ + [flt["col"] for flt in query_obj.filter]
+ + utils.get_column_names_from_metrics(query_obj.metrics)
+ if col not in self.datasource.column_names
+ ]
+ if invalid_columns:
+ raise QueryObjectValidationError(
+ _(
+ "Columns missing in datasource: %(invalid_columns)s",
+ invalid_columns=invalid_columns,
+ )
+ )
query_result = self.get_query_result(query_obj)
status = query_result["status"]
query = query_result["query"]
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index cfc807d..97336d4 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
cache_timeout = 0
changed_on = None
type = "annotation"
+ column_names = [
+ "created_on",
+ "changed_on",
+ "id",
+ "start_dttm",
+ "end_dttm",
+ "layer_id",
+ "short_descr",
+ "long_descr",
+ "json_metadata",
+ "created_by_fk",
+ "changed_by_fk",
+ ]
def query(self, query_obj: QueryObjectDict) -> QueryResult:
error_message = None
@@ -721,7 +734,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
expression_type = metric.get("expressionType")
label = utils.get_metric_name(metric)
- if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
+ if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
column_name = metric["column"].get("column_name")
table_column = columns_by_name.get(column_name)
if table_column:
@@ -729,7 +742,7 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
else:
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
- elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
+ elif expression_type == utils.AdhocMetricExpressionType.SQL:
sqla_metric = literal_column(metric.get("sqlExpression"))
else:
return None
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 3f998ca..aa3a10a 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -42,6 +42,7 @@ from types import TracebackType
from typing import (
Any,
Callable,
+ cast,
Dict,
Iterable,
Iterator,
@@ -102,7 +103,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO)
logger = logging.getLogger(__name__)
DTTM_ALIAS = "__timestamp"
-ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}
JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
@@ -1038,20 +1038,23 @@ def backend() -> str:
def is_adhoc_metric(metric: Metric) -> bool:
+ if not isinstance(metric, dict):
+ return False
+ metric = cast(Dict[str, Any], metric)
return bool(
- isinstance(metric, dict)
- and (
+ (
(
- metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
- and metric["column"]
- and metric["aggregate"]
+ metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
+ and metric.get("column")
+ and cast(Dict[str, Any], metric["column"]).get("column_name")
+ and metric.get("aggregate")
)
or (
- metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
- and metric["sqlExpression"]
+ metric.get("expressionType") == AdhocMetricExpressionType.SQL
+ and metric.get("sqlExpression")
)
)
- and metric["label"]
+ and metric.get("label")
)
@@ -1398,6 +1401,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
+def get_column_name_from_metric(metric: Metric) -> Optional[str]:
+ """
+ Extract the column that a metric is referencing. If the metric isn't
+ a simple metric, always returns `None`.
+
+ :param metric: Ad-hoc metric
+ :return: column name if simple metric, otherwise None
+ """
+ if is_adhoc_metric(metric):
+ metric = cast(Dict[str, Any], metric)
+ if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
+ return cast(Dict[str, Any], metric["column"])["column_name"]
+ return None
+
+
+def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
+ """
+ Extract the columns that a list of metrics are referencing. Expcludes all
+ SQL metrics.
+
+ :param metrics: Ad-hoc metric
+ :return: column name if simple metric, otherwise None
+ """
+ columns: List[str] = []
+ for metric in metrics:
+ column_name = get_column_name_from_metric(metric)
+ if column_name:
+ columns.append(column_name)
+ return columns
+
+
class LenientEnum(Enum):
"""Enums that do not raise ValueError when value is invalid"""
@@ -1523,3 +1557,8 @@ class PostProcessingContributionOrientation(str, Enum):
ROW = "row"
COLUMN = "column"
+
+
+class AdhocMetricExpressionType(str, Enum):
+ SIMPLE = "SIMPLE"
+ SQL = "SQL"
diff --git a/superset/viz.py b/superset/viz.py
index 34054cb..14eedf0 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -481,6 +481,24 @@ class BaseViz:
if query_obj and not is_loaded:
try:
+ invalid_columns = [
+ col
+ for col in (query_obj.get("columns") or [])
+ + (query_obj.get("groupby") or [])
+ + utils.get_column_names_from_metrics(
+ cast(
+ List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
+ )
+ )
+ if col not in self.datasource.column_names
+ ]
+ if invalid_columns:
+ raise QueryObjectValidationError(
+ _(
+ "Columns missing in datasource: %(invalid_columns)s",
+ invalid_columns=invalid_columns,
+ )
+ )
df = self.get_df(query_obj)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr("loaded_from_source")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 4f2d1bf..d625860 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -1202,6 +1202,25 @@ class TestCore(SupersetTestCase):
database.extra = json.dumps(extra)
self.assertEqual(database.explore_database_id, explore_database.id)
+ def test_get_column_names_from_metric(self):
+ simple_metric = {
+ "expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
+ "column": {"column_name": "my_col"},
+ "aggregate": "SUM",
+ "label": "My Simple Label",
+ }
+ assert utils.get_column_name_from_metric(simple_metric) == "my_col"
+
+ sql_metric = {
+ "expressionType": utils.AdhocMetricExpressionType.SQL.value,
+ "sqlExpression": "SUM(my_label)",
+ "label": "My SQL Label",
+ }
+ assert utils.get_column_name_from_metric(sql_metric) is None
+ assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
+ "my_col"
+ ]
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index f816bcd..0b0230f 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -17,11 +17,12 @@
import tests.test_app
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
-from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import (
+ AdhocMetricExpressionType,
ChartDataResultFormat,
ChartDataResultType,
+ FilterOperator,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
@@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
# construct baseline cache_key
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
@@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
@@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
)
# construct baseline cache_key from query_context with post processing operation
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
@@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
del payload["queries"][0]["extras"]["time_range_endpoints"]
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
@@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["granularity_sqla"] = "timecol"
- payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
- query_context = QueryContext(**payload)
+ payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
+ query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
@@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_format"] = ChartDataResultFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
+ def test_sql_injection_via_groupby(self):
+ """
+ Ensure that calling invalid columns names in groupby are caught
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["groupby"] = ["currentDatabase()"]
+ query_context = ChartDataQueryContextSchema().load(payload)
+ query_payload = query_context.get_payload()
+ assert query_payload[0].get("error") is not None
+
+ def test_sql_injection_via_columns(self):
+ """
+ Ensure that calling invalid columns names in columns are caught
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["groupby"] = []
+ payload["queries"][0]["metrics"] = []
+ payload["queries"][0]["columns"] = ["*, 'extra'"]
+ query_context = ChartDataQueryContextSchema().load(payload)
+ query_payload = query_context.get_payload()
+ assert query_payload[0].get("error") is not None
+
+ def test_sql_injection_via_filters(self):
+ """
+ Ensure that calling invalid columns names in filters are caught
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["groupby"] = ["name"]
+ payload["queries"][0]["metrics"] = []
+ payload["queries"][0]["filters"] = [
+ {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
+ ]
+ query_context = ChartDataQueryContextSchema().load(payload)
+ query_payload = query_context.get_payload()
+ assert query_payload[0].get("error") is not None
+
+ def test_sql_injection_via_metrics(self):
+ """
+ Ensure that calling invalid columns names in filters are caught
+ """
+ self.login(username="admin")
+ table_name = "birth_names"
+ table = self.get_table_by_name(table_name)
+ payload = get_query_context(table.name, table.id, table.type)
+ payload["queries"][0]["groupby"] = ["name"]
+ payload["queries"][0]["metrics"] = [
+ {
+ "expressionType": AdhocMetricExpressionType.SIMPLE.value,
+ "column": {"column_name": "invalid_col"},
+ "aggregate": "SUM",
+ "label": "My Simple Label",
+ }
+ ]
+ query_context = ChartDataQueryContextSchema().load(payload)
+ query_payload = query_context.get_payload()
+ assert query_payload[0].get("error") is not None
+
def test_samples_response_type(self):
"""
Ensure that samples result type works
@@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
@@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.QUERY.value
- query_context = QueryContext(**payload)
+ query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]