You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@superset.apache.org by GitBox <gi...@apache.org> on 2018/09/04 05:50:00 UTC

[GitHub] mistercrunch closed pull request #5686: Force quoted column aliases for Oracle-like databases

mistercrunch closed pull request #5686: Force quoted column aliases for Oracle-like databases
URL: https://github.com/apache/incubator-superset/pull/5686
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/docs/installation.rst b/docs/installation.rst
index 008a2648f1..d1d1fd5e40 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required
 databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
 not test for user rights during engine creation.
 
+*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
+snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
+use version 1.1.0 or try a newer version.
+
 See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
 
 Caching
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 44a2cfb1ca..fcfe9e0b3f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
         s for s in export_fields if s not in ('table_id',)]
     export_parent = 'table'
 
-    @property
-    def sqla_col(self):
-        name = self.column_name
+    def get_sqla_col(self, label=None):
+        db_engine_spec = self.table.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(label if label else self.column_name)
         if not self.expression:
-            col = column(self.column_name).label(name)
+            col = column(self.column_name).label(label)
         else:
-            col = literal_column(self.expression).label(name)
+            col = literal_column(self.expression).label(label)
         return col
 
     @property
@@ -113,7 +113,7 @@ def datasource(self):
         return self.table
 
     def get_time_filter(self, start_dttm, end_dttm):
-        col = self.sqla_col.label('__time')
+        col = self.get_sqla_col(label='__time')
         l = []  # noqa: E741
         if start_dttm:
             l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
         s for s in export_fields if s not in ('table_id', )])
     export_parent = 'table'
 
-    @property
-    def sqla_col(self):
-        name = self.metric_name
-        return literal_column(self.expression).label(name)
+    def get_sqla_col(self, label=None):
+        db_engine_spec = self.table.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
+        return literal_column(self.expression).label(label)
 
     @property
     def perm(self):
@@ -421,11 +421,10 @@ def values_for_column(self, column_name, limit=10000):
         cols = {col.column_name: col for col in self.columns}
         target_col = cols[column_name]
         tp = self.get_template_processor()
-        db_engine_spec = self.database.db_engine_spec
 
         qry = (
-            select([target_col.sqla_col])
-            .select_from(self.get_from_clause(tp, db_engine_spec))
+            select([target_col.get_sqla_col()])
+            .select_from(self.get_from_clause(tp))
             .distinct()
         )
         if limit:
@@ -474,7 +473,7 @@ def get_sqla_table(self):
             tbl.schema = self.schema
         return tbl
 
-    def get_from_clause(self, template_processor=None, db_engine_spec=None):
+    def get_from_clause(self, template_processor=None):
         # Supporting arbitrary SQL statements in place of tables
         if self.sql:
             from_sql = self.sql
@@ -484,7 +483,7 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None):
             return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
         return self.get_sqla_table()
 
-    def adhoc_metric_to_sa(self, metric, cols):
+    def adhoc_metric_to_sqla(self, metric, cols):
         """
         Turn an adhoc metric into a sqlalchemy column.
 
@@ -493,22 +492,25 @@ def adhoc_metric_to_sa(self, metric, cols):
         :returns: The metric defined as a sqlalchemy column
         :rtype: sqlalchemy.sql.column
         """
-        expressionType = metric.get('expressionType')
-        if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
+        expression_type = metric.get('expressionType')
+        db_engine_spec = self.database.db_engine_spec
+        label = db_engine_spec.make_label_compatible(metric.get('label'))
+
+        if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
             column_name = metric.get('column').get('column_name')
-            sa_column = column(column_name)
+            sqla_column = column(column_name)
             table_column = cols.get(column_name)
 
             if table_column:
