You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by yo...@apache.org on 2022/08/22 13:00:31 UTC
[superset] branch master updated: feat: generate label map on the backend (#21124)
This is an automated email from the ASF dual-hosted git repository.
yongjiezhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 11bf7b9125 feat: generate label map on the backend (#21124)
11bf7b9125 is described below
commit 11bf7b9125eefd93796a46d964c3f027fbc9ce4d
Author: Yongjie Zhao <yo...@gmail.com>
AuthorDate: Mon Aug 22 21:00:02 2022 +0800
feat: generate label map on the backend (#21124)
---
superset/common/query_context_processor.py | 14 +++++++
superset/utils/pandas_postprocessing/__init__.py | 6 +++
superset/utils/pandas_postprocessing/flatten.py | 5 ++-
superset/utils/pandas_postprocessing/utils.py | 10 +++++
tests/integration_tests/conftest.py | 27 +++++++++++++
tests/integration_tests/query_context_tests.py | 45 ++++++++++++++++++++++
.../pandas_postprocessing/test_flatten.py | 19 +++++++++
.../unit_tests/pandas_postprocessing/test_utils.py | 30 +++++++++++++++
8 files changed, 154 insertions(+), 2 deletions(-)
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 2978eeace4..b253caa6b9 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -18,6 +18,7 @@ from __future__ import annotations
import copy
import logging
+import re
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
import numpy as np
@@ -57,6 +58,7 @@ from superset.utils.core import (
TIME_COMPARISON,
)
from superset.utils.date_parser import get_past_or_future, normalize_time_delta
+from superset.utils.pandas_postprocessing.utils import unescape_separator
from superset.views.utils import get_viz
if TYPE_CHECKING:
@@ -142,6 +144,17 @@ class QueryContextProcessor:
cache.error_message = str(ex)
cache.status = QueryStatus.FAILED
+ # the N-dimensional DataFrame has converteds into flat DataFrame
+ # by `flatten operator`, "comma" in the column is escaped by `escape_separator`
+ # the result DataFrame columns should be unescaped
+ label_map = {
+ unescape_separator(col): [
+ unescape_separator(col) for col in re.split(r"(?<!\\),\s", col)
+ ]
+ for col in cache.df.columns.values
+ }
+ cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values]
+
return {
"cache_key": cache_key,
"cached_dttm": cache.cache_dttm,
@@ -157,6 +170,7 @@ class QueryContextProcessor:
"rowcount": len(cache.df.index),
"from_dttm": query_obj.from_dttm,
"to_dttm": query_obj.to_dttm,
+ "label_map": label_map,
}
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
diff --git a/superset/utils/pandas_postprocessing/__init__.py b/superset/utils/pandas_postprocessing/__init__.py
index 7902cb3232..e66a52f655 100644
--- a/superset/utils/pandas_postprocessing/__init__.py
+++ b/superset/utils/pandas_postprocessing/__init__.py
@@ -33,6 +33,10 @@ from superset.utils.pandas_postprocessing.resample import resample
from superset.utils.pandas_postprocessing.rolling import rolling
from superset.utils.pandas_postprocessing.select import select
from superset.utils.pandas_postprocessing.sort import sort
+from superset.utils.pandas_postprocessing.utils import (
+ escape_separator,
+ unescape_separator,
+)
__all__ = [
"aggregate",
@@ -52,4 +56,6 @@ __all__ = [
"select",
"sort",
"flatten",
+ "escape_separator",
+ "unescape_separator",
]
diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py
index 2874ac5797..db783c4bed 100644
--- a/superset/utils/pandas_postprocessing/flatten.py
+++ b/superset/utils/pandas_postprocessing/flatten.py
@@ -22,6 +22,7 @@ from numpy.distutils.misc_util import is_sequence
from superset.utils.pandas_postprocessing.utils import (
_is_multi_index_on_columns,
+ escape_separator,
FLAT_COLUMN_SEPARATOR,
)
@@ -86,8 +87,8 @@ def flatten(
_cells = []
for cell in series if is_sequence(series) else [series]:
if pd.notnull(cell):
- # every cell should be converted to string
- _cells.append(str(cell))
+ # every cell should be converted to string and escape comma
+ _cells.append(escape_separator(str(cell)))
_columns.append(FLAT_COLUMN_SEPARATOR.join(_cells))
df.columns = _columns
diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py
index 3d14f643c5..bff62dcb64 100644
--- a/superset/utils/pandas_postprocessing/utils.py
+++ b/superset/utils/pandas_postprocessing/utils.py
@@ -198,3 +198,13 @@ def _append_columns(
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")
+
+
+def escape_separator(plain_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
+ char = sep.strip()
+ return plain_str.replace(char, "\\" + char)
+
+
+def unescape_separator(escaped_str: str, sep: str = FLAT_COLUMN_SEPARATOR) -> str:
+ char = sep.strip()
+ return escaped_str.replace("\\" + char, char)
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 549a987db1..aaa156b5b4 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -358,3 +358,30 @@ def physical_dataset():
for ds in dataset:
db.session.delete(ds)
db.session.commit()
+
+
+@pytest.fixture
+def virtual_dataset_comma_in_column_value():
+ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
+
+ dataset = SqlaTable(
+ table_name="virtual_dataset",
+ sql=(
+ "SELECT 'col1,row1' as col1, 'col2, row1' as col2 "
+ "UNION ALL "
+ "SELECT 'col1,row2' as col1, 'col2, row2' as col2 "
+ "UNION ALL "
+ "SELECT 'col1,row3' as col1, 'col2, row3' as col2 "
+ ),
+ database=get_example_database(),
+ )
+ TableColumn(column_name="col1", type="VARCHAR(255)", table=dataset)
+ TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
+
+ SqlMetric(metric_name="count", expression="count(*)", table=dataset)
+ db.session.merge(dataset)
+
+ yield dataset
+
+ db.session.delete(dataset)
+ db.session.commit()
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index abd5d2be8b..b17072f6bc 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -25,6 +25,7 @@ from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
+from superset.common.query_context_factory import QueryContextFactory
from superset.common.query_object import QueryObject
from superset.connectors.sqla.models import SqlMetric
from superset.datasource.dao import DatasourceDAO
@@ -35,6 +36,7 @@ from superset.utils.core import (
DatasourceType,
QueryStatus,
)
+from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -683,3 +685,46 @@ class TestQueryContext(SupersetTestCase):
row["sum__num__3 years later"]
== df_3_years_later.loc[index]["sum__num"]
)
+
+
+def test_get_label_map(app_context, virtual_dataset_comma_in_column_value):
+ qc = QueryContextFactory().create(
+ datasource={
+ "type": virtual_dataset_comma_in_column_value.type,
+ "id": virtual_dataset_comma_in_column_value.id,
+ },
+ queries=[
+ {
+ "columns": ["col1", "col2"],
+ "metrics": ["count"],
+ "post_processing": [
+ {
+ "operation": "pivot",
+ "options": {
+ "aggregates": {"count": {"operator": "mean"}},
+ "columns": ["col2"],
+ "index": ["col1"],
+ },
+ },
+ {"operation": "flatten"},
+ ],
+ }
+ ],
+ result_type=ChartDataResultType.FULL,
+ force=True,
+ )
+ query_object = qc.queries[0]
+ df = qc.get_df_payload(query_object)["df"]
+ label_map = qc.get_df_payload(query_object)["label_map"]
+ assert list(df.columns.values) == [
+ "col1",
+ "count" + FLAT_COLUMN_SEPARATOR + "col2, row1",
+ "count" + FLAT_COLUMN_SEPARATOR + "col2, row2",
+ "count" + FLAT_COLUMN_SEPARATOR + "col2, row3",
+ ]
+ assert label_map == {
+ "col1": ["col1"],
+ "count, col2, row1": ["count", "col2, row1"],
+ "count, col2, row2": ["count", "col2, row2"],
+ "count, col2, row3": ["count", "col2, row3"],
+ }
diff --git a/tests/unit_tests/pandas_postprocessing/test_flatten.py b/tests/unit_tests/pandas_postprocessing/test_flatten.py
index 78a2e3eea4..fea84f7b9f 100644
--- a/tests/unit_tests/pandas_postprocessing/test_flatten.py
+++ b/tests/unit_tests/pandas_postprocessing/test_flatten.py
@@ -156,3 +156,22 @@ def test_flat_integer_column_name():
}
)
)
+
+
+def test_escape_column_name():
+ index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
+ index.name = "__timestamp"
+ columns = pd.MultiIndex.from_arrays(
+ [
+ ["level1,value1", "level1,value2", "level1,value3"],
+ ["level2, value1", "level2, value2", "level2, value3"],
+ ],
+ names=["level1", "level2"],
+ )
+ df = pd.DataFrame(index=index, columns=columns, data=1)
+ assert list(pp.flatten(df).columns.values) == [
+ "__timestamp",
+ "level1\\,value1" + FLAT_COLUMN_SEPARATOR + "level2\\, value1",
+ "level1\\,value2" + FLAT_COLUMN_SEPARATOR + "level2\\, value2",
+ "level1\\,value3" + FLAT_COLUMN_SEPARATOR + "level2\\, value3",
+ ]
diff --git a/tests/unit_tests/pandas_postprocessing/test_utils.py b/tests/unit_tests/pandas_postprocessing/test_utils.py
new file mode 100644
index 0000000000..058cefcd6c
--- /dev/null
+++ b/tests/unit_tests/pandas_postprocessing/test_utils.py
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from superset.utils.pandas_postprocessing import escape_separator, unescape_separator
+
+
+def test_escape_separator():
+ assert escape_separator(r" hell \world ") == r" hell \world "
+ assert unescape_separator(r" hell \world ") == r" hell \world "
+
+ escape_string = escape_separator("hello, world")
+ assert escape_string == r"hello\, world"
+ assert unescape_separator(escape_string) == "hello, world"
+
+ escape_string = escape_separator("hello,world")
+ assert escape_string == r"hello\,world"
+ assert unescape_separator(escape_string) == "hello,world"