You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/08/14 17:58:54 UTC

[incubator-superset] branch master updated: fix(chart-data-api): assert referenced columns are present in datasource (#10451)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new acb00f5  fix(chart-data-api): assert referenced columns are present in datasource (#10451)
acb00f5 is described below

commit acb00f509c193ea90aecc7486eee7c6e9fe1a8b3
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Aug 14 20:58:24 2020 +0300

    fix(chart-data-api): assert referenced columns are present in datasource (#10451)
    
    * fix(chart-data-api): assert requested columns are present in datasource
    
    * add filter tests
    
    * add column_names to AnnotationDatasource
    
    * add assertion for simple metrics
    
    * lint
---
 superset/charts/schemas.py         |  2 +-
 superset/common/query_context.py   | 16 +++++++
 superset/connectors/sqla/models.py | 17 ++++++-
 superset/utils/core.py             | 57 ++++++++++++++++++++----
 superset/viz.py                    | 18 ++++++++
 tests/core_tests.py                | 19 ++++++++
 tests/query_context_tests.py       | 91 +++++++++++++++++++++++++++++++++-----
 7 files changed, 196 insertions(+), 24 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 34fa4d8..1fba09c 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -797,7 +797,7 @@ class ChartDataQueryObjectSchema(Schema):
         deprecated=True,
     )
     having_filters = fields.List(
-        fields.Dict(),
+        fields.Nested(ChartDataFilterSchema),
         description="HAVING filters to be added to legacy Druid datasource queries. "
         "This field is deprecated and should be passed to `extras` "
         "as `having_druid`.",
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 0d33f9c..d2cecae 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
 
 import numpy as np
 import pandas as pd
+from flask_babel import gettext as _
 
 from superset import app, cache, db, security_manager
 from superset.common.query_object import QueryObject
@@ -235,6 +236,21 @@ class QueryContext:
 
         if query_obj and not is_loaded:
             try:
+                invalid_columns = [
+                    col
+                    for col in query_obj.columns
+                    + query_obj.groupby
+                    + [flt["col"] for flt in query_obj.filter]
+                    + utils.get_column_names_from_metrics(query_obj.metrics)
+                    if col not in self.datasource.column_names
+                ]
+                if invalid_columns:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Columns missing in datasource: %(invalid_columns)s",
+                            invalid_columns=invalid_columns,
+                        )
+                    )
                 query_result = self.get_query_result(query_obj)
                 status = query_result["status"]
                 query = query_result["query"]
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index cfc807d..97336d4 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
     cache_timeout = 0
     changed_on = None
     type = "annotation"
+    column_names = [
+        "created_on",
+        "changed_on",
+        "id",
+        "start_dttm",
+        "end_dttm",
+        "layer_id",
+        "short_descr",
+        "long_descr",
+        "json_metadata",
+        "created_by_fk",
+        "changed_by_fk",
+    ]
 
     def query(self, query_obj: QueryObjectDict) -> QueryResult:
         error_message = None
@@ -721,7 +734,7 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
         expression_type = metric.get("expressionType")
         label = utils.get_metric_name(metric)
 
-        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
+        if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
             column_name = metric["column"].get("column_name")
             table_column = columns_by_name.get(column_name)
             if table_column:
@@ -729,7 +742,7 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
             else:
                 sqla_column = column(column_name)
             sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
-        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
+        elif expression_type == utils.AdhocMetricExpressionType.SQL:
             sqla_metric = literal_column(metric.get("sqlExpression"))
         else:
             return None
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 3f998ca..aa3a10a 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -42,6 +42,7 @@ from types import TracebackType
 from typing import (
     Any,
     Callable,
+    cast,
     Dict,
     Iterable,
     Iterator,
@@ -102,7 +103,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO)
 logger = logging.getLogger(__name__)
 
 DTTM_ALIAS = "__timestamp"
-ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}
 
 JS_MAX_INTEGER = 9007199254740991  # Largest int Java Script can handle 2^53-1
 
@@ -1038,20 +1038,23 @@ def backend() -> str:
 
 
 def is_adhoc_metric(metric: Metric) -> bool:
