You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/09/04 05:50:01 UTC

[incubator-superset] branch master updated: Force quoted column aliases for Oracle-like databases (#5686)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 77fe9ef  Force quoted column aliases for Oracle-like databases (#5686)
77fe9ef is described below

commit 77fe9ef130a383d6902b63f89743446b5578d3b5
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Sep 4 08:49:58 2018 +0300

    Force quoted column aliases for Oracle-like databases (#5686)
    
    * Replace dataframe label override logic with table column override
    
    * Add mutation to any_date_col
    
    * Linting
    
    * Add mutation to oracle and redshift
    
    * Fine tune how and which labels are mutated
    
    * Implement alias quoting logic for oracle-like databases
    
    * Fix and align column and metric sqla_col methods
    
    * Clean up typos and redundant logic
    
    * Move new attribute to old location
    
    * Linting
    
    * Replace old sqla_col property references with function calls
    
    * Remove redundant calls to mutate_column_label
    
    * Move duplicated logic to common function
    
    * Add db_engine_specs to all sqla_col calls
    
    * Add missing mydb
    
    * Add note about snowflake-sqlalchemy regression
    
    * Make db_engine_spec mandatory in sqla_col
    
    * Small refactoring and cleanup
    
    * Remove db_engine_spec from get_from_clause call
    
    * Make db_engine_spec mandatory in adhoc_metric_to_sa
    
    * Remove redundant mutate_expression_label call
    
    * Add missing db_engine_specs to adhoc_metric_to_sa
    
    * Rename arg label_name to label in get_column_label()
    
    * Rename label function and add docstring
    
    * Remove redundant db_engine_spec args
    
    * Rename col_label to label
    
    * Remove get_column_name wrapper and make direct calls to db_engine_spec
    
    * Remove unneeded db_engine_specs
    
    * Rename sa_ vars to sqla_
---
 docs/installation.rst              |   4 ++
 superset/connectors/sqla/models.py | 106 ++++++++++++++++++++-----------------
 superset/dataframe.py              |   4 +-
 superset/db_engine_specs.py        |  66 +++++------------------
 superset/viz.py                    |   4 --
 5 files changed, 76 insertions(+), 108 deletions(-)

diff --git a/docs/installation.rst b/docs/installation.rst
index 6b08b82..7529323 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -393,6 +393,10 @@ Make sure the user has privileges to access and use all required
 databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
 not test for user rights during engine creation.
 
+*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
+snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
+use version 1.1.0 or try a newer version.
+
 See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
 
 Caching
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index f410279..037eb77 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
         s for s in export_fields if s not in ('table_id',)]
     export_parent = 'table'
 
-    @property
-    def sqla_col(self):
-        name = self.column_name
+    def get_sqla_col(self, label=None):
+        db_engine_spec = self.table.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(label if label else self.column_name)
         if not self.expression:
-            col = column(self.column_name).label(name)
+            col = column(self.column_name).label(label)
         else:
-            col = literal_column(self.expression).label(name)
+            col = literal_column(self.expression).label(label)
         return col
 
     @property
@@ -113,7 +113,7 @@ class TableColumn(Model, BaseColumn):
         return self.table
 
     def get_time_filter(self, start_dttm, end_dttm):
-        col = self.sqla_col.label('__time')
+        col = self.get_sqla_col(label='__time')
         l = []  # noqa: E741
         if start_dttm:
             l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
         s for s in export_fields if s not in ('table_id', )])
     export_parent = 'table'
 
-    @property
-    def sqla_col(self):
-        name = self.metric_name
-        return literal_column(self.expression).label(name)
+    def get_sqla_col(self, label=None):
+        db_engine_spec = self.table.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
+        return literal_column(self.expression).label(label)
 
     @property
     def perm(self):
@@ -421,11 +421,10 @@ class SqlaTable(Model, BaseDatasource):
         cols = {col.column_name: col for col in self.columns}
         target_col = cols[column_name]
         tp = self.get_template_processor()
-        db_engine_spec = self.database.db_engine_spec
 
         qry = (
-            select([target_col.sqla_col])
-            .select_from(self.get_from_clause(tp, db_engine_spec))
+            select([target_col.get_sqla_col()])
+            .select_from(self.get_from_clause(tp))
             .distinct()
         )
         if limit:
