You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/08/22 13:00:31 UTC

[superset] branch master updated: feat: generate label map on the backend (#21124)

This is an automated email from the ASF dual-hosted git repository.

yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 11bf7b9125 feat: generate label map on the backend (#21124)
11bf7b9125 is described below

commit 11bf7b9125eefd93796a46d964c3f027fbc9ce4d
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Mon Aug 22 21:00:02 2022 +0800

    feat: generate label map on the backend (#21124)
---
 superset/common/query_context_processor.py         | 14 +++++++
 superset/utils/pandas_postprocessing/__init__.py   |  6 +++
 superset/utils/pandas_postprocessing/flatten.py    |  5 ++-
 superset/utils/pandas_postprocessing/utils.py      | 10 +++++
 tests/integration_tests/conftest.py                | 27 +++++++++++++
 tests/integration_tests/query_context_tests.py     | 45 ++++++++++++++++++++++
 .../pandas_postprocessing/test_flatten.py          | 19 +++++++++
 .../unit_tests/pandas_postprocessing/test_utils.py | 30 +++++++++++++++
 8 files changed, 154 insertions(+), 2 deletions(-)

diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 2978eeace4..b253caa6b9 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 import copy
 import logging
+import re
 from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
 
 import numpy as np
@@ -57,6 +58,7 @@ from superset.utils.core import (
     TIME_COMPARISON,
 )
 from superset.utils.date_parser import get_past_or_future, normalize_time_delta
+from superset.utils.pandas_postprocessing.utils import unescape_separator
 from superset.views.utils import get_viz
 
 if TYPE_CHECKING:
@@ -142,6 +144,17 @@ class QueryContextProcessor:
                 cache.error_message = str(ex)
                 cache.status = QueryStatus.FAILED
 
+        # the N-dimensional DataFrame has converteds into flat DataFrame
+        # by `flatten operator`, "comma" in the column is escaped by `escape_separator`
+        # the result DataFrame columns should be unescaped
+        label_map = {
+            unescape_separator(col): [
+                unescape_separator(col) for col in re.split(r"(?<!\\),\s", col)
+            ]
+            for col in cache.df.columns.values
+        }
+        cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]
+
         return {
             "cache_key": cache_key,
             "cached_dttm": cache.cache_dttm,
@@ -157,6 +170,7 @@ class QueryContextProcessor:
             "rowcount": len(cache.df.index),
             "from_dttm": query_obj.from_dttm,
             "to_dttm": query_obj.to_dttm,
+            "label_map": label_map,
         }
 
     def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
diff --git a/superset/utils/pandas_postprocessing/__init__.py b/superset/utils/pandas_postprocessing/__init__.py
index 7902cb3232..e66a52f655 100644
--- a/superset/utils/pandas_postprocessing/__init__.py
+++ b/superset/utils/pandas_postprocessing/__init__.py
@@ -33,6 +33,10 @@ from superset.utils.pandas_postprocessing.resample import resample
 from superset.utils.pandas_postprocessing.rolling import rolling
 from superset.utils.pandas_postprocessing.select import select
 from superset.utils.pandas_postprocessing.sort import sort
+from superset.utils.pandas_postprocessing.utils import (
+    escape_separator,
+    unescape_separator,
+)
 
 __all__ = [
     "aggregate",
@@ -52,4 +56,6 @@ __all__ = [
     "select",
     "sort",
     "flatten",
+    "escape_separator",
+    "unescape_separator",
 ]
diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py
index 2874ac5797..db783c4bed 100644
--- a/superset/utils/pandas_postprocessing/flatten.py
+++ b/superset/utils/pandas_postprocessing/flatten.py
@@ -22,6 +22,7 @@ from numpy.distutils.misc_util import is_sequence
 
 from superset.utils.pandas_postprocessing.utils import (
     _is_multi_index_on_columns,
+    escape_separator,
     FLAT_COLUMN_SEPARATOR,
 )
 
@@ -86,8 +87,8 @@ def flatten(
             _cells = []
             for cell in series if is_sequence(series) else [series]:
                 if pd.notnull(cell):
-                    # every cell should be converted to string
-                    _cells.append(str(cell))
+                    # every cell should be converted to string and escape comma
+                    _cells.append(escape_separator(str(cell)))
             _columns.append(FLAT_COLUMN_SEPARATOR.join(_cells))
 
         df.columns = _columns
diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py
index 3d14f643c5..bff62dcb64 100644
--- a/superset/utils/pandas_postprocessing/utils.py
+++ b/superset/utils/pandas_postprocessing/utils.py
@@ -198,3 +198,13 @@ def _append_columns(
         return _base_df
     append_df = append_df.rename(columns=columns)
     return pd.concat([base_df, append_df], axis="columns")
+
+
+def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
+    char = sep.strip()
+    return plain_str.replace(char, "\\" + char)
+
+
+def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
+    char = sep.strip()
+    return escaped_str.replace("\\" + char, char)
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 549a987db1..aaa156b5b4 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -358,3 +358,30 @@ def physical_dataset():
     for ds in dataset:
         db.session.delete(ds)
     db.session.commit()
+
+
+@pytest.fixture
+def virtual_dataset_comma_in_column_value():
+    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+
+    dataset = SqlaTable(
+        table_name="virtual_dataset",
+        sql=(
+            "SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
+            "UNION ALL "
+            "SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
+            "UNION ALL "
+            "SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
+        ),
+        database=get_example_database(),
+    )
+    TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
+    TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+
+    SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+    db.session.merge(dataset)
+
+    yield dataset
+
+    db.session.delete(dataset)
+    db.session.commit()
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index abd5d2be8b..b17072f6bc 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -25,6 +25,7 @@ from superset import db
 from superset.charts.schemas import ChartDataQueryContextSchema
 from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
 from superset.common.query_context import QueryContext
+from superset.common.query_context_factory import QueryContextFactory
 from superset.common.query_object import QueryObject
 from superset.connectors.sqla.models import SqlMetric
 from superset.datasource.dao import DatasourceDAO
@@ -35,6 +36,7 @@ from superset.utils.core import (
     DatasourceType,
     QueryStatus,
 )
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
 from tests.integration_tests.base_tests import SupersetTestCase
 from tests.integration_tests.fixtures.birth_names_dashboard import (
     load_birth_names_dashboard_with_slices,
@@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase):
                     row["sum__num__3 years later"]
                     == df_3_years_later.loc[index]["sum__num"]
                 )
+
+
+def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
+    qc = QueryContextFactory().create(
+        datasource={
+            "type": virtual_dataset_comma_in_column_value.type,
+            "id": virtual_dataset_comma_in_column_value.id,
+        },
+        queries=[
+            {
+                "columns": ["col1", "col2"],
+                "metrics": ["count"],
+                "post_processing": [
+                    {
+                        "operation": "pivot",
+                        "options": {
+                            "aggregates": {"count": {"operator": "mean"}},
+                            "columns": ["col2"],
+                            "index": ["col1"],
+                        },
+                    },
+                    {"operation": "flatten"},
+                ],
+            }
+        ],
+        result_type=ChartDataResultType.FULL,
+        force=True,
+    )
+    query_object = qc.queries[0]
+    df = qc.get_df_payload(query_object)["df"]
+    label_map = qc.get_df_payload(query_object)["label_map"]
+    assert list(df.columns.values) == [
+        "col1",
+        "count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
+        "count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
+        "count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
+    ]
+    assert label_map == {
+        "col1": ["col1"],
+        "count, col2, row1": ["count", "col2, row1"],
+        "count, col2, row2": ["count", "col2, row2"],
+        "count, col2, row3": ["count", "col2, row3"],
+    }
diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py b/tests/unit_tests/pandas_postprocessing/test_flatten.py
index 78a2e3eea4..fea84f7b9f 100644
--- a/tests/unit_tests/pandas_postprocessing/test_flatten.py
+++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py
@@ -156,3 +156,22 @@ def test_flat_integer_column_name():
             }
         )
     )
+
+
+def test_escape_column_name():
+    index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+    index.name = "__timestamp"
+    columns = pd.MultiIndex.from_arrays(
+        [
+            ["level1,value1", "level1,value2", "level1,value3"],
+            ["level2, value1", "level2, value2", "level2, value3"],
+        ],
+        names=["level1", "level2"],
+    )
+    df = pd.DataFrame(index=index, columns=columns, data=1)
+    assert list(pp.flatten(df).columns.values) == [
+        "__timestamp",
+        "level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
+        "level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
+        "level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
+    ]
diff --git a/tests/unit_tests/pandas_postprocessing/test_utils.py b/tests/unit_tests/pandas_postprocessing/test_utils.py
new file mode 100644
index 0000000000..058cefcd6c
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_utils.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import escape_separator, unescape_separator
+
+
+def test_escape_separator():
+    assert escape_separator(r" hell \world ") == r" hell \world "
+    assert unescape_separator(r" hell \world ") == r" hell \world "
+
+    escape_string = escape_separator("hello, world")
+    assert escape_string == r"hello\, world"
+    assert unescape_separator(escape_string) == "hello, world"
+
+    escape_string = escape_separator("hello,world")
+    assert escape_string == r"hello\,world"
+    assert unescape_separator(escape_string) == "hello,world"