You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/01/05 19:57:22 UTC

[superset] 02/14: fix: allow adhoc columns in non-aggregate query (#21729)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 1.5
in repository https://gitbox.apache.org/repos/asf/superset.git

commit f676a890d9bd7b563835b7919b5b15da8d50df4c
Author: Mayur <ma...@gmail.com>
AuthorDate: Mon Oct 10 09:38:33 2022 +0530

    fix: allow adhoc columns in non-aggregate query (#21729)
---
 superset/connectors/sqla/models.py               | 18 ++++++++---
 superset/superset_typing.py                      |  4 +--
 superset/utils/core.py                           |  4 ++-
 tests/integration_tests/charts/data/api_tests.py | 41 ++++++++++++++++++++++++
 4 files changed, 60 insertions(+), 7 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 7289a7c3fb..4697530328 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -96,6 +96,7 @@ from superset.sql_parse import ParsedQuery, sanitize_clause
 from superset.superset_typing import (
     AdhocColumn,
     AdhocMetric,
+    Column as ColumnTyping,
     Metric,
     OrderBy,
     QueryObjectDict,
@@ -1067,7 +1068,7 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
     def get_sqla_query(  # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
         self,
         apply_fetch_values_predicate: bool = False,
-        columns: Optional[List[Column]] = None,
+        columns: Optional[List[ColumnTyping]] = None,
         extras: Optional[Dict[str, Any]] = None,
         filter: Optional[  # pylint: disable=redefined-builtin
             List[QueryObjectFilterClause]
@@ -1261,15 +1262,24 @@ class SqlaTable(Model, BaseDatasource):  # pylint: disable=too-many-public-metho
                 select_exprs.append(outer)
         elif columns:
             for selected in columns:
+                if is_adhoc_column(selected):
+                    _sql = selected["sqlExpression"]
+                    _column_label = selected["label"]
+                elif isinstance(selected, str):
+                    _sql = selected
+                    _column_label = selected
+
                 selected = validate_adhoc_subquery(
-                    selected,
+                    _sql,
                     self.database_id,
                     self.schema,
                 )
                 select_exprs.append(
                     columns_by_name[selected].get_sqla_col()
-                    if selected in columns_by_name
-                    else self.make_sqla_column_compatible(literal_column(selected))
+                    if isinstance(selected, str) and selected in columns_by_name
+                    else self.make_sqla_column_compatible(
+                        literal_column(selected), _column_label
+                    )
                 )
             metrics_exprs = []
 
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 1af04494d0..cd2103154e 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -53,8 +53,8 @@ class AdhocMetric(TypedDict, total=False):
 
 class AdhocColumn(TypedDict, total=False):
     hasCustomLabel: Optional[bool]
-    label: Optional[str]
-    sqlExpression: Optional[str]
+    label: str
+    sqlExpression: str
 
 
 class ResultSetColumnType(TypedDict):
diff --git a/superset/utils/core.py b/superset/utils/core.py
index 6d74cf459b..899576bd1b 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -1247,7 +1247,9 @@ def is_adhoc_metric(metric: Metric) -> TypeGuard[AdhocMetric]:
 
 
 def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
-    return isinstance(column, dict)
+    return isinstance(column, dict) and ({"label", "sqlExpression"}).issubset(
+        column.keys()
+    )
 
 
 def get_column_name(
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index e7787f6e4d..e8f258421f 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -776,6 +776,47 @@ class TestPostChartDataApi(BaseTestChartDataApi):
         assert "':xyz:qwerty'" in result["query"]
         assert "':qwerty:'" in result["query"]
 
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_with_table_columns_without_metrics(self):
+        request_payload = self.query_context_payload
+        request_payload["queries"][0]["columns"] = ["name", "gender"]
+        request_payload["queries"][0]["metrics"] = None
+        request_payload["queries"][0]["orderby"] = []
+
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        result = rv.json["result"][0]
+
+        assert rv.status_code == 200
+        assert "name" in result["colnames"]
+        assert "gender" in result["colnames"]
+        assert "name" in result["query"]
+        assert "gender" in result["query"]
+        assert list(result["data"][0].keys()) == ["name", "gender"]
+
+    @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
+    def test_with_adhoc_column_without_metrics(self):
+        request_payload = self.query_context_payload
+        request_payload["queries"][0]["columns"] = [
+            "name",
+            {
+                "label": "num divide by 10",
+                "sqlExpression": "num/10",
+                "expressionType": "SQL",
+            },
+        ]
+        request_payload["queries"][0]["metrics"] = None
+        request_payload["queries"][0]["orderby"] = []
+
+        rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
+        result = rv.json["result"][0]
+
+        assert rv.status_code == 200
+        assert "num divide by 10" in result["colnames"]
+        assert "name" in result["colnames"]
+        assert "num divide by 10" in result["query"]
+        assert "name" in result["query"]
+        assert list(result["data"][0].keys()) == ["name", "num divide by 10"]
+
 
 @pytest.mark.chart_data_flow
 class TestGetChartDataApi(BaseTestChartDataApi):