+    if not isinstance(metric, dict):
+        return False
+    metric = cast(Dict[str, Any], metric)
     return bool(
-        isinstance(metric, dict)
-        and (
+        (
             (
-                metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
-                and metric["column"]
-                and metric["aggregate"]
+                metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
+                and metric.get("column")
+                and cast(Dict[str, Any], metric["column"]).get("column_name")
+                and metric.get("aggregate")
             )
             or (
-                metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
-                and metric["sqlExpression"]
+                metric.get("expressionType") == AdhocMetricExpressionType.SQL
+                and metric.get("sqlExpression")
             )
         )
-        and metric["label"]
+        and metric.get("label")
     )
 
 
@@ -1398,6 +1401,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
     return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
 
 
+def get_column_name_from_metric(metric: Metric) -> Optional[str]:
+    """
+    Extract the column that a metric is referencing. If the metric isn't
+    a simple metric, always returns `None`.
+
+    :param metric: Ad-hoc metric
+    :return: column name if simple metric, otherwise None
+    """
+    if is_adhoc_metric(metric):
+        metric = cast(Dict[str, Any], metric)
+        if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
+            return cast(Dict[str, Any], metric["column"])["column_name"]
+    return None
+
+
+def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
+    """
+    Extract the columns that a list of metrics are referencing. Expcludes all
+    SQL metrics.
+
+    :param metrics: Ad-hoc metric
+    :return: column name if simple metric, otherwise None
+    """
+    columns: List[str] = []
+    for metric in metrics:
+        column_name = get_column_name_from_metric(metric)
+        if column_name:
+            columns.append(column_name)
+    return columns
+
+
 class LenientEnum(Enum):
     """Enums that do not raise ValueError when value is invalid"""
 
@@ -1523,3 +1557,8 @@ class PostProcessingContributionOrientation(str, Enum):
 
     ROW = "row"
     COLUMN = "column"
+
+
+class AdhocMetricExpressionType(str, Enum):
+    SIMPLE = "SIMPLE"
+    SQL = "SQL"
diff --git a/superset/viz.py b/superset/viz.py
index 34054cb..14eedf0 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -481,6 +481,24 @@ class BaseViz:
 
         if query_obj and not is_loaded:
             try:
+                invalid_columns = [
+                    col
+                    for col in (query_obj.get("columns") or [])
+                    + (query_obj.get("groupby") or [])
+                    + utils.get_column_names_from_metrics(
+                        cast(
+                            List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
+                        )
+                    )
+                    if col not in self.datasource.column_names
+                ]
+                if invalid_columns:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Columns missing in datasource: %(invalid_columns)s",
+                            invalid_columns=invalid_columns,
+                        )
+                    )
                 df = self.get_df(query_obj)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr("loaded_from_source")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 4f2d1bf..d625860 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -1202,6 +1202,25 @@ class TestCore(SupersetTestCase):
         database.extra = json.dumps(extra)
         self.assertEqual(database.explore_database_id, explore_database.id)
 
+    def test_get_column_names_from_metric(self):
+        simple_metric = {
+            "expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
+            "column": {"column_name": "my_col"},
+            "aggregate": "SUM",
+            "label": "My Simple Label",
+        }
+        assert utils.get_column_name_from_metric(simple_metric) == "my_col"
+
+        sql_metric = {
+            "expressionType": utils.AdhocMetricExpressionType.SQL.value,
+            "sqlExpression": "SUM(my_label)",
+            "label": "My SQL Label",
+        }
+        assert utils.get_column_name_from_metric(sql_metric) is None
+        assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
+            "my_col"
+        ]
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index f816bcd..0b0230f 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -17,11 +17,12 @@
 import tests.test_app
 from superset import db
 from superset.charts.schemas import ChartDataQueryContextSchema
-from superset.common.query_context import QueryContext
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.utils.core import (
+    AdhocMetricExpressionType,
     ChartDataResultFormat,
     ChartDataResultType,
+    FilterOperator,
     TimeRangeEndpoint,
 )
 from tests.base_tests import SupersetTestCase
@@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
 
         # construct baseline cache_key
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.cache_key(query_object)
 
@@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
         db.session.commit()
 
         # create new QueryContext with unchanged attributes and extract new cache_key
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_new = query_context.cache_key(query_object)
 
@@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
         )
 
         # construct baseline cache_key from query_context with post processing operation
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.cache_key(query_object)
 
         # ensure added None post_processing operation doesn't change cache_key
         payload["queries"][0]["post_processing"].append(None)
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_with_null = query_context.cache_key(query_object)
         self.assertEqual(cache_key_original, cache_key_with_null)
 
         # ensure query without post processing operation is different
         payload["queries"][0].pop("post_processing")
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_without_post_processing = query_context.cache_key(query_object)
         self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
@@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         del payload["queries"][0]["extras"]["time_range_endpoints"]
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         extras = query_object.to_dict()["extras"]
         self.assertTrue("time_range_endpoints" in extras)
@@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         payload["queries"][0]["granularity_sqla"] = "timecol"
-        payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
-        query_context = QueryContext(**payload)
+        payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
+        query_context = ChartDataQueryContextSchema().load(payload)
         self.assertEqual(len(query_context.queries), 1)
         query_object = query_context.queries[0]
         self.assertEqual(query_object.granularity, "timecol")
@@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_format"] = ChartDataResultFormat.CSV.value
         payload["queries"][0]["row_limit"] = 10
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         data = responses[0]["data"]
         self.assertIn("name,sum__num\n", data)
         self.assertEqual(len(data.split("\n")), 12)
 
+    def test_sql_injection_via_groupby(self):
+        """
+        Ensure that calling invalid columns names in groupby are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["currentDatabase()"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_columns(self):
+        """
+        Ensure that calling invalid columns names in columns are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = []
+        payload["queries"][0]["metrics"] = []
+        payload["queries"][0]["columns"] = ["*, 'extra'"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_filters(self):
+        """
+        Ensure that calling invalid columns names in filters are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["name"]
+        payload["queries"][0]["metrics"] = []
+        payload["queries"][0]["filters"] = [
+            {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
+        ]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_metrics(self):
+        """
+        Ensure that calling invalid columns names in filters are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["name"]
+        payload["queries"][0]["metrics"] = [
+            {
+                "expressionType": AdhocMetricExpressionType.SIMPLE.value,
+                "column": {"column_name": "invalid_col"},
+                "aggregate": "SUM",
+                "label": "My Simple Label",
+            }
+        ]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
     def test_samples_response_type(self):
         """
         Ensure that samples result type works
@@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_type"] = ChartDataResultType.SAMPLES.value
         payload["queries"][0]["row_limit"] = 5
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         data = responses[0]["data"]
@@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_type"] = ChartDataResultType.QUERY.value
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         response = responses[0]