You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by el...@apache.org on 2023/07/07 22:41:33 UTC

[superset] 03/12: fix: handle temporal columns in presto partitions (#24054)

This is an automated email from the ASF dual-hosted git repository.

elizabeth pushed a commit to branch 2.1
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 75be3dd7b45ed98ade643d56b05a1ab10d8874b4
Author: Rob Moore <gi...@users.noreply.github.com>
AuthorDate: Fri May 19 21:29:42 2023 +0100

    fix: handle temporal columns in presto partitions (#24054)
---
 superset/db_engine_specs/base.py                |  2 +-
 superset/db_engine_specs/hive.py                |  2 +-
 superset/db_engine_specs/presto.py              | 18 ++++++-----
 tests/unit_tests/db_engine_specs/test_presto.py | 43 ++++++++++++++++++++++++-
 4 files changed, 54 insertions(+), 11 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 27dd34a802..b789bbe70c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -1168,7 +1168,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         schema: Optional[str],
         database: Database,
         query: Select,
-        columns: Optional[List[Dict[str, str]]] = None,
+        columns: Optional[List[Dict[str, Any]]] = None,
     ) -> Optional[Select]:
         """
         Add a where clause to a query to reference only the most recent partition
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index f07d53518c..44dc435c2c 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -404,7 +404,7 @@ class HiveEngineSpec(PrestoEngineSpec):
         schema: Optional[str],
         database: "Database",
         query: Select,
-        columns: Optional[List[Dict[str, str]]] = None,
+        columns: Optional[List[Dict[str, Any]]] = None,
     ) -> Optional[Select]:
         try:
             col_names, values = cls.latest_partition(
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6bd556b79e..87f362acc8 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -462,7 +462,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
         schema: Optional[str],
         database: Database,
         query: Select,
-        columns: Optional[List[Dict[str, str]]] = None,
+        columns: Optional[List[Dict[str, Any]]] = None,
     ) -> Optional[Select]:
         try:
             col_names, values = cls.latest_partition(
@@ -480,13 +480,15 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
         }
 
         for col_name, value in zip(col_names, values):
-            if col_name in column_type_by_name:
-                if column_type_by_name.get(col_name) == "TIMESTAMP":
-                    query = query.where(Column(col_name, TimeStamp()) == value)
-                elif column_type_by_name.get(col_name) == "DATE":
-                    query = query.where(Column(col_name, Date()) == value)
-                else:
-                    query = query.where(Column(col_name) == value)
+            col_type = column_type_by_name.get(col_name)
+
+            if isinstance(col_type, types.DATE):
+                col_type = Date()
+            elif isinstance(col_type, types.TIMESTAMP):
+                col_type = TimeStamp()
+
+            query = query.where(Column(col_name, col_type) == value)
+
         return query
 
     @classmethod
diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py
index a30fab94c9..8f55b1c048 100644
--- a/tests/unit_tests/db_engine_specs/test_presto.py
+++ b/tests/unit_tests/db_engine_specs/test_presto.py
@@ -16,10 +16,13 @@
 # under the License.
 from datetime import datetime
 from typing import Any, Dict, Optional, Type
+from unittest import mock
 
 import pytest
 import pytz
-from sqlalchemy import types
+from pyhive.sqlalchemy_presto import PrestoDialect
+from sqlalchemy import sql, text, types
+from sqlalchemy.engine.url import make_url
 
 from superset.utils.core import GenericDataType
 from tests.unit_tests.db_engine_specs.utils import (
@@ -82,3 +85,41 @@ def test_get_column_spec(
     from superset.db_engine_specs.presto import PrestoEngineSpec as spec
 
     assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)
+
+
+@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
+@pytest.mark.parametrize(
+    ["column_type", "column_value", "expected_value"],
+    [
+        (types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
+        (types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
+        (types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
+        (types.INT(), 1234, "1234"),
+    ],
+)
+def test_where_latest_partition(
+    mock_latest_partition: Any,
+    column_type: Any,
+    column_value: str,
+    expected_value: str,
+) -> None:
+    """
+    Test the ``where_latest_partition`` method
+    """
+    from superset.db_engine_specs.presto import PrestoEngineSpec as spec
+
+    mock_latest_partition.return_value = (["partition_key"], [column_value])
+
+    query = sql.select(text("* FROM table"))
+    columns = [{"name": "partition_key", "type": column_type}]
+
+    expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
+    result = spec.where_latest_partition(
+        "table", mock.MagicMock(), mock.MagicMock(), query, columns
+    )
+    assert result is not None
+    actual = result.compile(
+        dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
+    )
+
+    assert str(actual) == expected