You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2023/10/09 14:35:57 UTC
[superset] 02/02: fix: Apply normalization to all dttm columns (#25147)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git
commit c205016e9d64cab056addbdb7d4d8aada65a9b62
Author: Kamil Gabryjelski <ka...@gmail.com>
AuthorDate: Fri Oct 6 18:47:00 2023 +0200
fix: Apply normalization to all dttm columns (#25147)
(cherry picked from commit 58fcd292a979212a3d6f636917021c12c299fd93)
---
superset/common/query_context_factory.py | 1 +
superset/common/query_context_processor.py | 5 +-
superset/common/query_object_factory.py | 66 +++++++++++++++-
tests/integration_tests/query_context_tests.py | 8 +-
.../unit_tests/common/test_query_object_factory.py | 90 +++++++++++++++++++++-
5 files changed, 160 insertions(+), 10 deletions(-)
diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py
index e4680ed5ed..62e8b79893 100644
--- a/superset/common/query_context_factory.py
+++ b/superset/common/query_context_factory.py
@@ -186,6 +186,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
+ or filter["op"] != "TEMPORAL_RANGE"
]
def _apply_filters(self, query_object: QueryObject) -> None:
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index f6152b232a..754c9ae91a 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -285,10 +285,11 @@ class QueryContextProcessor:
datasource = self._qc_datasource
labels = tuple(
label
- for label in [
+ for label in {
*get_base_axis_labels(query_object.columns),
+ *[col for col in query_object.columns or [] if isinstance(col, str)],
query_object.granularity,
- ]
+ }
if datasource
# Query datasource didn't support `get_column`
and hasattr(datasource, "get_column")
diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py
index ae85912cdf..bc773281b8 100644
--- a/superset/common/query_object_factory.py
+++ b/superset/common/query_object_factory.py
@@ -16,17 +16,24 @@
# under the License.
from __future__ import annotations
+from datetime import datetime
from typing import Any, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.common.utils.time_range_utils import get_since_until_from_time_range
-from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
+from superset.utils.core import (
+ apply_max_row_limit,
+ DatasourceDict,
+ DatasourceType,
+ FilterOperator,
+ QueryObjectFilterClause,
+)
if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker
- from superset.connectors.base.models import BaseDatasource
+ from superset.connectors.base.models import BaseColumn, BaseDatasource
from superset.daos.datasource import DatasourceDAO
@@ -66,6 +73,10 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
+ if datasource_model_instance and kwargs.get("filters", []):
+ kwargs["filters"] = self._process_filters(
+ datasource_model_instance, kwargs["filters"]
+ )
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
@@ -102,3 +113,54 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
# light version of the view.utils.core
# import view.utils require application context
# Todo: move it and the view.utils.core to utils package
+
+ def _process_filters(
+ self, datasource: BaseDatasource, query_filters: list[QueryObjectFilterClause]
+ ) -> list[QueryObjectFilterClause]:
+ def get_dttm_filter_value(
+ value: Any, col: BaseColumn, date_format: str
+ ) -> int | str:
+ if not isinstance(value, int):
+ return value
+ if date_format in {"epoch_ms", "epoch_s"}:
+ if date_format == "epoch_s":
+ value = str(value)
+ else:
+ value = str(value * 1000)
+ else:
+ dttm = datetime.utcfromtimestamp(value / 1000)
+ value = dttm.strftime(date_format)
+
+ if col.type in col.num_types:
+ value = int(value)
+ return value
+
+ for query_filter in query_filters:
+ if query_filter.get("op") == FilterOperator.TEMPORAL_RANGE:
+ continue
+ filter_col = query_filter.get("col")
+ if not isinstance(filter_col, str):
+ continue
+ column = datasource.get_column(filter_col)
+ if not column:
+ continue
+ filter_value = query_filter.get("val")
+
+ date_format = column.python_date_format
+ if not date_format and datasource.db_extra:
+ date_format = datasource.db_extra.get(
+ "python_date_format_by_column_name", {}
+ ).get(column.column_name)
+
+ if column.is_dttm and date_format:
+ if isinstance(filter_value, list):
+ query_filter["val"] = [
+ get_dttm_filter_value(value, column, date_format)
+ for value in filter_value
+ ]
+ else:
+ query_filter["val"] = get_dttm_filter_value(
+ filter_value, column, date_format
+ )
+
+ return query_filters
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 8c2082d1c4..00a98b2c21 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -836,11 +836,9 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
query_object = qc.queries[0]
df = qc.get_df_payload(query_object)["df"]
- if query_object.datasource.database.backend == "sqlite":
- # sqlite returns string as timestamp column
- assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
- assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
- else:
+
+ # sqlite doesn't have timestamp columns
+ if query_object.datasource.database.backend != "sqlite":
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"
diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py
index 02304828dc..4e8fadfe3e 100644
--- a/tests/unit_tests/common/test_query_object_factory.py
+++ b/tests/unit_tests/common/test_query_object_factory.py
@@ -43,9 +43,45 @@ def session_factory() -> Mock:
return Mock()
+class SimpleDatasetColumn:
+ def __init__(self, col_params: dict[str, Any]):
+ self.__dict__.update(col_params)
+
+
+TEMPORAL_COLUMN_NAMES = ["temporal_column", "temporal_column_with_python_date_format"]
+TEMPORAL_COLUMNS = {
+ TEMPORAL_COLUMN_NAMES[0]: SimpleDatasetColumn(
+ {
+ "column_name": TEMPORAL_COLUMN_NAMES[0],
+ "is_dttm": True,
+ "python_date_format": None,
+ "type": "string",
+ "num_types": ["BIGINT"],
+ }
+ ),
+ TEMPORAL_COLUMN_NAMES[1]: SimpleDatasetColumn(
+ {
+ "column_name": TEMPORAL_COLUMN_NAMES[1],
+ "type": "BIGINT",
+ "is_dttm": True,
+ "python_date_format": "%Y",
+ "num_types": ["BIGINT"],
+ }
+ ),
+}
+
+
@fixture
def connector_registry() -> Mock:
- return Mock(spec=["get_datasource"])
+ datasource_dao_mock = Mock(spec=["get_datasource"])
+ datasource_dao_mock.get_datasource.return_value = Mock()
+ datasource_dao_mock.get_datasource().get_column = Mock(
+ side_effect=lambda col_name: TEMPORAL_COLUMNS[col_name]
+ if col_name in TEMPORAL_COLUMN_NAMES
+ else Mock()
+ )
+ datasource_dao_mock.get_datasource().db_extra = None
+ return datasource_dao_mock
def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
@@ -112,3 +148,55 @@ class TestQueryObjectFactory:
raw_query_context["result_type"], **raw_query_object
)
assert query_object.post_processing == []
+
+ def test_query_context_no_python_date_format_filters(
+ self,
+ query_object_factory: QueryObjectFactory,
+ raw_query_context: dict[str, Any],
+ ):
+ raw_query_object = raw_query_context["queries"][0]
+ raw_query_object["filters"].append(
+ {"col": TEMPORAL_COLUMN_NAMES[0], "op": "==", "val": 315532800000}
+ )
+ query_object = query_object_factory.create(
+ raw_query_context["result_type"],
+ raw_query_context["datasource"],
+ **raw_query_object
+ )
+ assert query_object.filter[3]["val"] == 315532800000
+
+ def test_query_context_python_date_format_filters(
+ self,
+ query_object_factory: QueryObjectFactory,
+ raw_query_context: dict[str, Any],
+ ):
+ raw_query_object = raw_query_context["queries"][0]
+ raw_query_object["filters"].append(
+ {"col": TEMPORAL_COLUMN_NAMES[1], "op": "==", "val": 315532800000}
+ )
+ query_object = query_object_factory.create(
+ raw_query_context["result_type"],
+ raw_query_context["datasource"],
+ **raw_query_object
+ )
+ assert query_object.filter[3]["val"] == 1980
+
+ def test_query_context_python_date_format_filters_list_of_values(
+ self,
+ query_object_factory: QueryObjectFactory,
+ raw_query_context: dict[str, Any],
+ ):
+ raw_query_object = raw_query_context["queries"][0]
+ raw_query_object["filters"].append(
+ {
+ "col": TEMPORAL_COLUMN_NAMES[1],
+ "op": "==",
+ "val": [315532800000, 631152000000],
+ }
+ )
+ query_object = query_object_factory.create(
+ raw_query_context["result_type"],
+ raw_query_context["datasource"],
+ **raw_query_object
+ )
+ assert query_object.filter[3]["val"] == [1980, 1990]