You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2020/09/21 11:42:11 UTC

[incubator-superset] branch 0.37 updated (98329fb -> 2fd965c)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a change to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git.


    from 98329fb  chore: update version and changelog
     new 4655723  fix(chart-data-api): assert referenced columns are present in datasource (#10451)
     new 1f3a93b  fix(legacy-druid): undefined filter key (#10931)
     new 315acf4  fix(jinja): make context attrs private on SQL templates (#10934)
     new 2fd965c  fix: simply is_adhoc_metric (#10964)

The 4 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 superset/charts/schemas.py                        |  2 +-
 superset/common/query_context.py                  | 16 ++++
 superset/connectors/druid/models.py               |  2 +-
 superset/connectors/sqla/models.py                | 17 ++++-
 superset/jinja_context.py                         | 36 ++++-----
 superset/utils/core.py                            | 54 ++++++++++----
 superset/viz.py                                   | 18 +++++
 tests/core_tests.py                               | 19 +++++
 tests/query_context_tests.py                      | 91 ++++++++++++++++++++---
 tests/superset_test_custom_template_processors.py |  2 +-
 10 files changed, 206 insertions(+), 51 deletions(-)


[incubator-superset] 01/04: fix(chart-data-api): assert referenced columns are present in datasource (#10451)

Posted by vi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git

commit 465572325b6c880b81189a94a27417bbb592f540
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Aug 14 20:58:24 2020 +0300

    fix(chart-data-api): assert referenced columns are present in datasource (#10451)
    
    * fix(chart-data-api): assert requested columns are present in datasource
    
    * add filter tests
    
    * add column_names to AnnotationDatasource
    
    * add assertion for simple metrics
    
    * lint
---
 superset/charts/schemas.py         |  2 +-
 superset/common/query_context.py   | 16 +++++++
 superset/connectors/sqla/models.py | 17 ++++++-
 superset/utils/core.py             | 57 ++++++++++++++++++++----
 superset/viz.py                    | 18 ++++++++
 tests/core_tests.py                | 19 ++++++++
 tests/query_context_tests.py       | 91 +++++++++++++++++++++++++++++++++-----
 7 files changed, 196 insertions(+), 24 deletions(-)

diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index cdea43c..8d580e3 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -718,7 +718,7 @@ class ChartDataQueryObjectSchema(Schema):
         deprecated=True,
     )
     having_filters = fields.List(
-        fields.Dict(),
+        fields.Nested(ChartDataFilterSchema),
         description="HAVING filters to be added to legacy Druid datasource queries. "
         "This field is deprecated and should be passed to `extras` "
         "as `having_druid`.",
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 0d33f9c..d2cecae 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -22,6 +22,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
 
 import numpy as np
 import pandas as pd
+from flask_babel import gettext as _
 
 from superset import app, cache, db, security_manager
 from superset.common.query_object import QueryObject
@@ -235,6 +236,21 @@ class QueryContext:
 
         if query_obj and not is_loaded:
             try:
+                invalid_columns = [
+                    col
+                    for col in query_obj.columns
+                    + query_obj.groupby
+                    + [flt["col"] for flt in query_obj.filter]
+                    + utils.get_column_names_from_metrics(query_obj.metrics)
+                    if col not in self.datasource.column_names
+                ]
+                if invalid_columns:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Columns missing in datasource: %(invalid_columns)s",
+                            invalid_columns=invalid_columns,
+                        )
+                    )
                 query_result = self.get_query_result(query_obj)
                 status = query_result["status"]
                 query = query_result["query"]
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 2bfa444..98c6f1e 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -90,6 +90,19 @@ class AnnotationDatasource(BaseDatasource):
     cache_timeout = 0
     changed_on = None
     type = "annotation"
+    column_names = [
+        "created_on",
+        "changed_on",
+        "id",
+        "start_dttm",
+        "end_dttm",
+        "layer_id",
+        "short_descr",
+        "long_descr",
+        "json_metadata",
+        "created_by_fk",
+        "changed_by_fk",
+    ]
 
     def query(self, query_obj: QueryObjectDict) -> QueryResult:
         error_message = None
@@ -717,7 +730,7 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
         expression_type = metric.get("expressionType")
         label = utils.get_metric_name(metric)
 
-        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]:
+        if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
             column_name = metric["column"].get("column_name")
             table_column = columns_by_name.get(column_name)
             if table_column:
@@ -725,7 +738,7 @@ class SqlaTable(  # pylint: disable=too-many-public-methods,too-many-instance-at
             else:
                 sqla_column = column(column_name)
             sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
-        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES["SQL"]:
+        elif expression_type == utils.AdhocMetricExpressionType.SQL:
             sqla_metric = literal_column(metric.get("sqlExpression"))
         else:
             return None
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 297127e..2e2bfcb 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -42,6 +42,7 @@ from types import TracebackType
 from typing import (
     Any,
     Callable,
+    cast,
     Dict,
     Iterable,
     Iterator,
@@ -101,7 +102,6 @@ logging.getLogger("MARKDOWN").setLevel(logging.INFO)
 logger = logging.getLogger(__name__)
 
 DTTM_ALIAS = "__timestamp"
-ADHOC_METRIC_EXPRESSION_TYPES = {"SIMPLE": "SIMPLE", "SQL": "SQL"}
 
 JS_MAX_INTEGER = 9007199254740991  # Largest int Java Script can handle 2^53-1
 
@@ -1010,20 +1010,23 @@ def get_example_database() -> "Database":
 
 
 def is_adhoc_metric(metric: Metric) -> bool:
+    if not isinstance(metric, dict):
+        return False
+    metric = cast(Dict[str, Any], metric)
     return bool(
-        isinstance(metric, dict)
-        and (
+        (
             (
-                metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SIMPLE"]
-                and metric["column"]
-                and metric["aggregate"]
+                metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
+                and metric.get("column")
+                and cast(Dict[str, Any], metric["column"]).get("column_name")
+                and metric.get("aggregate")
             )
             or (
-                metric["expressionType"] == ADHOC_METRIC_EXPRESSION_TYPES["SQL"]
-                and metric["sqlExpression"]
+                metric.get("expressionType") == AdhocMetricExpressionType.SQL
+                and metric.get("sqlExpression")
             )
         )
-        and metric["label"]
+        and metric.get("label")
     )
 
 
@@ -1370,6 +1373,37 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
     return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
 
 
+def get_column_name_from_metric(metric: Metric) -> Optional[str]:
+    """
+    Extract the column that a metric is referencing. If the metric isn't
+    a simple metric, always returns `None`.
+
+    :param metric: Ad-hoc metric
+    :return: column name if simple metric, otherwise None
+    """
+    if is_adhoc_metric(metric):
+        metric = cast(Dict[str, Any], metric)
+        if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
+            return cast(Dict[str, Any], metric["column"])["column_name"]
+    return None
+
+
+def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
+    """
+    Extract the columns that a list of metrics are referencing. Expcludes all
+    SQL metrics.
+
+    :param metrics: Ad-hoc metric
+    :return: column name if simple metric, otherwise None
+    """
+    columns: List[str] = []
+    for metric in metrics:
+        column_name = get_column_name_from_metric(metric)
+        if column_name:
+            columns.append(column_name)
+    return columns
+
+
 class LenientEnum(Enum):
     """Enums that do not raise ValueError when value is invalid"""
 
@@ -1495,3 +1529,8 @@ class PostProcessingContributionOrientation(str, Enum):
 
     ROW = "row"
     COLUMN = "column"
+
+
+class AdhocMetricExpressionType(str, Enum):
+    SIMPLE = "SIMPLE"
+    SQL = "SQL"
diff --git a/superset/viz.py b/superset/viz.py
index df00d43..365c0c5 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -482,6 +482,24 @@ class BaseViz:
 
         if query_obj and not is_loaded:
             try:
+                invalid_columns = [
+                    col
+                    for col in (query_obj.get("columns") or [])
+                    + (query_obj.get("groupby") or [])
+                    + utils.get_column_names_from_metrics(
+                        cast(
+                            List[Union[str, Dict[str, Any]]], query_obj.get("metrics"),
+                        )
+                    )
+                    if col not in self.datasource.column_names
+                ]
+                if invalid_columns:
+                    raise QueryObjectValidationError(
+                        _(
+                            "Columns missing in datasource: %(invalid_columns)s",
+                            invalid_columns=invalid_columns,
+                        )
+                    )
                 df = self.get_df(query_obj)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr("loaded_from_source")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 2e9ab4b..cc9b3b5 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -1335,6 +1335,25 @@ class TestCore(SupersetTestCase):
         database.extra = json.dumps(extra)
         self.assertEqual(database.explore_database_id, explore_database.id)
 
+    def test_get_column_names_from_metric(self):
+        simple_metric = {
+            "expressionType": utils.AdhocMetricExpressionType.SIMPLE.value,
+            "column": {"column_name": "my_col"},
+            "aggregate": "SUM",
+            "label": "My Simple Label",
+        }
+        assert utils.get_column_name_from_metric(simple_metric) == "my_col"
+
+        sql_metric = {
+            "expressionType": utils.AdhocMetricExpressionType.SQL.value,
+            "sqlExpression": "SUM(my_label)",
+            "label": "My SQL Label",
+        }
+        assert utils.get_column_name_from_metric(sql_metric) is None
+        assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [
+            "my_col"
+        ]
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py
index f816bcd..0b0230f 100644
--- a/tests/query_context_tests.py
+++ b/tests/query_context_tests.py
@@ -17,11 +17,12 @@
 import tests.test_app
 from superset import db
 from superset.charts.schemas import ChartDataQueryContextSchema
-from superset.common.query_context import QueryContext
 from superset.connectors.connector_registry import ConnectorRegistry
 from superset.utils.core import (
+    AdhocMetricExpressionType,
     ChartDataResultFormat,
     ChartDataResultType,
+    FilterOperator,
     TimeRangeEndpoint,
 )
 from tests.base_tests import SupersetTestCase
@@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
 
         # construct baseline cache_key
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.cache_key(query_object)
 
@@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
         db.session.commit()
 
         # create new QueryContext with unchanged attributes and extract new cache_key
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_new = query_context.cache_key(query_object)
 
@@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
         )
 
         # construct baseline cache_key from query_context with post processing operation
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_original = query_context.cache_key(query_object)
 
         # ensure added None post_processing operation doesn't change cache_key
         payload["queries"][0]["post_processing"].append(None)
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_with_null = query_context.cache_key(query_object)
         self.assertEqual(cache_key_original, cache_key_with_null)
 
         # ensure query without post processing operation is different
         payload["queries"][0].pop("post_processing")
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         cache_key_without_post_processing = query_context.cache_key(query_object)
         self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
@@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         del payload["queries"][0]["extras"]["time_range_endpoints"]
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         query_object = query_context.queries[0]
         extras = query_object.to_dict()["extras"]
         self.assertTrue("time_range_endpoints" in extras)
@@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         payload["queries"][0]["granularity_sqla"] = "timecol"
-        payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
-        query_context = QueryContext(**payload)
+        payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
+        query_context = ChartDataQueryContextSchema().load(payload)
         self.assertEqual(len(query_context.queries), 1)
         query_object = query_context.queries[0]
         self.assertEqual(query_object.granularity, "timecol")
@@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_format"] = ChartDataResultFormat.CSV.value
         payload["queries"][0]["row_limit"] = 10
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         data = responses[0]["data"]
         self.assertIn("name,sum__num\n", data)
         self.assertEqual(len(data.split("\n")), 12)
 
+    def test_sql_injection_via_groupby(self):
+        """
+        Ensure that calling invalid columns names in groupby are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["currentDatabase()"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_columns(self):
+        """
+        Ensure that calling invalid columns names in columns are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = []
+        payload["queries"][0]["metrics"] = []
+        payload["queries"][0]["columns"] = ["*, 'extra'"]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_filters(self):
+        """
+        Ensure that calling invalid columns names in filters are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["name"]
+        payload["queries"][0]["metrics"] = []
+        payload["queries"][0]["filters"] = [
+            {"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
+        ]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
+    def test_sql_injection_via_metrics(self):
+        """
+        Ensure that calling invalid columns names in filters are caught
+        """
+        self.login(username="admin")
+        table_name = "birth_names"
+        table = self.get_table_by_name(table_name)
+        payload = get_query_context(table.name, table.id, table.type)
+        payload["queries"][0]["groupby"] = ["name"]
+        payload["queries"][0]["metrics"] = [
+            {
+                "expressionType": AdhocMetricExpressionType.SIMPLE.value,
+                "column": {"column_name": "invalid_col"},
+                "aggregate": "SUM",
+                "label": "My Simple Label",
+            }
+        ]
+        query_context = ChartDataQueryContextSchema().load(payload)
+        query_payload = query_context.get_payload()
+        assert query_payload[0].get("error") is not None
+
     def test_samples_response_type(self):
         """
         Ensure that samples result type works
@@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_type"] = ChartDataResultType.SAMPLES.value
         payload["queries"][0]["row_limit"] = 5
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         data = responses[0]["data"]
@@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
         table = self.get_table_by_name(table_name)
         payload = get_query_context(table.name, table.id, table.type)
         payload["result_type"] = ChartDataResultType.QUERY.value
-        query_context = QueryContext(**payload)
+        query_context = ChartDataQueryContextSchema().load(payload)
         responses = query_context.get_payload()
         self.assertEqual(len(responses), 1)
         response = responses[0]


[incubator-superset] 04/04: fix: simply is_adhoc_metric (#10964)

Posted by vi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git

commit 2fd965c6b3fea7ea74225fe7e0b6c9faed544f2f
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Sun Sep 20 13:22:55 2020 +0300

    fix: simply is_adhoc_metric (#10964)
    
    * fix: simply is_adhoc_metric
    
    * address comment
---
 superset/utils/core.py | 19 +------------------
 1 file changed, 1 insertion(+), 18 deletions(-)

diff --git a/superset/utils/core.py b/superset/utils/core.py
index 2e2bfcb..6e9a80d 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1010,24 +1010,7 @@ def get_example_database() -> "Database":
 
 
 def is_adhoc_metric(metric: Metric) -> bool:
-    if not isinstance(metric, dict):
-        return False
-    metric = cast(Dict[str, Any], metric)
-    return bool(
-        (
-            (
-                metric.get("expressionType") == AdhocMetricExpressionType.SIMPLE
-                and metric.get("column")
-                and cast(Dict[str, Any], metric["column"]).get("column_name")
-                and metric.get("aggregate")
-            )
-            or (
-                metric.get("expressionType") == AdhocMetricExpressionType.SQL
-                and metric.get("sqlExpression")
-            )
-        )
-        and metric.get("label")
-    )
+    return isinstance(metric, dict)
 
 
 def get_metric_name(metric: Metric) -> str:


[incubator-superset] 03/04: fix(jinja): make context attrs private on SQL templates (#10934)

Posted by vi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git

commit 315acf481f335ce8d7554388aff3cf6460499a83
Author: Daniel Vaz Gaspar <da...@gmail.com>
AuthorDate: Fri Sep 18 12:56:07 2020 +0100

    fix(jinja): make context attrs private on SQL templates (#10934)
    
    * fix(jinja): make SQLAlchemy models private on SQL templates
    
    * add missing privates
    
    * fix test
---
 superset/jinja_context.py                         | 36 +++++++++++------------
 tests/superset_test_custom_template_processors.py |  2 +-
 2 files changed, 19 insertions(+), 19 deletions(-)

diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 95ee723..5a35997 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -213,17 +213,17 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
         extra_cache_keys: Optional[List[Any]] = None,
         **kwargs: Any,
     ) -> None:
-        self.database = database
-        self.query = query
-        self.schema = None
+        self._database = database
+        self._query = query
+        self._schema = None
         if query and query.schema:
-            self.schema = query.schema
+            self._schema = query.schema
         elif table:
-            self.schema = table.schema
+            self._schema = table.schema
 
         extra_cache = ExtraCache(extra_cache_keys)
 
-        self.context = {
+        self._context = {
             "url_param": extra_cache.url_param,
             "current_user_id": extra_cache.current_user_id,
             "current_username": extra_cache.current_username,
@@ -231,11 +231,11 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
             "filter_values": filter_values,
             "form_data": {},
         }
-        self.context.update(kwargs)
-        self.context.update(jinja_base_context)
+        self._context.update(kwargs)
+        self._context.update(jinja_base_context)
         if self.engine:
-            self.context[self.engine] = self
-        self.env = SandboxedEnvironment()
+            self._context[self.engine] = self
+        self._env = SandboxedEnvironment()
 
     def process_template(self, sql: str, **kwargs: Any) -> str:
         """Processes a sql template
@@ -244,8 +244,8 @@ class BaseTemplateProcessor:  # pylint: disable=too-few-public-methods
         >>> process_template(sql)
         "SELECT '2017-01-01T00:00:00'"
         """
-        template = self.env.from_string(sql)
-        kwargs.update(self.context)
+        template = self._env.from_string(sql)
+        kwargs.update(self._context)
         return template.render(kwargs)
 
 
@@ -288,20 +288,20 @@ class PrestoTemplateProcessor(BaseTemplateProcessor):
 
         from superset.db_engine_specs.presto import PrestoEngineSpec
 
-        table_name, schema = self._schema_table(table_name, self.schema)
-        return cast(PrestoEngineSpec, self.database.db_engine_spec).latest_partition(
-            table_name, schema, self.database
+        table_name, schema = self._schema_table(table_name, self._schema)
+        return cast(PrestoEngineSpec, self._database.db_engine_spec).latest_partition(
+            table_name, schema, self._database
         )[1]
 
     def latest_sub_partition(self, table_name: str, **kwargs: Any) -> Any:
-        table_name, schema = self._schema_table(table_name, self.schema)
+        table_name, schema = self._schema_table(table_name, self._schema)
 
         from superset.db_engine_specs.presto import PrestoEngineSpec
 
         return cast(
-            PrestoEngineSpec, self.database.db_engine_spec
+            PrestoEngineSpec, self._database.db_engine_spec
         ).latest_sub_partition(
-            table_name=table_name, schema=schema, database=self.database, **kwargs
+            table_name=table_name, schema=schema, database=self._database, **kwargs
         )
 
     latest_partition = first_latest_partition
diff --git a/tests/superset_test_custom_template_processors.py b/tests/superset_test_custom_template_processors.py
index 28fc65d..2987109 100644
--- a/tests/superset_test_custom_template_processors.py
+++ b/tests/superset_test_custom_template_processors.py
@@ -42,7 +42,7 @@ class CustomPrestoTemplateProcessor(PrestoTemplateProcessor):
         # Add custom macros functions.
         macros = {"DATE": partial(DATE, datetime.utcnow())}  # type: Dict[str, Any]
         # Update with macros defined in context and kwargs.
-        macros.update(self.context)
+        macros.update(self._context)
         macros.update(kwargs)
 
         def replacer(match):


[incubator-superset] 02/04: fix(legacy-druid): undefined filter key (#10931)

Posted by vi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch 0.37
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git

commit 1f3a93b2c91081c5ebb9ed950f4a8d2a36471f59
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Thu Sep 17 15:10:06 2020 +0300

    fix(legacy-druid): undefined filter key (#10931)
---
 superset/connectors/druid/models.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py
index e9c6029..5285201 100644
--- a/superset/connectors/druid/models.py
+++ b/superset/connectors/druid/models.py
@@ -1396,7 +1396,7 @@ class DruidDatasource(Model, BaseDatasource):
                 if df is None:
                     df = pd.DataFrame()
                 qry["filter"] = self._add_filter_from_pre_query_data(
-                    df, pre_qry["dimensions"], qry["filter"]
+                    df, pre_qry["dimensions"], filters
                 )
                 qry["limit_spec"] = None
             if row_limit: