You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@superset.apache.org by GitBox <gi...@apache.org> on 2018/08/03 16:53:57 UTC

[GitHub] mistercrunch closed pull request #5487: Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift

mistercrunch closed pull request #5487: Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift
URL: https://github.com/apache/incubator-superset/pull/5487
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/dataframe.py b/superset/dataframe.py
index 30ba4c776b..2fecad9264 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -27,23 +27,26 @@
 INFER_COL_TYPES_SAMPLE_SIZE = 100
 
 
-def dedup(l, suffix='__'):
+def dedup(l, suffix='__', case_sensitive=True):
     """De-duplicates a list of string by suffixing a counter
 
     Always returns the same number of entries as provided, and always returns
-    unique values.
+    unique values. Case sensitive comparison by default.
 
-    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
-    foo,bar,bar__1,bar__2
+    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
+    foo,bar,bar__1,bar__2,Bar
+    >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
+    foo,bar,bar__1,bar__2,Bar__3
     """
     new_l = []
     seen = {}
     for s in l:
-        if s in seen:
-            seen[s] += 1
-            s += suffix + str(seen[s])
+        s_fixed_case = s if case_sensitive else s.lower()
+        if s_fixed_case in seen:
+            seen[s_fixed_case] += 1
+            s += suffix + str(seen[s_fixed_case])
         else:
-            seen[s] = 0
+            seen[s_fixed_case] = 0
         new_l.append(s)
     return new_l
 
@@ -70,7 +73,9 @@ def __init__(self, data, cursor_description, db_engine_spec):
         if cursor_description:
             column_names = [col[0] for col in cursor_description]
 
-        self.column_names = dedup(column_names)
+        case_sensitive = db_engine_spec.consistent_case_sensitivity
+        self.column_names = dedup(column_names,
+                                  case_sensitive=case_sensitive)
 
         data = data or []
         self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4967d30647..c07910af4d 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -101,6 +101,7 @@ class BaseEngineSpec(object):
     time_secondary_columns = False
     inner_joins = True
     allows_subquery = True
+    consistent_case_sensitivity = True  # do results have same case as qry for col names?
 
     @classmethod
     def get_time_grains(cls):