-                sa_column = table_column.sqla_col
-
-            sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
-            sa_metric = sa_metric.label(metric.get('label'))
-            return sa_metric
-        elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
-            sa_metric = literal_column(metric.get('sqlExpression'))
-            sa_metric = sa_metric.label(metric.get('label'))
-            return sa_metric
+                sqla_column = table_column.get_sqla_col()
+
+            sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
+            sqla_metric = sqla_metric.label(label)
+            return sqla_metric
+        elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
+            sqla_metric = literal_column(metric.get('sqlExpression'))
+            sqla_metric = sqla_metric.label(label)
+            return sqla_metric
         else:
             return None
 
@@ -566,15 +568,16 @@ def get_sqla_query(  # sqla
         metrics_exprs = []
         for m in metrics:
             if utils.is_adhoc_metric(m):
-                metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
+                metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
             elif m in metrics_dict:
-                metrics_exprs.append(metrics_dict.get(m).sqla_col)
+                metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
             else:
                 raise Exception(_("Metric '{}' is not valid".format(m)))
         if metrics_exprs:
             main_metric_expr = metrics_exprs[0]
         else:
-            main_metric_expr = literal_column('COUNT(*)').label('ccount')
+            main_metric_expr = literal_column('COUNT(*)').label(
+                db_engine_spec.make_label_compatible('count'))
 
         select_exprs = []
         groupby_exprs = []
@@ -585,8 +588,8 @@ def get_sqla_query(  # sqla
             inner_groupby_exprs = []
             for s in groupby:
                 col = cols[s]
-                outer = col.sqla_col
-                inner = col.sqla_col.label(col.column_name + '__')
+                outer = col.get_sqla_col()
+                inner = col.get_sqla_col(col.column_name + '__')
 
                 groupby_exprs.append(outer)
                 select_exprs.append(outer)
@@ -594,7 +597,7 @@ def get_sqla_query(  # sqla
                 inner_select_exprs.append(inner)
         elif columns:
             for s in columns:
-                select_exprs.append(cols[s].sqla_col)
+                select_exprs.append(cols[s].get_sqla_col())
             metrics_exprs = []
 
         if granularity:
@@ -618,7 +621,7 @@ def get_sqla_query(  # sqla
         select_exprs += metrics_exprs
         qry = sa.select(select_exprs)
 
-        tbl = self.get_from_clause(template_processor, db_engine_spec)
+        tbl = self.get_from_clause(template_processor)
 
         if not columns:
             qry = qry.group_by(*groupby_exprs)
@@ -638,9 +641,9 @@ def get_sqla_query(  # sqla
                     target_column_is_numeric=col_obj.is_num,
                     is_list_target=is_list_target)
                 if op in ('in', 'not in'):
-                    cond = col_obj.sqla_col.in_(eq)
+                    cond = col_obj.get_sqla_col().in_(eq)
                     if '<NULL>' in eq:
-                        cond = or_(cond, col_obj.sqla_col == None)  # noqa
+                        cond = or_(cond, col_obj.get_sqla_col() == None)  # noqa
                     if op == 'not in':
                         cond = ~cond
                     where_clause_and.append(cond)
@@ -648,23 +651,24 @@ def get_sqla_query(  # sqla
                     if col_obj.is_num:
                         eq = utils.string_to_num(flt['val'])
                     if op == '==':
-                        where_clause_and.append(col_obj.sqla_col == eq)
+                        where_clause_and.append(col_obj.get_sqla_col() == eq)
                     elif op == '!=':
-                        where_clause_and.append(col_obj.sqla_col != eq)
+                        where_clause_and.append(col_obj.get_sqla_col() != eq)
                     elif op == '>':
-                        where_clause_and.append(col_obj.sqla_col > eq)
+                        where_clause_and.append(col_obj.get_sqla_col() > eq)
                     elif op == '<':
-                        where_clause_and.append(col_obj.sqla_col < eq)
+                        where_clause_and.append(col_obj.get_sqla_col() < eq)
                     elif op == '>=':
-                        where_clause_and.append(col_obj.sqla_col >= eq)
+                        where_clause_and.append(col_obj.get_sqla_col() >= eq)
                     elif op == '<=':
-                        where_clause_and.append(col_obj.sqla_col <= eq)
+                        where_clause_and.append(col_obj.get_sqla_col() <= eq)
                     elif op == 'LIKE':
-                        where_clause_and.append(col_obj.sqla_col.like(eq))
+                        where_clause_and.append(col_obj.get_sqla_col().like(eq))
                     elif op == 'IS NULL':
-                        where_clause_and.append(col_obj.sqla_col == None)  # noqa
+                        where_clause_and.append(col_obj.get_sqla_col() == None)  # noqa
                     elif op == 'IS NOT NULL':
-                        where_clause_and.append(col_obj.sqla_col != None)  # noqa
+                        where_clause_and.append(
+                            col_obj.get_sqla_col() != None)  # noqa
         if extras:
             where = extras.get('where')
             if where:
@@ -686,7 +690,7 @@ def get_sqla_query(  # sqla
         for col, ascending in orderby:
             direction = asc if ascending else desc
             if utils.is_adhoc_metric(col):
-                col = self.adhoc_metric_to_sa(col, cols)
+                col = self.adhoc_metric_to_sqla(col, cols)
             qry = qry.order_by(direction(col))
 
         if row_limit:
@@ -712,12 +716,12 @@ def get_sqla_query(  # sqla
                 ob = inner_main_metric_expr
                 if timeseries_limit_metric:
                     if utils.is_adhoc_metric(timeseries_limit_metric):
-                        ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
+                        ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
                     elif timeseries_limit_metric in metrics_dict:
                         timeseries_limit_metric = metrics_dict.get(
                             timeseries_limit_metric,
                         )
-                        ob = timeseries_limit_metric.sqla_col
+                        ob = timeseries_limit_metric.get_sqla_col()
                     else:
                         raise Exception(_("Metric '{}' is not valid".format(m)))
                 direction = desc if order_desc else asc
@@ -762,7 +766,7 @@ def _get_top_groups(self, df, dimensions):
             group = []
             for dimension in dimensions:
                 col_obj = cols.get(dimension)
-                group.append(col_obj.sqla_col == row[dimension])
+                group.append(col_obj.get_sqla_col() == row[dimension])
             groups.append(and_(*group))
 
         return or_(*groups)
@@ -816,6 +820,7 @@ def fetch_metadata(self):
             .filter(or_(TableColumn.column_name == col.name
                         for col in table.columns)))
         dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
+        db_engine_spec = self.database.db_engine_spec
 
         for col in table.columns:
             try:
@@ -848,6 +853,9 @@ def fetch_metadata(self):
         ))
         if not self.main_dttm_col:
             self.main_dttm_col = any_date_col
+        for metric in metrics:
+            metric.metric_name = db_engine_spec.mutate_expression_label(
+                metric.metric_name)
         self.add_missing_metrics(metrics)
         db.session.merge(self)
         db.session.commit()
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 834f118047..1678dd97f7 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
         if cursor_description:
             column_names = [col[0] for col in cursor_description]
 
-        case_sensitive = db_engine_spec.consistent_case_sensitivity
-        self.column_names = dedup(column_names,
-                                  case_sensitive=case_sensitive)
+        self.column_names = dedup(column_names)
 
         data = data or []
         self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 13eb69502b..a6ae8ce603 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -35,7 +35,7 @@
 from sqlalchemy import select
 from sqlalchemy.engine import create_engine
 from sqlalchemy.engine.url import make_url
-from sqlalchemy.sql import text
+from sqlalchemy.sql import quoted_name, text
 from sqlalchemy.sql.expression import TextAsFrom
 import sqlparse
 from tableschema import Table
@@ -101,7 +101,7 @@ class BaseEngineSpec(object):
     time_secondary_columns = False
     inner_joins = True
     allows_subquery = True
-    consistent_case_sensitivity = True  # do results have same case as qry for col names?
+    force_column_alias_quotes = False
     arraysize = None
 
     @classmethod
@@ -376,55 +376,15 @@ def execute(cls, cursor, query, async=False):
         cursor.execute(query)
 
     @classmethod
-    def adjust_df_column_names(cls, df, fd):
-        """Based of fields in form_data, return dataframe with new column names
-
-        Usually sqla engines return column names whose case matches that of the
-        original query. For example:
-            SELECT 1 as col1, 2 as COL2, 3 as Col_3
-        will usually result in the following df.columns:
-            ['col1', 'COL2', 'Col_3'].
-        For these engines there is no need to adjust the dataframe column names
-        (default behavior). However, some engines (at least Snowflake, Oracle and
-        Redshift) return column names with different case than in the original query,
-        usually all uppercase. For these the column names need to be adjusted to
-        correspond to the case of the fields specified in the form data for Viz
-        to work properly. This adjustment can be done here.
+    def make_label_compatible(cls, label):
         """
-        if cls.consistent_case_sensitivity:
-            return df
-        else:
-            return cls.align_df_col_names_with_form_data(df, fd)
-
-    @staticmethod
-    def align_df_col_names_with_form_data(df, fd):
-        """Helper function to rename columns that have changed case during query.
-
-        Returns a dataframe where column names have been adjusted to correspond with
-        column names in form data (case insensitive). Examples:
-        dataframe: 'col1', form_data: 'col1' -> no change
-        dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
-        dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+        Return a sqlalchemy.sql.elements.quoted_name if the engine requires
+        quoting of aliases to ensure that select query and query results
+        have same case.
         """
-
-        columns = set()
-        lowercase_mapping = {}
-
-        metrics = utils.get_metric_names(fd.get('metrics', []))
-        groupby = fd.get('groupby', [])
-        other_cols = [utils.DTTM_ALIAS]
-        for col in metrics + groupby + other_cols:
-            columns.add(col)
-            lowercase_mapping[col.lower()] = col
-
-        rename_cols = {}
-        for col in df.columns:
-            if col not in columns:
-                orig_col = lowercase_mapping.get(col.lower())
-                if orig_col:
-                    rename_cols[col] = orig_col
-
-        return df.rename(index=str, columns=rename_cols)
+        if cls.force_column_alias_quotes is True:
+            return quoted_name(label, True)
+        return label
 
     @staticmethod
     def mutate_expression_label(label):
@@ -478,7 +438,8 @@ def get_table_names(cls, schema, inspector):
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
+
     time_grain_functions = {
         None: '{col}',
         'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
 
 
 class OracleEngineSpec(PostgresBaseEngineSpec):
     engine = 'oracle'
     limit_method = LimitMethod.WRAP_SQL
-    consistent_case_sensitivity = False
+    force_column_alias_quotes = True
 
     time_grain_functions = {
         None: '{col}',
@@ -545,6 +506,7 @@ def convert_dttm(cls, target_type, dttm):
 class Db2EngineSpec(BaseEngineSpec):
     engine = 'ibm_db_sa'
     limit_method = LimitMethod.WRAP_SQL
+    force_column_alias_quotes = True
     time_grain_functions = {
         None: '{col}',
         'PT1S': 'CAST({col} as TIMESTAMP)'
diff --git a/superset/viz.py b/superset/viz.py
index 5f4cea8498..df2bdaf0a5 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -386,10 +386,6 @@ def get_df_payload(self, query_obj=None):
         if query_obj and not is_loaded:
             try:
                 df = self.get_df(query_obj)
-                if hasattr(self.datasource, 'database') and \
-                        hasattr(self.datasource.database, 'db_engine_spec'):
-                    db_engine_spec = self.datasource.database.db_engine_spec
-                    df = db_engine_spec.adjust_df_column_names(df, self.form_data)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr('loaded_from_source')
                     is_loaded = True


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@superset.apache.org
For additional commands, e-mail: notifications-help@superset.apache.org