You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/08/03 16:54:01 UTC

[incubator-superset] branch master updated: Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)

This is an automated email from the ASF dual-hosted git repository.

maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e1f4db8  Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
e1f4db8 is described below

commit e1f4db8e24db69f9954738cc534a740ee8823721
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Aug 3 19:53:56 2018 +0300

    Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
    
    * Add function to fix dataframe column case
    
    * Fix broken handle_nulls method
    
    * Add case sensitivity option to dedup
    
    * Refactor function definition and call location
    
    * Remove added blank line
    
    * Move df column rename logit to db_engine_spec
    
    * Remove redundant variable
    
    * Update comments in db_engine_specs
    
    * Tie df adjustment to db_engine_spec class attribute
    
    * Fix dedup error
    
    * Linting
    
    * Check for db_engine_spec attribute prior to adjustment
    
    * Rename case sensitivity flag
    
    * Linting
    
    * Remove function that was moved to db_engine_specs
    
    * Get metrics names from utils
    
    * Remove double import and rename dedup variable
---
 superset/dataframe.py       | 23 +++++++++-------
 superset/db_engine_specs.py | 65 ++++++++++++++++++++++++++++++++++++++++++++-
 superset/viz.py             |  7 +++--
 tests/dataframe_test.py     | 12 ++++++---
 4 files changed, 91 insertions(+), 16 deletions(-)

diff --git a/superset/dataframe.py b/superset/dataframe.py
index 30ba4c7..2fecad9 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -27,23 +27,26 @@ INFER_COL_TYPES_THRESHOLD = 95
 INFER_COL_TYPES_SAMPLE_SIZE = 100
 
 
-def dedup(l, suffix='__'):
+def dedup(l, suffix='__', case_sensitive=True):
     """De-duplicates a list of string by suffixing a counter
 
     Always returns the same number of entries as provided, and always returns
-    unique values.
+    unique values. Case sensitive comparison by default.
 
-    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
-    foo,bar,bar__1,bar__2
+    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
+    foo,bar,bar__1,bar__2,Bar
+    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
+    foo,bar,bar__1,bar__2,Bar__3
     """
     new_l = []
     seen = {}
     for s in l:
-        if s in seen:
-            seen[s] += 1
-            s += suffix + str(seen[s])
+        s_fixed_case = s if case_sensitive else s.lower()
+        if s_fixed_case in seen:
+            seen[s_fixed_case] += 1
+            s += suffix + str(seen[s_fixed_case])
         else:
-            seen[s] = 0
+            seen[s_fixed_case] = 0
         new_l.append(s)
     return new_l
 
@@ -70,7 +73,9 @@ class SupersetDataFrame(object):
         if cursor_description:
             column_names = [col[0] for col in cursor_description]
 
-        self.column_names = dedup(column_names)
+        case_sensitive = db_engine_spec.consistent_case_sensitivity
+        self.column_names = dedup(column_names,
+                                  case_sensitive=case_sensitive)
 
         data = data or []
         self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4967d30..c07910a 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -101,6 +101,7 @@ class BaseEngineSpec(object):
     time_secondary_columns = False
     inner_joins = True
     allows_subquery = True
+    consistent_case_sensitivity = True  # do results have same case as qry for col names?
 
     @classmethod
     def get_time_grains(cls):
@@ -318,7 +319,6 @@ class BaseEngineSpec(object):
 
         if show_cols:
             fields = [sqla.column(c.get('name')) for c in cols]
-        full_table_name = table_name
         quote = engine.dialect.identifier_preparer.quote
         if schema:
             full_table_name = quote(schema) + '.' + quote(table_name)
@@ -366,6 +366,57 @@ class BaseEngineSpec(object):
     def execute(cursor, query, async=False):
         cursor.execute(query)
 
