You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2023/10/06 16:47:06 UTC

[superset] branch master updated: fix: Apply normalization to all dttm columns (#25147)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 58fcd292a9 fix: Apply normalization to all dttm columns (#25147)
58fcd292a9 is described below

commit 58fcd292a979212a3d6f636917021c12c299fd93
Author: Kamil Gabryjelski <ka...@gmail.com>
AuthorDate: Fri Oct 6 18:47:00 2023 +0200

    fix: Apply normalization to all dttm columns (#25147)
---
 superset/common/query_context_factory.py           |  1 +
 superset/common/query_context_processor.py         |  5 +-
 superset/common/query_object_factory.py            | 66 +++++++++++++++-
 tests/integration_tests/query_context_tests.py     |  8 +-
 .../unit_tests/common/test_query_object_factory.py | 90 +++++++++++++++++++++-
 5 files changed, 160 insertions(+), 10 deletions(-)

diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py
index d6510ccd9a..4fd0de7856 100644
--- a/superset/common/query_context_factory.py
+++ b/superset/common/query_context_factory.py
@@ -185,6 +185,7 @@ class QueryContextFactory:  # pylint: disable=too-few-public-methods
                     filter
                     for filter in query_object.filter
                     if filter["col"] != filter_to_remove
+                    or filter["op"] != "TEMPORAL_RANGE"
                 ]
 
     def _apply_filters(self, query_object: QueryObject) -> None:
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 5a0468b671..dcf19c0c32 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -282,10 +282,11 @@ class QueryContextProcessor:
         datasource = self._qc_datasource
         labels = tuple(
             label
-            for label in [
+            for label in {
                 *get_base_axis_labels(query_object.columns),
+                *[col for col in query_object.columns or [] if isinstance(col, str)],
                 query_object.granularity,
-            ]
+            }
             if datasource
             # Query datasource didn't support `get_column`
             and hasattr(datasource, "get_column")
diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py
index a2732ae553..33393e88a6 100644
--- a/superset/common/query_object_factory.py
+++ b/superset/common/query_object_factory.py
@@ -16,17 +16,24 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import datetime
 from typing import Any, TYPE_CHECKING
 
 from superset.common.chart_data import ChartDataResultType
 from superset.common.query_object import QueryObject
 from superset.common.utils.time_range_utils import get_since_until_from_time_range
-from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
+from superset.utils.core import (
+    apply_max_row_limit,
+    DatasourceDict,
+    DatasourceType,
+    FilterOperator,
+    QueryObjectFilterClause,
+)
 
 if TYPE_CHECKING:
     from sqlalchemy.orm import sessionmaker
 
-    from superset.connectors.base.models import BaseDatasource
+    from superset.connectors.base.models import BaseColumn, BaseDatasource
     from superset.daos.datasource import DatasourceDAO
 
 
@@ -66,6 +73,10 @@ class QueryObjectFactory:  # pylint: disable=too-few-public-methods
         )
         kwargs["from_dttm"] = from_dttm
         kwargs["to_dttm"] = to_dttm
+        if datasource_model_instance and kwargs.get("filters", []):
+            kwargs["filters"] = self._process_filters(
+                datasource_model_instance, kwargs["filters"]
+            )
         return QueryObject(
             datasource=datasource_model_instance,
             extras=extras,
@@ -102,3 +113,54 @@ class QueryObjectFactory:  # pylint: disable=too-few-public-methods
     # light version of the view.utils.core
     # import view.utils require application context
     # Todo: move it and the view.utils.core to utils package
+
+    def _process_filters(
+        self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
+    ) -> list[QueryObjectFilterClause]:
+        def get_dttm_filter_value(
+            value: Any, col: BaseColumn, date_format: str
+        ) -> int | str:
+            if not isinstance(value, int):
+                return value
+            if date_format in {"epoch_ms", "epoch_s"}:
+                if date_format == "epoch_s":
+                    value = str(value)
+                else:
+                    value = str(value * 1000)
+            else:
+                dttm = datetime.utcfromtimestamp(value / 1000)
+                value = dttm.strftime(date_format)
+
+            if col.type in col.num_types:
+                value = int(value)
+            return value
+
+        for query_filter in query_filters:
+            if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
+                continue
+            filter_col = query_filter.get("col")
+            if not isinstance(filter_col, str):
+                continue
+            column = datasource.get_column(filter_col)
+            if not column:
+                continue
+            filter_value = query_filter.get("val")
+
+            date_format = column.python_date_format
+            if not date_format and datasource.db_extra:
+                date_format = datasource.db_extra.get(
+                    "python_date_format_by_column_name", {}
+                ).get(column.column_name)
+
+            if column.is_dttm and date_format:
+                if isinstance(filter_value, list):
+                    query_filter["val"] = [
+                        get_dttm_filter_value(value, column, date_format)
+                        for value in filter_value
+                    ]
+                else:
+                    query_filter["val"] = get_dttm_filter_value(
+                        filter_value, column, date_format
+                    )
+
+        return query_filters
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 8c2082d1c4..00a98b2c21 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
 
     query_object = qc.queries[0]
     df = qc.get_df_payload(query_object)["df"]
-    if query_object.datasource.database.backend == "sqlite":
-        # sqlite returns string as timestamp column
-        assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
-        assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
-    else:
+
+    # sqlite doesn't have timestamp columns
+    if query_object.datasource.database.backend != "sqlite":
         assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
         assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"
 
diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py
index 02304828dc..4e8fadfe3e 100644
--- a/tests/unit_tests/common/test_query_object_factory.py
+++ b/tests/unit_tests/common/test_query_object_factory.py
@@ -43,9 +43,45 @@ def session_factory() -> Mock:
     return Mock()
 
 
+class SimpleDatasetColumn:
+    def __init__(self, col_params: dict[str, Any]):
+        self.__dict__.update(col_params)
+
+
+TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
+TEMPORAL_COLUMNS = {
+    TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
+        {
+            "column_name": TEMPORAL_COLUMN_NAMES[0],
+            "is_dttm": True,
+            "python_date_format": None,
+            "type": "string",
+            "num_types": ["BIGINT"],
+        }
+    ),
+    TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
+        {
+            "column_name": TEMPORAL_COLUMN_NAMES[1],
+            "type": "BIGINT",
+            "is_dttm": True,
+            "python_date_format": "%Y",
+            "num_types": ["BIGINT"],
+        }
+    ),
+}
+
+
 @fixture
 def connector_registry() -> Mock:
