You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by vi...@apache.org on 2019/05/30 05:28:50 UTC

[incubator-superset] branch master updated: Make timestamp expression native SQLAlchemy element (#7131)

This is an automated email from the ASF dual-hosted git repository.

villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 34407e8  Make timestamp expression native SQLAlchemy element (#7131)
34407e8 is described below

commit 34407e896296a7a98a3bc31bde20ae1f3ac02005
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Thu May 30 08:28:37 2019 +0300

    Make timestamp expression native SQLAlchemy element (#7131)
    
    * Add native sqla component for time expressions
    
    * Add unit tests and remove old tests
    
    * Remove redundant _grains_dict method
    
    * Clarify time_grain logic
    
    * Add docstrings and typing
    
    * Fix flake8 errors
    
    * Add missing typings
    
    * Rename to TimestampExpression
    
    * Remove redundant tests
    
    * Fix broken reference to db.database_name due to refactor
---
 superset/connectors/sqla/models.py | 30 +++++++------
 superset/db_engine_specs.py        | 89 +++++++++++++++++++++++---------------
 superset/models/core.py            | 10 +----
 tests/db_engine_specs_test.py      | 61 +++++++++++++++++++++++---
 tests/model_tests.py               | 59 -------------------------
 5 files changed, 128 insertions(+), 121 deletions(-)

diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index f178db4..de9f4d1 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -18,6 +18,7 @@
 from collections import namedtuple, OrderedDict
 from datetime import datetime
 import logging
+from typing import Optional, Union
 
 from flask import escape, Markup
 from flask_appbuilder import Model
@@ -32,11 +33,12 @@ from sqlalchemy.exc import CompileError
 from sqlalchemy.orm import backref, relationship
 from sqlalchemy.schema import UniqueConstraint
 from sqlalchemy.sql import column, literal_column, table, text
-from sqlalchemy.sql.expression import TextAsFrom
+from sqlalchemy.sql.expression import Label, TextAsFrom
 import sqlparse
 
 from superset import app, db, security_manager
 from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
+from superset.db_engine_specs import TimestampExpression
 from superset.jinja_context import get_template_processor
 from superset.models.annotations import Annotation
 from superset.models.core import Database
@@ -140,8 +142,14 @@ class TableColumn(Model, BaseColumn):
             l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc)))
         return and_(*l)
 
-    def get_timestamp_expression(self, time_grain):
-        """Getting the time component of the query"""
+    def get_timestamp_expression(self, time_grain: Optional[str]) \
+            -> Union[TimestampExpression, Label]:
+        """
+        Return a SQLAlchemy Core element representation of self to be used in a query.
+
+        :param time_grain: Optional time grain, e.g. P1Y
+        :return: A TimeExpression object wrapped in a Label if supported by db
+        """
         label = utils.DTTM_ALIAS
 
         db = self.table.database
@@ -150,16 +158,12 @@ class TableColumn(Model, BaseColumn):
         if not self.expression and not time_grain and not is_epoch:
             sqla_col = column(self.column_name, type_=DateTime)
             return self.table.make_sqla_column_compatible(sqla_col, label)
-        grain = None
-        if time_grain:
-            grain = db.grains_dict().get(time_grain)
-            if not grain:
-                raise NotImplementedError(
-                    f'No grain spec for {time_grain} for database {db.database_name}')
-        col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name)
-        expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
-        sqla_col = literal_column(expr, type_=DateTime)
-        return self.table.make_sqla_column_compatible(sqla_col, label)
+        if self.expression:
+            col = literal_column(self.expression)
+        else:
+            col = column(self.column_name)
+        time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
+        return self.table.make_sqla_column_compatible(time_expr, label)
 
     @classmethod
     def import_obj(cls, i_column):
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 04efef7..b6103c3 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -36,19 +36,20 @@ import os
 import re
 import textwrap
 import time
-from typing import List, Tuple
+from typing import Dict, List, Optional, Tuple
 from urllib import parse
 
 from flask import g
 from flask_babel import lazy_gettext as _
 import pandas
 import sqlalchemy as sqla
-from sqlalchemy import Column, select, types
+from sqlalchemy import Column, DateTime, select, types
 from sqlalchemy.engine import create_engine
 from sqlalchemy.engine.base import Engine
 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.engine.result import RowProxy
 from sqlalchemy.engine.url import make_url