@@ -474,7 +473,7 @@ class SqlaTable(Model, BaseDatasource):
             tbl.schema = self.schema
         return tbl
 
-    def get_from_clause(self, template_processor=None, db_engine_spec=None):
+    def get_from_clause(self, template_processor=None):
         # Supporting arbitrary SQL statements in place of tables
         if self.sql:
             from_sql = self.sql
@@ -484,7 +483,7 @@ class SqlaTable(Model, BaseDatasource):
             return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
         return self.get_sqla_table()
 
-    def adhoc_metric_to_sa(self, metric, cols):
+    def adhoc_metric_to_sqla(self, metric, cols):
         """
         Turn an adhoc metric into a sqlalchemy column.
 
@@ -493,22 +492,25 @@ class SqlaTable(Model, BaseDatasource):
         :returns: The metric defined as a sqlalchemy column
         :rtype: sqlalchemy.sql.column
         """
-        expressionType = metric.get('expressionType')
-        if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
+        expression_type = metric.get('expressionType')
+        db_engine_spec = self.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(metric.get('label'))
+
+        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
             column_name = metric.get('column').get('column_name')
-            sa_column = column(column_name)
+            sqla_column = column(column_name)
             table_column = cols.get(column_name)
 
             if table_column:
-                sa_column = table_column.sqla_col
-
-            sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
-            sa_metric = sa_metric.label(metric.get('label'))
-            return sa_metric
-        elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
-            sa_metric = literal_column(metric.get('sqlExpression'))
-            sa_metric = sa_metric.label(metric.get('label'))
-            return sa_metric
+                sqla_column = table_column.get_sqla_col()
+
+            sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
+            sqla_metric = sqla_metric.label(label)
+            return sqla_metric
+        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
+            sqla_metric = literal_column(metric.get('sqlExpression'))
+            sqla_metric = sqla_metric.label(label)
+            return sqla_metric
         else:
             return None
 
@@ -566,15 +568,16 @@ class SqlaTable(Model, BaseDatasource):
         metrics_exprs = []
         for m in metrics:
             if utils.is_adhoc_metric(m):
-                metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
+                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
             elif m in metrics_dict:
-                metrics_exprs.append(metrics_dict.get(m).sqla_col)
+                metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
             else:
                 raise Exception(_("Metric '{}' is not valid".format(m)))
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
         else:
-            main_metric_expr = literal_column('COUNT(*)').label('ccount')
+            main_metric_expr = literal_column('COUNT(*)').label(
+                db_engine_spec.make_label_compatible('count'))
 
         select_exprs = []
         groupby_exprs = []
@@ -585,8 +588,8 @@ class SqlaTable(Model, BaseDatasource):
             inner_groupby_exprs = []
             for s in groupby:
                 col = cols[s]
-                outer = col.sqla_col
-                inner = col.sqla_col.label(col.column_name + '__')
+                outer = col.get_sqla_col()
+                inner = col.get_sqla_col(col.column_name + '__')
 
                 groupby_exprs.append(outer)
                 select_exprs.append(outer)
@@ -594,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
                 inner_select_exprs.append(inner)
         elif columns:
             for s in columns:
-                select_exprs.append(cols[s].sqla_col)
+                select_exprs.append(cols[s].get_sqla_col())
             metrics_exprs = []
 
         if granularity:
@@ -618,7 +621,7 @@ class SqlaTable(Model, BaseDatasource):
         select_exprs += metrics_exprs
         qry = sa.select(select_exprs)
 
-        tbl = self.get_from_clause(template_processor, db_engine_spec)
+        tbl = self.get_from_clause(template_processor)
 
         if not columns:
             qry = qry.group_by(*groupby_exprs)
@@ -638,9 +641,9 @@ class SqlaTable(Model, BaseDatasource):
                     target_column_is_numeric=col_obj.is_num,
                     is_list_target=is_list_target)
                 if op in ('in', 'not in'):
-                    cond = col_obj.sqla_col.in_(eq)
+                    cond = col_obj.get_sqla_col().in_(eq)
                     if '<NULL>' in eq:
