You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/08/03 16:54:01 UTC
[incubator-superset] branch master updated: Match viz dataframe
column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new e1f4db8 Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
e1f4db8 is described below
commit e1f4db8e24db69f9954738cc534a740ee8823721
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Fri Aug 3 19:53:56 2018 +0300
Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
* Add function to fix dataframe column case
* Fix broken handle_nulls method
* Add case sensitivity option to dedup
* Refactor function definition and call location
* Remove added blank line
* Move df column rename logit to db_engine_spec
* Remove redundant variable
* Update comments in db_engine_specs
* Tie df adjustment to db_engine_spec class attribute
* Fix dedup error
* Linting
* Check for db_engine_spec attribute prior to adjustment
* Rename case sensitivity flag
* Linting
* Remove function that was moved to db_engine_specs
* Get metrics names from utils
* Remove double import and rename dedup variable
---
superset/dataframe.py | 23 +++++++++-------
superset/db_engine_specs.py | 65 ++++++++++++++++++++++++++++++++++++++++++++-
superset/viz.py | 7 +++--
tests/dataframe_test.py | 12 ++++++---
4 files changed, 91 insertions(+), 16 deletions(-)
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 30ba4c7..2fecad9 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -27,23 +27,26 @@ INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100
-def dedup(l, suffix='__'):
+def dedup(l, suffix='__', case_sensitive=True):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
- unique values.
+ unique values. Case sensitive comparison by default.
- >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
- foo,bar,bar__1,bar__2
+ >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
+ foo,bar,bar__1,bar__2,Bar
+ >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
+ foo,bar,bar__1,bar__2,Bar__3
"""
new_l = []
seen = {}
for s in l:
- if s in seen:
- seen[s] += 1
- s += suffix + str(seen[s])
+ s_fixed_case = s if case_sensitive else s.lower()
+ if s_fixed_case in seen:
+ seen[s_fixed_case] += 1
+ s += suffix + str(seen[s_fixed_case])
else:
- seen[s] = 0
+ seen[s_fixed_case] = 0
new_l.append(s)
return new_l
@@ -70,7 +73,9 @@ class SupersetDataFrame(object):
if cursor_description:
column_names = [col[0] for col in cursor_description]
- self.column_names = dedup(column_names)
+ case_sensitive = db_engine_spec.consistent_case_sensitivity
+ self.column_names = dedup(column_names,
+ case_sensitive=case_sensitive)
data = data or []
self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4967d30..c07910a 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -101,6 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
+ consistent_case_sensitivity = True # do results have same case as qry for col names?
@classmethod
def get_time_grains(cls):
@@ -318,7 +319,6 @@ class BaseEngineSpec(object):
if show_cols:
fields = [sqla.column(c.get('name')) for c in cols]
- full_table_name = table_name
quote = engine.dialect.identifier_preparer.quote
if schema:
full_table_name = quote(schema) + '.' + quote(table_name)
@@ -366,6 +366,57 @@ class BaseEngineSpec(object):
def execute(cursor, query, async=False):
cursor.execute(query)
+ @classmethod
+ def adjust_df_column_names(cls, df, fd):
+ """Based of fields in form_data, return dataframe with new column names
+
+ Usually sqla engines return column names whose case matches that of the
+ original query. For example:
+ SELECT 1 as col1, 2 as COL2, 3 as Col_3
+ will usually result in the following df.columns:
+ ['col1', 'COL2', 'Col_3'].
+ For these engines there is no need to adjust the dataframe column names
+ (default behavior). However, some engines (at least Snowflake, Oracle and
+ Redshift) return column names with different case than in the original query,
+ usually all uppercase. For these the column names need to be adjusted to
+ correspond to the case of the fields specified in the form data for Viz
+ to work properly. This adjustment can be done here.
+ """
+ if cls.consistent_case_sensitivity:
+ return df
+ else:
+ return cls.align_df_col_names_with_form_data(df, fd)
+
+ @staticmethod
+ def align_df_col_names_with_form_data(df, fd):
+ """Helper function to rename columns that have changed case during query.
+
+ Returns a dataframe where column names have been adjusted to correspond with
+ column names in form data (case insensitive). Examples:
+ dataframe: 'col1', form_data: 'col1' -> no change
+ dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
+ dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+ """
+
+ columns = set()
+ lowercase_mapping = {}
+
+ metrics = utils.get_metric_names(fd.get('metrics', []))
+ groupby = fd.get('groupby', [])
+ other_cols = [utils.DTTM_ALIAS]
+ for col in metrics + groupby + other_cols:
+ columns.add(col)
+ lowercase_mapping[col.lower()] = col
+
+ rename_cols = {}
+ for col in df.columns:
+ if col not in columns:
+ orig_col = lowercase_mapping.get(col.lower())
+ if orig_col:
+ rename_cols[col] = orig_col
+
+ return df.rename(index=str, columns=rename_cols)
+
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@@ -414,6 +465,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
+ consistent_case_sensitivity = False
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
'P1Y': "DATE_TRUNC('YEAR', {col})",
}
+ @classmethod
+ def adjust_database_uri(cls, uri, selected_schema=None):
+ database = uri.database
+ if '/' in uri.database:
+ database = uri.database.split('/')[0]
+ if selected_schema:
+ uri.database = database + '/' + selected_schema
+ return uri
+
class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica'
@@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
+ consistent_case_sensitivity = False
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
+ consistent_case_sensitivity = False
time_grain_functions = {
None: '{col}',
diff --git a/superset/viz.py b/superset/viz.py
index 770d7ea..7ba8645 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -153,7 +153,7 @@ class BaseViz(object):
def handle_nulls(self, df):
fillna = self.get_fillna_for_columns(df.columns)
- df = df.fillna(fillna)
+ return df.fillna(fillna)
def get_fillna_for_col(self, col):
"""Returns the value to use as filler for a specific Column.type"""
@@ -217,7 +217,7 @@ class BaseViz(object):
self.df_metrics_to_num(df, query_obj.get('metrics') or [])
df.replace([np.inf, -np.inf], np.nan)
- self.handle_nulls(df)
+ df = self.handle_nulls(df)
return df
@staticmethod
@@ -382,6 +382,9 @@ class BaseViz(object):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
+ if hasattr(self.datasource.database, 'db_engine_spec'):
+ db_engine_spec = self.datasource.database.db_engine_spec
+ df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True
diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py
index fdba431..a773f08 100644
--- a/tests/dataframe_test.py
+++ b/tests/dataframe_test.py
@@ -16,12 +16,16 @@ class SupersetDataFrameTestCase(SupersetTestCase):
['foo', 'bar'],
)
self.assertEquals(
- dedup(['foo', 'bar', 'foo', 'bar']),
- ['foo', 'bar', 'foo__1', 'bar__1'],
+ dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
+ ['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
)
self.assertEquals(
- dedup(['foo', 'bar', 'bar', 'bar']),
- ['foo', 'bar', 'bar__1', 'bar__2'],
+ dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
+ ['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
+ )
+ self.assertEquals(
+ dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
+ ['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
)
def test_get_columns_basic(self):