@@ -318,7 +319,6 @@ def select_star(cls, my_db, table_name, engine, schema=None, limit=100,
 
         if show_cols:
             fields = [sqla.column(c.get('name')) for c in cols]
-        full_table_name = table_name
         quote = engine.dialect.identifier_preparer.quote
         if schema:
             full_table_name = quote(schema) + '.' + quote(table_name)
@@ -366,6 +366,57 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
     def execute(cursor, query, async=False):
         cursor.execute(query)
 
+    @classmethod
+    def adjust_df_column_names(cls, df, fd):
+        """Based of fields in form_data, return dataframe with new column names
+
+        Usually sqla engines return column names whose case matches that of the
+        original query. For example:
+            SELECT 1 as col1, 2 as COL2, 3 as Col_3
+        will usually result in the following df.columns:
+            ['col1', 'COL2', 'Col_3'].
+        For these engines there is no need to adjust the dataframe column names
+        (default behavior). However, some engines (at least Snowflake, Oracle and
+        Redshift) return column names with different case than in the original query,
+        usually all uppercase. For these the column names need to be adjusted to
+        correspond to the case of the fields specified in the form data for Viz
+        to work properly. This adjustment can be done here.
+        """
+        if cls.consistent_case_sensitivity:
+            return df
+        else:
+            return cls.align_df_col_names_with_form_data(df, fd)
+
+    @staticmethod
+    def align_df_col_names_with_form_data(df, fd):
+        """Helper function to rename columns that have changed case during query.
+
+        Returns a dataframe where column names have been adjusted to correspond with
+        column names in form data (case insensitive). Examples:
+        dataframe: 'col1', form_data: 'col1' -> no change
+        dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
+        dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+        """
+
+        columns = set()
+        lowercase_mapping = {}
+
+        metrics = utils.get_metric_names(fd.get('metrics', []))
+        groupby = fd.get('groupby', [])
+        other_cols = [utils.DTTM_ALIAS]
+        for col in metrics + groupby + other_cols:
+            columns.add(col)
+            lowercase_mapping[col.lower()] = col
+
+        rename_cols = {}
+        for col in df.columns:
+            if col not in columns:
+                orig_col = lowercase_mapping.get(col.lower())
+                if orig_col:
+                    rename_cols[col] = orig_col
+
+        return df.rename(index=str, columns=rename_cols)
+
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -414,6 +465,7 @@ def get_table_names(cls, schema, inspector):
 
 class SnowflakeEngineSpec(PostgresBaseEngineSpec):
     engine = 'snowflake'
+    consistent_case_sensitivity = False
     time_grain_functions = {
         None: '{col}',
         'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
         'P1Y': "DATE_TRUNC('YEAR', {col})",
     }
 
+    @classmethod
+    def adjust_database_uri(cls, uri, selected_schema=None):
+        database = uri.database
+        if '/' in uri.database:
+            database = uri.database.split('/')[0]
+        if selected_schema:
+            uri.database = database + '/' + selected_schema
+        return uri
+
 
 class VerticaEngineSpec(PostgresBaseEngineSpec):
     engine = 'vertica'
@@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
 
 class RedshiftEngineSpec(PostgresBaseEngineSpec):
     engine = 'redshift'
+    consistent_case_sensitivity = False
 
 
 class OracleEngineSpec(PostgresBaseEngineSpec):
     engine = 'oracle'
     limit_method = LimitMethod.WRAP_SQL
+    consistent_case_sensitivity = False
 
     time_grain_functions = {
         None: '{col}',
diff --git a/superset/viz.py b/superset/viz.py
index 770d7eaff0..7ba8645bbc 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -153,7 +153,7 @@ def run_extra_queries(self):
 
     def handle_nulls(self, df):
         fillna = self.get_fillna_for_columns(df.columns)
-        df = df.fillna(fillna)
+        return df.fillna(fillna)
 
     def get_fillna_for_col(self, col):
         """Returns the value to use as filler for a specific Column.type"""
@@ -217,7 +217,7 @@ def get_df(self, query_obj=None):
                 self.df_metrics_to_num(df, query_obj.get('metrics') or [])
 
             df.replace([np.inf, -np.inf], np.nan)
-            self.handle_nulls(df)
+            df = self.handle_nulls(df)
         return df
 
     @staticmethod
@@ -382,6 +382,9 @@ def get_df_payload(self, query_obj=None):
         if query_obj and not is_loaded:
             try:
                 df = self.get_df(query_obj)
+                if hasattr(self.datasource.database, 'db_engine_spec'):
+                    db_engine_spec = self.datasource.database.db_engine_spec
+                    df = db_engine_spec.adjust_df_column_names(df, self.form_data)
                 if self.status != utils.QueryStatus.FAILED:
                     stats_logger.incr('loaded_from_source')
                     is_loaded = True
diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py
index fdba431491..a773f08c4e 100644
--- a/tests/dataframe_test.py
+++ b/tests/dataframe_test.py
@@ -16,12 +16,16 @@ def test_dedup(self):
             ['foo', 'bar'],
         )
         self.assertEquals(
-            dedup(['foo', 'bar', 'foo', 'bar']),
-            ['foo', 'bar', 'foo__1', 'bar__1'],
+            dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
+            ['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
         )
         self.assertEquals(
-            dedup(['foo', 'bar', 'bar', 'bar']),
-            ['foo', 'bar', 'bar__1', 'bar__2'],
+            dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
+            ['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
+        )
+        self.assertEquals(
+            dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
+            ['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
         )
 
     def test_get_columns_basic(self):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@superset.apache.org
For additional commands, e-mail: notifications-help@superset.apache.org