You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2022/12/19 19:27:27 UTC

[superset] branch master updated: chore: Re-add inheritance of Presto macros for Trino et al. (#22435)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 71982ee174 chore: Re-add inheritance of Presto macros for Trino et al. (#22435)
71982ee174 is described below

commit 71982ee174a0dce811a108987fd72aa0aa391903
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Tue Dec 20 08:27:20 2022 +1300

    chore: Re-add inheritance of Presto macros for Trino et al. (#22435)
---
 superset/db_engine_specs/presto.py                 | 411 +++++++++++----------
 superset/db_engine_specs/trino.py                  |  33 +-
 .../db_engine_specs/presto_tests.py                |   3 +-
 .../db_engine_specs/trino_tests.py                 |  16 +
 4 files changed, 253 insertions(+), 210 deletions(-)

diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6755039734..2a3acb8bb5 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -311,6 +311,203 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
         """
         return database.get_df("SHOW FUNCTIONS")["Function"].tolist()
 
+    @classmethod
+    def _partition_query(  # pylint: disable=too-many-arguments,too-many-locals
+        cls,
+        table_name: str,
+        database: Database,
+        limit: int = 0,
+        order_by: Optional[List[Tuple[str, bool]]] = None,
+        filters: Optional[Dict[Any, Any]] = None,
+    ) -> str:
+        """Returns a partition query
+
+        :param table_name: the name of the table to get partitions from
+        :type table_name: str
+        :param limit: the number of partitions to be returned
+        :type limit: int
+        :param order_by: a list of tuples of field name and a boolean
+            that determines if that field should be sorted in descending
+            order
+        :type order_by: list of (str, bool) tuples
+        :param filters: dict of field name and filter value combinations
+        """
+        limit_clause = "LIMIT {}".format(limit) if limit else ""
+        order_by_clause = ""
+        if order_by:
+            l = []
+            for field, desc in order_by:
+                l.append(field + " DESC" if desc else "")
+            order_by_clause = "ORDER BY " + ", ".join(l)
+
+        where_clause = ""
+        if filters:
+            l = []
+            for field, value in filters.items():
+                l.append(f"{field} = '{value}'")
+            where_clause = "WHERE " + " AND ".join(l)
+
+        presto_version = database.get_extra().get("version")
+
+        # Partition select syntax changed in v0.199, so check here.
+        # Default to the new syntax if version is unset.
+        partition_select_clause = (
+            f'SELECT * FROM "{table_name}$partitions"'
+            if not presto_version
+            or StrictVersion(presto_version) >= StrictVersion("0.199")
+            else f"SHOW PARTITIONS FROM {table_name}"
+        )
+
+        sql = dedent(
+            f"""\
+            {partition_select_clause}
+            {where_clause}
+            {order_by_clause}
+            {limit_clause}
+        """
+        )
+        return sql
+
+    @classmethod
+    def where_latest_partition(  # pylint: disable=too-many-arguments
+        cls,
+        table_name: str,
+        schema: Optional[str],
+        database: Database,
+        query: Select,
+        columns: Optional[List[Dict[str, str]]] = None,
+    ) -> Optional[Select]:
+        try:
+            col_names, values = cls.latest_partition(
+                table_name, schema, database, show_first=True
+            )
+        except Exception:  # pylint: disable=broad-except
+            # table is not partitioned
+            return None
+
+        if values is None:
+            return None
+
+        column_type_by_name = {
+            column.get("name"): column.get("type") for column in columns or []
+        }
+
+        for col_name, value in zip(col_names, values):
+            if col_name in column_type_by_name:
+                if column_type_by_name.get(col_name) == "TIMESTAMP":
+                    query = query.where(Column(col_name, TimeStamp()) == value)
+                elif column_type_by_name.get(col_name) == "DATE":
+                    query = query.where(Column(col_name, Date()) == value)
+                else:
+                    query = query.where(Column(col_name) == value)
+        return query
+
+    @classmethod
+    def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+        if not df.empty:
+            return df.to_records(index=False)[0].item()
+        return None
+
+    @classmethod
+    @cache_manager.data_cache.memoize(timeout=60)
+    def latest_partition(
+        cls,
+        table_name: str,
+        schema: Optional[str],
+        database: Database,
+        show_first: bool = False,
+    ) -> Tuple[List[str], Optional[List[str]]]:
+        """Returns col name and the latest (max) partition value for a table
+
+        :param table_name: the name of the table
+        :param schema: schema / database / namespace
+        :param database: database query will be run against
+        :type database: models.Database
+        :param show_first: displays the value for the first partitioning key
+          if there are many partitioning keys
+        :type show_first: bool
+
+        >>> latest_partition('foo_table')
+        (['ds'], ('2018-01-01',))
+        """
+        indexes = database.get_indexes(table_name, schema)
+        if not indexes:
+            raise SupersetTemplateException(
+                f"Error getting partition for {schema}.{table_name}. "
+                "Verify that this table has a partition."
+            )
+
+        if len(indexes[0]["column_names"]) < 1:
+            raise SupersetTemplateException(
+                "The table should have one partitioned field"
+            )
+
+        if not show_first and len(indexes[0]["column_names"]) > 1:
+            raise SupersetTemplateException(
+                "The table should have a single partitioned field "
+                "to use this function. You may want to use "
+                "`presto.latest_sub_partition`"
+            )
+
+        column_names = indexes[0]["column_names"]
+        part_fields = [(column_name, True) for column_name in column_names]
+        sql = cls._partition_query(table_name, database, 1, part_fields)
+        df = database.get_df(sql, schema)
+        return column_names, cls._latest_partition_from_df(df)
+
+    @classmethod
+    def latest_sub_partition(
+        cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
+    ) -> Any:
+        """Returns the latest (max) partition value for a table
+
+        A filtering criteria should be passed for all fields that are
+        partitioned except for the field to be returned. For example,
+        if a table is partitioned by (``ds``, ``event_type`` and
+        ``event_category``) and you want the latest ``ds``, you'll want
+        to provide a filter as keyword arguments for both
+        ``event_type`` and ``event_category`` as in
+        ``latest_sub_partition('my_table',
+            event_category='page', event_type='click')``
+
+        :param table_name: the name of the table, can be just the table
+            name or a fully qualified table name as ``schema_name.table_name``
+        :type table_name: str
+        :param schema: schema / database / namespace
+        :type schema: str
+        :param database: database query will be run against
+        :type database: models.Database
+
+        :param kwargs: keyword arguments define the filtering criteria
+            on the partition list. There can be many of these.
+        :type kwargs: str
+        >>> latest_sub_partition('sub_partition_table', event_type='click')
+        '2018-01-01'
+        """
+        indexes = database.get_indexes(table_name, schema)
+        part_fields = indexes[0]["column_names"]
+        for k in kwargs.keys():  # pylint: disable=consider-iterating-dictionary
+            if k not in k in part_fields:  # pylint: disable=comparison-with-itself
+                msg = f"Field [{k}] is not part of the portioning key"
+                raise SupersetTemplateException(msg)
+        if len(kwargs.keys()) != len(part_fields) - 1:
+            msg = (
+                "A filter needs to be specified for {} out of the " "{} fields."
+            ).format(len(part_fields) - 1, len(part_fields))
+            raise SupersetTemplateException(msg)
+
+        for field in part_fields:
+            if field not in kwargs.keys():
+                field_to_return = field
+
+        sql = cls._partition_query(
+            table_name, database, 1, [(field_to_return, True)], kwargs
+        )
+        df = database.get_df(sql, schema)
+        if df.empty:
+            return ""
+        return df.to_dict()[field_to_return][0]
+
 
 class PrestoEngineSpec(PrestoBaseEngineSpec):
     engine = "presto"
@@ -958,21 +1155,24 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
 
         indexes = database.get_indexes(table_name, schema_name)
         if indexes:
-            cols = indexes[0].get("column_names", [])
-            full_table_name = table_name
-            if schema_name and "." not in table_name:
-                full_table_name = "{}.{}".format(schema_name, table_name)
-            pql = cls._partition_query(full_table_name, database)
             col_names, latest_parts = cls.latest_partition(
                 table_name, schema_name, database, show_first=True
             )
 
             if not latest_parts:
                 latest_parts = tuple([None] * len(col_names))
+
             metadata["partitions"] = {
-                "cols": cols,
+                "cols": sorted(indexes[0].get("column_names", [])),
                 "latest": dict(zip(col_names, latest_parts)),
-                "partitionQuery": pql,
+                "partitionQuery": cls._partition_query(
+                    table_name=(
+                        f"{schema_name}.{table_name}"
+                        if schema_name and "." not in table_name
+                        else table_name
+                    ),
+                    database=database,
+                ),
             }
 
         # flake8 is not matching `Optional[str]` to `Any` for some reason...
@@ -1085,203 +1285,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
             return error_dict.get("message", _("Unknown Presto Error"))
         return utils.error_msg_from_exception(ex)
 
-    @classmethod
-    def _partition_query(  # pylint: disable=too-many-arguments,too-many-locals
-        cls,
-        table_name: str,
-        database: Database,
-        limit: int = 0,
-        order_by: Optional[List[Tuple[str, bool]]] = None,
-        filters: Optional[Dict[Any, Any]] = None,
-    ) -> str:
-        """Returns a partition query
-
-        :param table_name: the name of the table to get partitions from
-        :type table_name: str
-        :param limit: the number of partitions to be returned
-        :type limit: int
-        :param order_by: a list of tuples of field name and a boolean
-            that determines if that field should be sorted in descending
-            order
-        :type order_by: list of (str, bool) tuples
-        :param filters: dict of field name and filter value combinations
-        """
-        limit_clause = "LIMIT {}".format(limit) if limit else ""
-        order_by_clause = ""
-        if order_by:
-            l = []
-            for field, desc in order_by:
-                l.append(field + " DESC" if desc else "")
-            order_by_clause = "ORDER BY " + ", ".join(l)
-
-        where_clause = ""
-        if filters:
-            l = []
-            for field, value in filters.items():
-                l.append(f"{field} = '{value}'")
-            where_clause = "WHERE " + " AND ".join(l)
-
-        presto_version = database.get_extra().get("version")
-
-        # Partition select syntax changed in v0.199, so check here.
-        # Default to the new syntax if version is unset.
-        partition_select_clause = (
-            f'SELECT * FROM "{table_name}$partitions"'
-            if not presto_version
-            or StrictVersion(presto_version) >= StrictVersion("0.199")
-            else f"SHOW PARTITIONS FROM {table_name}"
-        )
-
-        sql = dedent(
-            f"""\
-            {partition_select_clause}
-            {where_clause}
-            {order_by_clause}
-            {limit_clause}
-        """
-        )
-        return sql
-
-    @classmethod
-    def where_latest_partition(  # pylint: disable=too-many-arguments
-        cls,
-        table_name: str,
-        schema: Optional[str],
-        database: Database,
-        query: Select,
-        columns: Optional[List[Dict[str, str]]] = None,
-    ) -> Optional[Select]:
-        try:
-            col_names, values = cls.latest_partition(
-                table_name, schema, database, show_first=True
-            )
-        except Exception:  # pylint: disable=broad-except
-            # table is not partitioned
-            return None
-
-        if values is None:
-            return None
-
-        column_type_by_name = {
-            column.get("name"): column.get("type") for column in columns or []
-        }
-
-        for col_name, value in zip(col_names, values):
-            if col_name in column_type_by_name:
-                if column_type_by_name.get(col_name) == "TIMESTAMP":
-                    query = query.where(Column(col_name, TimeStamp()) == value)
-                elif column_type_by_name.get(col_name) == "DATE":
-                    query = query.where(Column(col_name, Date()) == value)
-                else:
-                    query = query.where(Column(col_name) == value)
-        return query
-
-    @classmethod
-    def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
-        if not df.empty:
-            return df.to_records(index=False)[0].item()
-        return None
-
-    @classmethod
-    @cache_manager.data_cache.memoize(timeout=60)
-    def latest_partition(
-        cls,
-        table_name: str,
-        schema: Optional[str],
-        database: Database,
-        show_first: bool = False,
-    ) -> Tuple[List[str], Optional[List[str]]]:
-        """Returns col name and the latest (max) partition value for a table
-
-        :param table_name: the name of the table
-        :param schema: schema / database / namespace
-        :param database: database query will be run against
-        :type database: models.Database
-        :param show_first: displays the value for the first partitioning key
-          if there are many partitioning keys
-        :type show_first: bool
-
-        >>> latest_partition('foo_table')
-        (['ds'], ('2018-01-01',))
-        """
-        indexes = database.get_indexes(table_name, schema)
-        if not indexes:
-            raise SupersetTemplateException(
-                f"Error getting partition for {schema}.{table_name}. "
-                "Verify that this table has a partition."
-            )
-
-        if len(indexes[0]["column_names"]) < 1:
-            raise SupersetTemplateException(
-                "The table should have one partitioned field"
-            )
-
-        if not show_first and len(indexes[0]["column_names"]) > 1:
-            raise SupersetTemplateException(
-                "The table should have a single partitioned field "
-                "to use this function. You may want to use "
-                "`presto.latest_sub_partition`"
-            )
-
-        column_names = indexes[0]["column_names"]
-        part_fields = [(column_name, True) for column_name in column_names]
-        sql = cls._partition_query(table_name, database, 1, part_fields)
-        df = database.get_df(sql, schema)
-        return column_names, cls._latest_partition_from_df(df)
-
-    @classmethod
-    def latest_sub_partition(
-        cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
-    ) -> Any:
-        """Returns the latest (max) partition value for a table
-
-        A filtering criteria should be passed for all fields that are
-        partitioned except for the field to be returned. For example,
-        if a table is partitioned by (``ds``, ``event_type`` and
-        ``event_category``) and you want the latest ``ds``, you'll want
-        to provide a filter as keyword arguments for both
-        ``event_type`` and ``event_category`` as in
-        ``latest_sub_partition('my_table',
-            event_category='page', event_type='click')``
-
-        :param table_name: the name of the table, can be just the table
-            name or a fully qualified table name as ``schema_name.table_name``
-        :type table_name: str
-        :param schema: schema / database / namespace
-        :type schema: str
-        :param database: database query will be run against
-        :type database: models.Database
-
-        :param kwargs: keyword arguments define the filtering criteria
-            on the partition list. There can be many of these.
-        :type kwargs: str
-        >>> latest_sub_partition('sub_partition_table', event_type='click')
-        '2018-01-01'
-        """
-        indexes = database.get_indexes(table_name, schema)
-        part_fields = indexes[0]["column_names"]
-        for k in kwargs.keys():  # pylint: disable=consider-iterating-dictionary
-            if k not in k in part_fields:  # pylint: disable=comparison-with-itself
-                msg = f"Field [{k}] is not part of the portioning key"
-                raise SupersetTemplateException(msg)
-        if len(kwargs.keys()) != len(part_fields) - 1:
-            msg = (
-                "A filter needs to be specified for {} out of the " "{} fields."
-            ).format(len(part_fields) - 1, len(part_fields))
-            raise SupersetTemplateException(msg)
-
-        for field in part_fields:
-            if field not in kwargs.keys():
-                field_to_return = field
-
-        sql = cls._partition_query(
-            table_name, database, 1, [(field_to_return, True)], kwargs
-        )
-        df = database.get_df(sql, schema)
-        if df.empty:
-            return ""
-        return df.to_dict()[field_to_return][0]
-
     @classmethod
     def get_column_spec(
         cls,
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index c6faf6db6c..2a1d8cc639 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -83,11 +83,34 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
 
         indexes = database.get_indexes(table_name, schema_name)
         if indexes:
-            partitions_columns = []
-            for index in indexes:
-                if index.get("name") == "partition":
-                    partitions_columns += index.get("column_names", [])
-            metadata["partitions"] = {"cols": partitions_columns}
+            col_names, latest_parts = cls.latest_partition(
+                table_name, schema_name, database, show_first=True
+            )
+
+            if not latest_parts:
+                latest_parts = tuple([None] * len(col_names))
+
+            metadata["partitions"] = {
+                "cols": sorted(
+                    list(
+                        set(
+                            column_name
+                            for index in indexes
+                            if index.get("name") == "partition"
+                            for column_name in index.get("column_names", [])
+                        )
+                    )
+                ),
+                "latest": dict(zip(col_names, latest_parts)),
+                "partitionQuery": cls._partition_query(
+                    table_name=(
+                        f"{schema_name}.{table_name}"
+                        if schema_name and "." not in table_name
+                        else table_name
+                    ),
+                    database=database,
+                ),
+            }
 
         if database.has_view_by_name(table_name, schema_name):
             metadata["view"] = database.inspector.get_view_definition(
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 4a76d59a46..eef3bb8d36 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -492,7 +492,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
         db.get_df = mock.Mock(return_value=df)
         PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
         result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
-        self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
+        assert result["partitions"]["cols"] == ["ds", "hour"]
+        assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
 
     def test_presto_where_latest_partition(self):
         db = mock.Mock()
diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py
index 41a4f4e0f3..6379d013b2 100644
--- a/tests/integration_tests/db_engine_specs/trino_tests.py
+++ b/tests/integration_tests/db_engine_specs/trino_tests.py
@@ -16,8 +16,10 @@
 # under the License.
 import json
 from typing import Any, Dict
+from unittest import mock
 from unittest.mock import Mock, patch
 
+import pandas as pd
 import pytest
 from sqlalchemy import types
 
@@ -196,3 +198,17 @@ class TestTrinoDbEngineSpec(TestDbEngineSpec):
             TrinoEngineSpec.convert_dttm("DATE", dttm),
             "DATE '2019-01-02'",
         )
+
+    def test_extra_table_metadata(self):
+        db = mock.Mock()
+        db.get_indexes = mock.Mock(
+            return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
+        )
+        db.get_extra = mock.Mock(return_value={})
+        db.has_view_by_name = mock.Mock(return_value=None)
+        db.get_df = mock.Mock(
+            return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
+        )
+        result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
+        assert result["partitions"]["cols"] == ["ds", "hour"]
+        assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}