-    return Mock(spec=["get_datasource"])
+    datasource_dao_mock = Mock(spec=["get_datasource"])
+    datasource_dao_mock.get_datasource.return_value = Mock()
+    datasource_dao_mock.get_datasource().get_column = Mock(
+        side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
+        if col_name in TEMPORAL_COLUMN_NAMES
+        else Mock()
+    )
+    datasource_dao_mock.get_datasource().db_extra = None
+    return datasource_dao_mock
 
 
 def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
@@ -112,3 +148,55 @@ class TestQueryObjectFactory:
             raw_query_context["result_type"], **raw_query_object
         )
         assert query_object.post_processing == []
+
+    def test_query_context_no_python_date_format_filters(
+        self,
+        query_object_factory: QueryObjectFactory,
+        raw_query_context: dict[str, Any],
+    ):
+        raw_query_object = raw_query_context["queries"][0]
+        raw_query_object["filters"].append(
+            {"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
+        )
+        query_object = query_object_factory.create(
+            raw_query_context["result_type"],
+            raw_query_context["datasource"],
+            **raw_query_object
+        )
+        assert query_object.filter[3]["val"] == 315532800000
+
+    def test_query_context_python_date_format_filters(
+        self,
+        query_object_factory: QueryObjectFactory,
+        raw_query_context: dict[str, Any],
+    ):
+        raw_query_object = raw_query_context["queries"][0]
+        raw_query_object["filters"].append(
+            {"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
+        )
+        query_object = query_object_factory.create(
+            raw_query_context["result_type"],
+            raw_query_context["datasource"],
+            **raw_query_object
+        )
+        assert query_object.filter[3]["val"] == 1980
+
+    def test_query_context_python_date_format_filters_list_of_values(
+        self,
+        query_object_factory: QueryObjectFactory,
+        raw_query_context: dict[str, Any],
+    ):
+        raw_query_object = raw_query_context["queries"][0]
+        raw_query_object["filters"].append(
+            {
+                "col": TEMPORAL_COLUMN_NAMES[1],
+                "op": "==",
+                "val": [315532800000, 631152000000],
+            }
+        )
+        query_object = query_object_factory.create(
+            raw_query_context["result_type"],
+            raw_query_context["datasource"],
+            **raw_query_object
+        )
+        assert query_object.filter[3]["val"] == [1980, 1990]