-                        cond = or_(cond, col_obj.sqla_col == None)  # noqa
+                        cond = or_(cond, col_obj.get_sqla_col() == None)  # noqa
                     if op == 'not in':
                         cond = ~cond
                     where_clause_and.append(cond)
@@ -648,23 +651,24 @@ class SqlaTable(Model, BaseDatasource):
                     if col_obj.is_num:
                         eq = utils.string_to_num(flt['val'])
                     if op == '==':
-                        where_clause_and.append(col_obj.sqla_col == eq)
+                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                     elif op == '!=':
-                        where_clause_and.append(col_obj.sqla_col != eq)
+                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                     elif op == '>':
-                        where_clause_and.append(col_obj.sqla_col > eq)
+                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                     elif op == '<':
-                        where_clause_and.append(col_obj.sqla_col < eq)
+                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                     elif op == '>=':
-                        where_clause_and.append(col_obj.sqla_col >= eq)
+                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                     elif op == '<=':
-                        where_clause_and.append(col_obj.sqla_col <= eq)
+                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                     elif op == 'LIKE':
-                        where_clause_and.append(col_obj.sqla_col.like(eq))
+                        where_clause_and.append(col_obj.get_sqla_col().like(eq))
                     elif op == 'IS NULL':
-                        where_clause_and.append(col_obj.sqla_col == None)  # noqa
+                        where_clause_and.append(col_obj.get_sqla_col() == None)  # noqa
                     elif op == 'IS NOT NULL':
-                        where_clause_and.append(col_obj.sqla_col != None)  # noqa
+                        where_clause_and.append(
+                            col_obj.get_sqla_col() != None)  # noqa
         if extras:
             where = extras.get('where')
             if where:
@@ -686,7 +690,7 @@ class SqlaTable(Model, BaseDatasource):
         for col, ascending in orderby:
             direction = asc if ascending else desc
             if utils.is_adhoc_metric(col):
-                col = self.adhoc_metric_to_sa(col, cols)
+                col = self.adhoc_metric_to_sqla(col, cols)
             qry = qry.order_by(direction(col))
 
         if row_limit:
@@ -712,12 +716,12 @@ class SqlaTable(Model, BaseDatasource):
                 ob = inner_main_metric_expr
                 if timeseries_limit_metric:
                     if utils.is_adhoc_metric(timeseries_limit_metric):
-                        ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
+                        ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
                     elif timeseries_limit_metric in metrics_dict:
                         timeseries_limit_metric = metrics_dict.get(
                             timeseries_limit_metric,
                         )
-                        ob = timeseries_limit_metric.sqla_col
+                        ob = timeseries_limit_metric.get_sqla_col()
                     else:
                         raise Exception(_("Metric '{}' is not valid".format(m)))
                 direction = desc if order_desc else asc
@@ -762,7 +766,7 @@ class SqlaTable(Model, BaseDatasource):
             group = []
             for dimension in dimensions:
                 col_obj = cols.get(dimension)
-                group.append(col_obj.sqla_col == row[dimension])
+                group.append(col_obj.get_sqla_col() == row[dimension])
             groups.append(and_(*group))
 
         return or_(*groups)
@@ -816,6 +820,7 @@ class SqlaTable(Model, BaseDatasource):
             .filter(or_(TableColumn.column_name == col.name
                         for col in table.columns)))
         dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
+        db_engine_spec = self.database.db_engine_spec
 
         for col in table.columns:
             try:
@@ -848,6 +853,9 @@ class SqlaTable(Model, BaseDatasource):
         ))
         if not self.main_dttm_col:
             self.main_dttm_col = any_date_col
+        for metric in metrics:
+            metric.metric_name = db_engine_spec.mutate_expression_label(
+                metric.metric_name)
         self.add_missing_metrics(metrics)
         db.session.merge(self)
         db.session.commit()
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 834f118..1678dd9 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -73,9 +73,7 @@ class SupersetDataFrame(object):
         if cursor_description:
             column_names = [col[0] for col in cursor_description]
 
-        case_sensitive = db_engine_spec.consistent_case_sensitivity
-        self.column_names = dedup(column_names,
-                                  case_sensitive=case_sensitive)
+        self.column_names = dedup(column_names)
 
         data = data or []
         self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 2ce7db0..a8a9faa 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -35,7 +35,7 @@ import sqlalchemy as sqla
 from sqlalchemy import select
 from sqlalchemy.engine import create_engine
 from sqlalchemy.engine.url import make_url