+from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.sql import quoted_name, text
 from sqlalchemy.sql.expression import ColumnClause
 from sqlalchemy.sql.expression import TextAsFrom
@@ -90,6 +91,24 @@ builtin_time_grains = {
 }
 
 
+class TimestampExpression(ColumnClause):
+    def __init__(self, expr: str, col: ColumnClause, **kwargs):
+        """Sqlalchemy class that can be can be used to render native column elements
+        respeting engine-specific quoting rules as part of a string-based expression.
+
+        :param expr: Sql expression with '{col}' denoting the locations where the col
+        object will be rendered.
+        :param col: the target column
+        """
+        super().__init__(expr, **kwargs)
+        self.col = col
+
+
+@compiles(TimestampExpression)
+def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
+    return element.name.replace('{col}', compiler.process(element.col, **kw))
+
+
 def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
     ret_list = []
     blacklist = blacklist if blacklist else []
@@ -112,7 +131,7 @@ class BaseEngineSpec(object):
     """Abstract class for database engine specific configurations"""
 
     engine = 'base'  # str as defined in sqlalchemy.engine.engine
-    time_grain_functions: dict = {}
+    time_grain_functions: Dict[Optional[str], str] = {}
     time_groupby_inline = False
     limit_method = LimitMethod.FORCE_LIMIT
     time_secondary_columns = False
@@ -125,16 +144,31 @@ class BaseEngineSpec(object):
     try_remove_schema_from_table_name = True
 
     @classmethod
-    def get_time_expr(cls, expr, pdf, time_grain, grain):
+    def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
+                           time_grain: Optional[str]) -> TimestampExpression:
+        """
+        Construct a TimeExpression to be used in a SQLAlchemy query.
+
+        :param col: Target column for the TimeExpression
+        :param pdf: date format (seconds or milliseconds)
+        :param time_grain: time grain, e.g. P1Y for 1 year
+        :return: TimestampExpression object
+        """
+        if time_grain:
+            time_expr = cls.time_grain_functions.get(time_grain)
+            if not time_expr:
+                raise NotImplementedError(
+                    f'No grain spec for {time_grain} for database {cls.engine}')
+        else:
+            time_expr = '{col}'
+
         # if epoch, translate to DATE using db specific conf
         if pdf == 'epoch_s':
-            expr = cls.epoch_to_dttm().format(col=expr)
+            time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
         elif pdf == 'epoch_ms':
-            expr = cls.epoch_ms_to_dttm().format(col=expr)
+            time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
 
-        if grain:
-            expr = grain.function.format(col=expr)
-        return expr
+        return TimestampExpression(time_expr, col, type_=DateTime)
 
     @classmethod
     def get_time_grains(cls):
@@ -489,13 +523,6 @@ class BaseEngineSpec(object):
             label = label[:cls.max_column_name_length]
         return label
 
-    @staticmethod
-    def get_timestamp_column(expression, column_name):
-        """Return the expression if defined, otherwise return column_name. Some
-        engines require forcing quotes around column name, in which case this method
-        can be overridden."""
-        return expression or column_name
-
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -543,16 +570,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
         tables.extend(inspector.get_foreign_table_names(schema))
         return sorted(tables)
 
-    @staticmethod
-    def get_timestamp_column(expression, column_name):
-        """Postgres is unable to identify mixed case column names unless they
-        are quoted."""
-        if expression:
-            return expression
-        elif column_name.lower() != column_name:
-            return f'"{column_name}"'
-        return column_name
-
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
@@ -794,7 +811,7 @@ class MySQLEngineSpec(BaseEngineSpec):
               'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
     }
 
-    type_code_map: dict = {}  # loaded from get_datatype only if needed
+    type_code_map: Dict[int, str] = {}  # loaded from get_datatype only if needed
 
     @classmethod
     def convert_dttm(cls, target_type, dttm):
@@ -1812,20 +1829,21 @@ class PinotEngineSpec(BaseEngineSpec):
     inner_joins = False
     supports_column_aliases = False
 