+    @classmethod
+    def adjust_df_column_names(cls, df, fd):
+        """Based of fields in form_data, return dataframe with new column names
+
+        Usually sqla engines return column names whose case matches that of the
+        original query. For example:
+            SELECT 1 as col1, 2 as COL2, 3 as Col_3
+        will usually result in the following df.columns:
+            ['col1', 'COL2', 'Col_3'].
+        For these engines there is no need to adjust the dataframe column names
+        (default behavior). However, some engines (at least Snowflake, Oracle and
+        Redshift) return column names with different case than in the original query,
+        usually all uppercase. For these the column names need to be adjusted to
+        correspond to the case of the fields specified in the form data for Viz
+        to work properly. This adjustment can be done here.
+        """
+        if cls.consistent_case_sensitivity:
+            return df
+        else:
+            return cls.align_df_col_names_with_form_data(df, fd)
+
+    @staticmethod
+    def align_df_col_names_with_form_data(df, fd):
+        """Helper function to rename columns that have changed case during query.
+
+        Returns a dataframe where column names have been adjusted to correspond with
+        column names in form data (case insensitive). Examples:
+        dataframe: 'col1', form_data: 'col1' -> no change
+        dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
+        dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+        """
+
+        columns = set()
+        lowercase_mapping = {}
+
+        metrics = utils.get_metric_names(fd.get('metrics', []))
+        groupby = fd.get('groupby', [])
+        other_cols = [utils.DTTM_ALIAS]
+        for col in metrics + groupby + other_cols:
+            columns.add(col)
+            lowercase_mapping[col.lower()] = col
+
+        rename_cols = {}
+        for col in df.columns:
+            if col not in columns:
+                orig_col = lowercase_mapping.get(col.lower())
+                if orig_col:
+                    rename_cols[col] = orig_col
+
+        return df.rename(index=str, columns=rename_cols)
+
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -414,6 +465,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
+    consistent_case_sensitivity = False
     time_grain_functions = {
         None: '{col}',
         'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
         'P1Y': "DATE_TRUNC('YEAR', {col})",
     }
 
+    @classmethod
+    def adjust_database_uri(cls, uri, selected_schema=None):
+        database = uri.database
+        if '/' in uri.database:
+            database = uri.database.split('/')[0]
+        if selected_schema:
+            uri.database = database + '/' + selected_schema
+        return uri
+
 
 class VerticaEngineSpec(PostgresBaseEngineSpec):
     engine = 'vertica'
@@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
+    consistent_case_sensitivity = False
 
 
 class OracleEngineSpec(PostgresBaseEngineSpec):
     engine = 'oracle'
     limit_method = LimitMethod.WRAP_SQL
+    consistent_case_sensitivity = False
 
     time_grain_functions = {
         None: '{col}',
diff --git a/superset/viz.py b/superset/viz.py
index 770d7ea..7ba8645 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -153,7 +153,7 @@ class BaseViz(object):
 
     def handle_nulls(self, df):
         fillna = self.get_fillna_for_columns(df.columns)
-        df = df.fillna(fillna)
+        return df.fillna(fillna)
 
     def get_fillna_for_col(self, col):
         """Returns the value to use as filler for a specific Column.type"""
@@ -217,7 +217,7 @@ class BaseViz(object):
                 self.df_metrics_to_num(df, query_obj.get('metrics') or [])
 
             df.replace([np.inf, -np.inf], np.nan)
-            self.handle_nulls(df)
+            df = self.handle_nulls(df)
         return df
 
     @staticmethod
@@ -382,6 +382,9 @@ class BaseViz(object):
         if query_obj and not is_loaded:
             try:
                 df = self.get_df(query_obj)
+                if hasattr(self.datasource.database, 'db_engine_spec'):
+                    db_engine_spec = self.datasource.database.db_engine_spec
+                    df = db_engine_spec.adjust_df_column_names(df, self.form_data)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr('loaded_from_source')
                     is_loaded = True
diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py
index fdba431..a773f08 100644
--- a/tests/dataframe_test.py
+++ b/tests/dataframe_test.py
@@ -16,12 +16,16 @@ class SupersetDataFrameTestCase(SupersetTestCase):
             ['foo', 'bar'],
         )
         self.assertEquals(
-            dedup(['foo', 'bar', 'foo', 'bar']),
-            ['foo', 'bar', 'foo__1', 'bar__1'],
+            dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
+            ['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
         )
         self.assertEquals(
-            dedup(['foo', 'bar', 'bar', 'bar']),
-            ['foo', 'bar', 'bar__1', 'bar__2'],
+            dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
+            ['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
+        )
+        self.assertEquals(
+            dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
+            ['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
         )
 
     def test_get_columns_basic(self):