-from sqlalchemy.sql import text
+from sqlalchemy.sql import quoted_name, text
 from sqlalchemy.sql.expression import TextAsFrom
 import sqlparse
 from tableschema import Table
@@ -101,7 +101,7 @@ class BaseEngineSpec(object):
     time_secondary_columns = False
     inner_joins = True
     allows_subquery = True
-    consistent_case_sensitivity = True  # do results have same case as qry for col names?
+    force_column_alias_quotes = False
     arraysize = None
 
     @classmethod
@@ -376,55 +376,15 @@ class BaseEngineSpec(object):
         cursor.execute(query)
 
     @classmethod
-    def adjust_df_column_names(cls, df, fd):
-        """Based of fields in form_data, return dataframe with new column names
-
-        Usually sqla engines return column names whose case matches that of the
-        original query. For example:
-            SELECT 1 as col1, 2 as COL2, 3 as Col_3
-        will usually result in the following df.columns:
-            ['col1', 'COL2', 'Col_3'].
-        For these engines there is no need to adjust the dataframe column names
-        (default behavior). However, some engines (at least Snowflake, Oracle and
-        Redshift) return column names with different case than in the original query,
-        usually all uppercase. For these the column names need to be adjusted to
-        correspond to the case of the fields specified in the form data for Viz
-        to work properly. This adjustment can be done here.
+    def make_label_compatible(cls, label):
         """
-        if cls.consistent_case_sensitivity:
-            return df
-        else:
-            return cls.align_df_col_names_with_form_data(df, fd)
-
-    @staticmethod
-    def align_df_col_names_with_form_data(df, fd):
-        """Helper function to rename columns that have changed case during query.
-
-        Returns a dataframe where column names have been adjusted to correspond with
-        column names in form data (case insensitive). Examples:
-        dataframe: 'col1', form_data: 'col1' -> no change
-        dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
-        dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+        Return a sqlalchemy.sql.elements.quoted_name if the engine requires
+        quoting of aliases to ensure that select query and query results
+        have same case.
         """
-
-        columns = set()
-        lowercase_mapping = {}
-
-        metrics = utils.get_metric_names(fd.get('metrics', []))
-        groupby = fd.get('groupby', [])
-        other_cols = [utils.DTTM_ALIAS]
-        for col in metrics + groupby + other_cols:
-            columns.add(col)
-            lowercase_mapping[col.lower()] = col
-
-        rename_cols = {}
-        for col in df.columns:
-            if col not in columns:
-                orig_col = lowercase_mapping.get(col.lower())
-                if orig_col:
-                    rename_cols[col] = orig_col
-
-        return df.rename(index=str, columns=rename_cols)
+        if cls.force_column_alias_quotes is True:
+            return quoted_name(label, True)
+        return label
 
     @staticmethod
     def mutate_expression_label(label):
@@ -478,7 +438,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
+
     time_grain_functions = {
         None: '{col}',
         'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
 
 
 class OracleEngineSpec(PostgresBaseEngineSpec):
     engine = 'oracle'
     limit_method = LimitMethod.WRAP_SQL
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
 
     time_grain_functions = {
         None: '{col}',
@@ -545,6 +506,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
 class Db2EngineSpec(BaseEngineSpec):
     engine = 'ibm_db_sa'
     limit_method = LimitMethod.WRAP_SQL
+    force_column_alias_quotes = True
     time_grain_functions = {
         None: '{col}',
         'PT1S': 'CAST({col} as TIMESTAMP)'
diff --git a/superset/viz.py b/superset/viz.py
index 90c209e..6a18dfa 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -391,10 +391,6 @@ class BaseViz(object):
         if query_obj and not is_loaded:
             try:
                 df = self.get_df(query_obj)
-                if hasattr(self.datasource, 'database') and \
-                        hasattr(self.datasource.database, 'db_engine_spec'):
-                    db_engine_spec = self.datasource.database.db_engine_spec
-                    df = db_engine_spec.adjust_df_column_names(df, self.form_data)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr('loaded_from_source')
                     is_loaded = True