-    _time_grain_to_datetimeconvert = {
+    # Pinot does its own conversion below
+    time_grain_functions: Dict[Optional[str], str] = {
         'PT1S': '1:SECONDS',
         'PT1M': '1:MINUTES',
         'PT1H': '1:HOURS',
         'P1D': '1:DAYS',
-        'P1Y': '1:YEARS',
+        'P1W': '1:WEEKS',
         'P1M': '1:MONTHS',
+        'P0.25Y': '3:MONTHS',
+        'P1Y': '1:YEARS',
     }
 
-    # Pinot does its own conversion below
-    time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()}
-
     @classmethod
-    def get_time_expr(cls, expr, pdf, time_grain, grain):
+    def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
+                           time_grain: Optional[str]) -> TimestampExpression:
         is_epoch = pdf in ('epoch_s', 'epoch_ms')
         if not is_epoch:
             raise NotImplementedError('Pinot currently only supports epochs')
@@ -1834,11 +1852,12 @@ class PinotEngineSpec(BaseEngineSpec):
         # We are not really converting any time units, just bucketing them.
         seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
         tf = f'1:{seconds_or_ms}:EPOCH'
-        granularity = cls._time_grain_to_datetimeconvert.get(time_grain)
+        granularity = cls.time_grain_functions.get(time_grain)
         if not granularity:
             raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
         # In pinot the output is a string since there is no timestamp column like pg
-        return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")'
+        time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
+        return TimestampExpression(time_expr, col)
 
     @classmethod
     def make_select_compatible(cls, groupby_exprs, select_exprs):
diff --git a/superset/models/core.py b/superset/models/core.py
index 047a3dd..b379af7 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -1029,21 +1029,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
         """Defines time granularity database-specific expressions.
 
         The idea here is to make it easy for users to change the time grain
-        form a datetime (maybe the source grain is arbitrary timestamps, daily
+        from a datetime (maybe the source grain is arbitrary timestamps, daily
         or 5 minutes increments) to another, "truncated" datetime. Since
         each database has slightly different but similar datetime functions,
         this allows a mapping between database engines and actual functions.
         """
         return self.db_engine_spec.get_time_grains()
 
-    def grains_dict(self):
-        """Allowing to lookup grain by either label or duration
-
-        For backward compatibility"""
-        d = {grain.duration: grain for grain in self.grains()}
-        d.update({grain.label: grain for grain in self.grains()})
-        return d
-
     def get_extra(self):
         extra = {}
         if self.extra:
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index 0372366..43f89c1 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -17,15 +17,16 @@
 import inspect
 from unittest import mock
 
-from sqlalchemy import column, select, table
-from sqlalchemy.dialects.mssql import pymssql
+from sqlalchemy import column, literal_column, select, table
+from sqlalchemy.dialects import mssql, oracle, postgresql
 from sqlalchemy.engine.result import RowProxy
 from sqlalchemy.types import String, UnicodeText
 
 from superset import db_engine_specs
 from superset.db_engine_specs import (
     BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
-    MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
+    MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
+    PrestoEngineSpec,
 )
 from superset.models.core import Database
 from .base_tests import SupersetTestCase
@@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase):
         assert_type('NTEXT', UnicodeText)
 
     def test_mssql_where_clause_n_prefix(self):
-        dialect = pymssql.dialect()
+        dialect = mssql.dialect()
         spec = MssqlEngineSpec
         str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
         unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
@@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
             where(unicode_col == 'abc')
 
         query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
-        query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'"  # noqa
+        query_expected = 'SELECT col, unicode_col \n' \
+                         'FROM tbl \n' \
+                         "WHERE col = 'abc' AND unicode_col = N'abc'"
         self.assertEqual(query, query_expected)
 
     def test_get_table_names(self):
@@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase):
         pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
             schema='schema', inspector=inspector)
         self.assertListEqual(pg_result_expected, pg_result)
