You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2022/12/19 19:27:27 UTC
[superset] branch master updated: chore: Re-add inheritance of Presto macros for Trino et al. (#22435)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 71982ee174 chore: Re-add inheritance of Presto macros for Trino et al. (#22435)
71982ee174 is described below
commit 71982ee174a0dce811a108987fd72aa0aa391903
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Tue Dec 20 08:27:20 2022 +1300
chore: Re-add inheritance of Presto macros for Trino et al. (#22435)
---
superset/db_engine_specs/presto.py | 411 +++++++++++----------
superset/db_engine_specs/trino.py | 33 +-
.../db_engine_specs/presto_tests.py | 3 +-
.../db_engine_specs/trino_tests.py | 16 +
4 files changed, 253 insertions(+), 210 deletions(-)
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 6755039734..2a3acb8bb5 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -311,6 +311,203 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
"""
return database.get_df("SHOW FUNCTIONS")["Function"].tolist()
+ @classmethod
+ def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
+ cls,
+ table_name: str,
+ database: Database,
+ limit: int = 0,
+ order_by: Optional[List[Tuple[str, bool]]] = None,
+ filters: Optional[Dict[Any, Any]] = None,
+ ) -> str:
+ """Returns a partition query
+
+ :param table_name: the name of the table to get partitions from
+ :type table_name: str
+ :param limit: the number of partitions to be returned
+ :type limit: int
+ :param order_by: a list of tuples of field name and a boolean
+ that determines if that field should be sorted in descending
+ order
+ :type order_by: list of (str, bool) tuples
+ :param filters: dict of field name and filter value combinations
+ """
+ limit_clause = "LIMIT {}".format(limit) if limit else ""
+ order_by_clause = ""
+ if order_by:
+ l = []
+ for field, desc in order_by:
+ l.append(field + " DESC" if desc else "")
+ order_by_clause = "ORDER BY " + ", ".join(l)
+
+ where_clause = ""
+ if filters:
+ l = []
+ for field, value in filters.items():
+ l.append(f"{field} = '{value}'")
+ where_clause = "WHERE " + " AND ".join(l)
+
+ presto_version = database.get_extra().get("version")
+
+ # Partition select syntax changed in v0.199, so check here.
+ # Default to the new syntax if version is unset.
+ partition_select_clause = (
+ f'SELECT * FROM "{table_name}$partitions"'
+ if not presto_version
+ or StrictVersion(presto_version) >= StrictVersion("0.199")
+ else f"SHOW PARTITIONS FROM {table_name}"
+ )
+
+ sql = dedent(
+ f"""\
+ {partition_select_clause}
+ {where_clause}
+ {order_by_clause}
+ {limit_clause}
+ """
+ )
+ return sql
+
+ @classmethod
+ def where_latest_partition( # pylint: disable=too-many-arguments
+ cls,
+ table_name: str,
+ schema: Optional[str],
+ database: Database,
+ query: Select,
+ columns: Optional[List[Dict[str, str]]] = None,
+ ) -> Optional[Select]:
+ try:
+ col_names, values = cls.latest_partition(
+ table_name, schema, database, show_first=True
+ )
+ except Exception: # pylint: disable=broad-except
+ # table is not partitioned
+ return None
+
+ if values is None:
+ return None
+
+ column_type_by_name = {
+ column.get("name"): column.get("type") for column in columns or []
+ }
+
+ for col_name, value in zip(col_names, values):
+ if col_name in column_type_by_name:
+ if column_type_by_name.get(col_name) == "TIMESTAMP":
+ query = query.where(Column(col_name, TimeStamp()) == value)
+ elif column_type_by_name.get(col_name) == "DATE":
+ query = query.where(Column(col_name, Date()) == value)
+ else:
+ query = query.where(Column(col_name) == value)
+ return query
+
+ @classmethod
+ def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+ if not df.empty:
+ return df.to_records(index=False)[0].item()
+ return None
+
+ @classmethod
+ @cache_manager.data_cache.memoize(timeout=60)
+ def latest_partition(
+ cls,
+ table_name: str,
+ schema: Optional[str],
+ database: Database,
+ show_first: bool = False,
+ ) -> Tuple[List[str], Optional[List[str]]]:
+ """Returns col name and the latest (max) partition value for a table
+
+ :param table_name: the name of the table
+ :param schema: schema / database / namespace
+ :param database: database query will be run against
+ :type database: models.Database
+ :param show_first: displays the value for the first partitioning key
+ if there are many partitioning keys
+ :type show_first: bool
+
+ >>> latest_partition('foo_table')
+ (['ds'], ('2018-01-01',))
+ """
+ indexes = database.get_indexes(table_name, schema)
+ if not indexes:
+ raise SupersetTemplateException(
+ f"Error getting partition for {schema}.{table_name}. "
+ "Verify that this table has a partition."
+ )
+
+ if len(indexes[0]["column_names"]) < 1:
+ raise SupersetTemplateException(
+ "The table should have one partitioned field"
+ )
+
+ if not show_first and len(indexes[0]["column_names"]) > 1:
+ raise SupersetTemplateException(
+ "The table should have a single partitioned field "
+ "to use this function. You may want to use "
+ "`presto.latest_sub_partition`"
+ )
+
+ column_names = indexes[0]["column_names"]
+ part_fields = [(column_name, True) for column_name in column_names]
+ sql = cls._partition_query(table_name, database, 1, part_fields)
+ df = database.get_df(sql, schema)
+ return column_names, cls._latest_partition_from_df(df)
+
+ @classmethod
+ def latest_sub_partition(
+ cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
+ ) -> Any:
+ """Returns the latest (max) partition value for a table
+
+ A filtering criteria should be passed for all fields that are
+ partitioned except for the field to be returned. For example,
+ if a table is partitioned by (``ds``, ``event_type`` and
+ ``event_category``) and you want the latest ``ds``, you'll want
+ to provide a filter as keyword arguments for both
+ ``event_type`` and ``event_category`` as in
+ ``latest_sub_partition('my_table',
+ event_category='page', event_type='click')``
+
+ :param table_name: the name of the table, can be just the table
+ name or a fully qualified table name as ``schema_name.table_name``
+ :type table_name: str
+ :param schema: schema / database / namespace
+ :type schema: str
+ :param database: database query will be run against
+ :type database: models.Database
+
+ :param kwargs: keyword arguments define the filtering criteria
+ on the partition list. There can be many of these.
+ :type kwargs: str
+ >>> latest_sub_partition('sub_partition_table', event_type='click')
+ '2018-01-01'
+ """
+ indexes = database.get_indexes(table_name, schema)
+ part_fields = indexes[0]["column_names"]
+ for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
+ if k not in k in part_fields: # pylint: disable=comparison-with-itself
+ msg = f"Field [{k}] is not part of the portioning key"
+ raise SupersetTemplateException(msg)
+ if len(kwargs.keys()) != len(part_fields) - 1:
+ msg = (
+ "A filter needs to be specified for {} out of the " "{} fields."
+ ).format(len(part_fields) - 1, len(part_fields))
+ raise SupersetTemplateException(msg)
+
+ for field in part_fields:
+ if field not in kwargs.keys():
+ field_to_return = field
+
+ sql = cls._partition_query(
+ table_name, database, 1, [(field_to_return, True)], kwargs
+ )
+ df = database.get_df(sql, schema)
+ if df.empty:
+ return ""
+ return df.to_dict()[field_to_return][0]
+
class PrestoEngineSpec(PrestoBaseEngineSpec):
engine = "presto"
@@ -958,21 +1155,24 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
indexes = database.get_indexes(table_name, schema_name)
if indexes:
- cols = indexes[0].get("column_names", [])
- full_table_name = table_name
- if schema_name and "." not in table_name:
- full_table_name = "{}.{}".format(schema_name, table_name)
- pql = cls._partition_query(full_table_name, database)
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
if not latest_parts:
latest_parts = tuple([None] * len(col_names))
+
metadata["partitions"] = {
- "cols": cols,
+ "cols": sorted(indexes[0].get("column_names", [])),
"latest": dict(zip(col_names, latest_parts)),
- "partitionQuery": pql,
+ "partitionQuery": cls._partition_query(
+ table_name=(
+ f"{schema_name}.{table_name}"
+ if schema_name and "." not in table_name
+ else table_name
+ ),
+ database=database,
+ ),
}
# flake8 is not matching `Optional[str]` to `Any` for some reason...
@@ -1085,203 +1285,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return error_dict.get("message", _("Unknown Presto Error"))
return utils.error_msg_from_exception(ex)
- @classmethod
- def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
- cls,
- table_name: str,
- database: Database,
- limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
- ) -> str:
- """Returns a partition query
-
- :param table_name: the name of the table to get partitions from
- :type table_name: str
- :param limit: the number of partitions to be returned
- :type limit: int
- :param order_by: a list of tuples of field name and a boolean
- that determines if that field should be sorted in descending
- order
- :type order_by: list of (str, bool) tuples
- :param filters: dict of field name and filter value combinations
- """
- limit_clause = "LIMIT {}".format(limit) if limit else ""
- order_by_clause = ""
- if order_by:
- l = []
- for field, desc in order_by:
- l.append(field + " DESC" if desc else "")
- order_by_clause = "ORDER BY " + ", ".join(l)
-
- where_clause = ""
- if filters:
- l = []
- for field, value in filters.items():
- l.append(f"{field} = '{value}'")
- where_clause = "WHERE " + " AND ".join(l)
-
- presto_version = database.get_extra().get("version")
-
- # Partition select syntax changed in v0.199, so check here.
- # Default to the new syntax if version is unset.
- partition_select_clause = (
- f'SELECT * FROM "{table_name}$partitions"'
- if not presto_version
- or StrictVersion(presto_version) >= StrictVersion("0.199")
- else f"SHOW PARTITIONS FROM {table_name}"
- )
-
- sql = dedent(
- f"""\
- {partition_select_clause}
- {where_clause}
- {order_by_clause}
- {limit_clause}
- """
- )
- return sql
-
- @classmethod
- def where_latest_partition( # pylint: disable=too-many-arguments
- cls,
- table_name: str,
- schema: Optional[str],
- database: Database,
- query: Select,
- columns: Optional[List[Dict[str, str]]] = None,
- ) -> Optional[Select]:
- try:
- col_names, values = cls.latest_partition(
- table_name, schema, database, show_first=True
- )
- except Exception: # pylint: disable=broad-except
- # table is not partitioned
- return None
-
- if values is None:
- return None
-
- column_type_by_name = {
- column.get("name"): column.get("type") for column in columns or []
- }
-
- for col_name, value in zip(col_names, values):
- if col_name in column_type_by_name:
- if column_type_by_name.get(col_name) == "TIMESTAMP":
- query = query.where(Column(col_name, TimeStamp()) == value)
- elif column_type_by_name.get(col_name) == "DATE":
- query = query.where(Column(col_name, Date()) == value)
- else:
- query = query.where(Column(col_name) == value)
- return query
-
- @classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
- if not df.empty:
- return df.to_records(index=False)[0].item()
- return None
-
- @classmethod
- @cache_manager.data_cache.memoize(timeout=60)
- def latest_partition(
- cls,
- table_name: str,
- schema: Optional[str],
- database: Database,
- show_first: bool = False,
- ) -> Tuple[List[str], Optional[List[str]]]:
- """Returns col name and the latest (max) partition value for a table
-
- :param table_name: the name of the table
- :param schema: schema / database / namespace
- :param database: database query will be run against
- :type database: models.Database
- :param show_first: displays the value for the first partitioning key
- if there are many partitioning keys
- :type show_first: bool
-
- >>> latest_partition('foo_table')
- (['ds'], ('2018-01-01',))
- """
- indexes = database.get_indexes(table_name, schema)
- if not indexes:
- raise SupersetTemplateException(
- f"Error getting partition for {schema}.{table_name}. "
- "Verify that this table has a partition."
- )
-
- if len(indexes[0]["column_names"]) < 1:
- raise SupersetTemplateException(
- "The table should have one partitioned field"
- )
-
- if not show_first and len(indexes[0]["column_names"]) > 1:
- raise SupersetTemplateException(
- "The table should have a single partitioned field "
- "to use this function. You may want to use "
- "`presto.latest_sub_partition`"
- )
-
- column_names = indexes[0]["column_names"]
- part_fields = [(column_name, True) for column_name in column_names]
- sql = cls._partition_query(table_name, database, 1, part_fields)
- df = database.get_df(sql, schema)
- return column_names, cls._latest_partition_from_df(df)
-
- @classmethod
- def latest_sub_partition(
- cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
- ) -> Any:
- """Returns the latest (max) partition value for a table
-
- A filtering criteria should be passed for all fields that are
- partitioned except for the field to be returned. For example,
- if a table is partitioned by (``ds``, ``event_type`` and
- ``event_category``) and you want the latest ``ds``, you'll want
- to provide a filter as keyword arguments for both
- ``event_type`` and ``event_category`` as in
- ``latest_sub_partition('my_table',
- event_category='page', event_type='click')``
-
- :param table_name: the name of the table, can be just the table
- name or a fully qualified table name as ``schema_name.table_name``
- :type table_name: str
- :param schema: schema / database / namespace
- :type schema: str
- :param database: database query will be run against
- :type database: models.Database
-
- :param kwargs: keyword arguments define the filtering criteria
- on the partition list. There can be many of these.
- :type kwargs: str
- >>> latest_sub_partition('sub_partition_table', event_type='click')
- '2018-01-01'
- """
- indexes = database.get_indexes(table_name, schema)
- part_fields = indexes[0]["column_names"]
- for k in kwargs.keys(): # pylint: disable=consider-iterating-dictionary
- if k not in k in part_fields: # pylint: disable=comparison-with-itself
- msg = f"Field [{k}] is not part of the portioning key"
- raise SupersetTemplateException(msg)
- if len(kwargs.keys()) != len(part_fields) - 1:
- msg = (
- "A filter needs to be specified for {} out of the " "{} fields."
- ).format(len(part_fields) - 1, len(part_fields))
- raise SupersetTemplateException(msg)
-
- for field in part_fields:
- if field not in kwargs.keys():
- field_to_return = field
-
- sql = cls._partition_query(
- table_name, database, 1, [(field_to_return, True)], kwargs
- )
- df = database.get_df(sql, schema)
- if df.empty:
- return ""
- return df.to_dict()[field_to_return][0]
-
@classmethod
def get_column_spec(
cls,
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index c6faf6db6c..2a1d8cc639 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -83,11 +83,34 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
indexes = database.get_indexes(table_name, schema_name)
if indexes:
- partitions_columns = []
- for index in indexes:
- if index.get("name") == "partition":
- partitions_columns += index.get("column_names", [])
- metadata["partitions"] = {"cols": partitions_columns}
+ col_names, latest_parts = cls.latest_partition(
+ table_name, schema_name, database, show_first=True
+ )
+
+ if not latest_parts:
+ latest_parts = tuple([None] * len(col_names))
+
+ metadata["partitions"] = {
+ "cols": sorted(
+ list(
+ set(
+ column_name
+ for index in indexes
+ if index.get("name") == "partition"
+ for column_name in index.get("column_names", [])
+ )
+ )
+ ),
+ "latest": dict(zip(col_names, latest_parts)),
+ "partitionQuery": cls._partition_query(
+ table_name=(
+ f"{schema_name}.{table_name}"
+ if schema_name and "." not in table_name
+ else table_name
+ ),
+ database=database,
+ ),
+ }
if database.has_view_by_name(table_name, schema_name):
metadata["view"] = database.inspector.get_view_definition(
diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py
index 4a76d59a46..eef3bb8d36 100644
--- a/tests/integration_tests/db_engine_specs/presto_tests.py
+++ b/tests/integration_tests/db_engine_specs/presto_tests.py
@@ -492,7 +492,8 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
db.get_df = mock.Mock(return_value=df)
PrestoEngineSpec.get_create_view = mock.Mock(return_value=None)
result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
- self.assertEqual({"ds": "01-01-19", "hour": 1}, result["partitions"]["latest"])
+ assert result["partitions"]["cols"] == ["ds", "hour"]
+ assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}
def test_presto_where_latest_partition(self):
db = mock.Mock()
diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py
index 41a4f4e0f3..6379d013b2 100644
--- a/tests/integration_tests/db_engine_specs/trino_tests.py
+++ b/tests/integration_tests/db_engine_specs/trino_tests.py
@@ -16,8 +16,10 @@
# under the License.
import json
from typing import Any, Dict
+from unittest import mock
from unittest.mock import Mock, patch
+import pandas as pd
import pytest
from sqlalchemy import types
@@ -196,3 +198,17 @@ class TestTrinoDbEngineSpec(TestDbEngineSpec):
TrinoEngineSpec.convert_dttm("DATE", dttm),
"DATE '2019-01-02'",
)
+
+ def test_extra_table_metadata(self):
+ db = mock.Mock()
+ db.get_indexes = mock.Mock(
+ return_value=[{"column_names": ["ds", "hour"], "name": "partition"}]
+ )
+ db.get_extra = mock.Mock(return_value={})
+ db.has_view_by_name = mock.Mock(return_value=None)
+ db.get_df = mock.Mock(
+ return_value=pd.DataFrame({"ds": ["01-01-19"], "hour": [1]})
+ )
+ result = TrinoEngineSpec.extra_table_metadata(db, "test_table", "test_schema")
+ assert result["partitions"]["cols"] == ["ds", "hour"]
+ assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1}