+
+    def test_pg_time_expression_literal_no_grain(self):
+        col = literal_column('COALESCE(a, b)')
+        expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
+        result = str(expr.compile(dialect=postgresql.dialect()))
+        self.assertEqual(result, 'COALESCE(a, b)')
+
+    def test_pg_time_expression_literal_1y_grain(self):
+        col = literal_column('COALESCE(a, b)')
+        expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
+        result = str(expr.compile(dialect=postgresql.dialect()))
+        self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
+
+    def test_pg_time_expression_lower_column_no_grain(self):
+        col = column('lower_case')
+        expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
+        result = str(expr.compile(dialect=postgresql.dialect()))
+        self.assertEqual(result, 'lower_case')
+
+    def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
+        col = column('lower_case')
+        expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
+        result = str(expr.compile(dialect=postgresql.dialect()))
+        self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))")  # noqa
+
+    def test_pg_time_expression_mixed_case_column_1y_grain(self):
+        col = column('MixedCase')
+        expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
+        result = str(expr.compile(dialect=postgresql.dialect()))
+        self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
+
+    def test_mssql_time_expression_mixed_case_column_1y_grain(self):
+        col = column('MixedCase')
+        expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
+        result = str(expr.compile(dialect=mssql.dialect()))
+        self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
+
+    def test_oracle_time_expression_reserved_keyword_1m_grain(self):
+        col = column('decimal')
+        expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
+        result = str(expr.compile(dialect=oracle.dialect()))
+        self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
+
+    def test_pinot_time_expression_sec_1m_grain(self):
+        col = column('tstamp')
+        expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
+        result = str(expr.compile())
+        self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")')  # noqa
diff --git a/tests/model_tests.py b/tests/model_tests.py
index 0fe03de..53e53cc 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -109,47 +109,6 @@ class DatabaseModelTestCase(SupersetTestCase):
         LIMIT 100""")
         assert sql.startswith(expected)
 
-    def test_grains_dict(self):
-        uri = 'mysql://root@localhost'
-        database = Database(sqlalchemy_uri=uri)
-        d = database.grains_dict()
-        self.assertEquals(d.get('day').function, 'DATE({col})')
-        self.assertEquals(d.get('P1D').function, 'DATE({col})')
-        self.assertEquals(d.get('Time Column').function, '{col}')
-
-    def test_postgres_expression_time_grain(self):
-        uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
-        database = Database(sqlalchemy_uri=uri)
-        pdf, time_grain = '', 'P1D'
-        expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', ''
-        grain = database.grains_dict().get(time_grain)
-        col = database.db_engine_spec.get_timestamp_column(expression, column_name)
-        grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
-        grain_expr_expected = grain.function.replace('{col}', expression)
-        self.assertEqual(grain_expr, grain_expr_expected)
-
-    def test_postgres_lowercase_col_time_grain(self):
-        uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
-        database = Database(sqlalchemy_uri=uri)
-        pdf, time_grain = '', 'P1D'
-        expression, column_name = '', 'lowercase_col'
-        grain = database.grains_dict().get(time_grain)
-        col = database.db_engine_spec.get_timestamp_column(expression, column_name)
-        grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
-        grain_expr_expected = grain.function.replace('{col}', column_name)
-        self.assertEqual(grain_expr, grain_expr_expected)
-
-    def test_postgres_mixedcase_col_time_grain(self):
-        uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
-        database = Database(sqlalchemy_uri=uri)
-        pdf, time_grain = '', 'P1D'
-        expression, column_name = '', 'MixedCaseCol'
-        grain = database.grains_dict().get(time_grain)
-        col = database.db_engine_spec.get_timestamp_column(expression, column_name)
-        grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
-        grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
-        self.assertEqual(grain_expr, grain_expr_expected)
-
     def test_single_statement(self):
         main_db = get_main_database(db.session)
 
@@ -217,24 +176,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
             self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
         ds_col.expression = prev_ds_expr
 
-    def test_get_timestamp_expression_backward(self):
-        tbl = self.get_table_by_name('birth_names')
-        ds_col = tbl.get_column('ds')
-
-        ds_col.expression = None
-        ds_col.python_date_format = None
-        sqla_literal = ds_col.get_timestamp_expression('day')
-        compiled = '{}'.format(sqla_literal.compile())
-        if tbl.database.backend == 'mysql':
-            self.assertEquals(compiled, 'DATE(ds)')
-
-        ds_col.expression = None
-        ds_col.python_date_format = None
-        sqla_literal = ds_col.get_timestamp_expression('Time Column')
-        compiled = '{}'.format(sqla_literal.compile())
-        if tbl.database.backend == 'mysql':
-            self.assertEquals(compiled, 'ds')
-
     def query_with_expr_helper(self, is_timeseries, inner_join=True):
         tbl = self.get_table_by_name('birth_names')
         ds_col = tbl.get_column('ds')