You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by ma...@apache.org on 2018/09/04 05:50:01 UTC
[incubator-superset] branch master updated: Force quoted column
aliases for Oracle-like databases (#5686)
This is an automated email from the ASF dual-hosted git repository.
maximebeauchemin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 77fe9ef Force quoted column aliases for Oracle-like databases (#5686)
77fe9ef is described below
commit 77fe9ef130a383d6902b63f89743446b5578d3b5
Author: Ville Brofeldt <33...@users.noreply.github.com>
AuthorDate: Tue Sep 4 08:49:58 2018 +0300
Force quoted column aliases for Oracle-like databases (#5686)
* Replace dataframe label override logic with table column override
* Add mutation to any_date_col
* Linting
* Add mutation to oracle and redshift
* Fine tune how and which labels are mutated
* Implement alias quoting logic for oracle-like databases
* Fix and align column and metric sqla_col methods
* Clean up typos and redundant logic
* Move new attribute to old location
* Linting
* Replace old sqla_col property references with function calls
* Remove redundant calls to mutate_column_label
* Move duplicated logic to common function
* Add db_engine_specs to all sqla_col calls
* Add missing mydb
* Add note about snowflake-sqlalchemy regression
* Make db_engine_spec mandatory in sqla_col
* Small refactoring and cleanup
* Remove db_engine_spec from get_from_clause call
* Make db_engine_spec mandatory in adhoc_metric_to_sa
* Remove redundant mutate_expression_label call
* Add missing db_engine_specs to adhoc_metric_to_sa
* Rename arg label_name to label in get_column_label()
* Rename label function and add docstring
* Remove redundant db_engine_spec args
* Rename col_label to label
* Remove get_column_name wrapper and make direct calls to db_engine_spec
* Remove unneeded db_engine_specs
* Rename sa_ vars to sqla_
---
docs/installation.rst | 4 ++
superset/connectors/sqla/models.py | 106 ++++++++++++++++++++-----------------
superset/dataframe.py | 4 +-
superset/db_engine_specs.py | 66 +++++------------------
superset/viz.py | 4 --
5 files changed, 76 insertions(+), 108 deletions(-)
diff --git a/docs/installation.rst b/docs/installation.rst
index 6b08b82..7529323 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -393,6 +393,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.
+*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
+snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
+use version 1.1.0 or try a newer version.
+
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
Caching
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index f410279..037eb77 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'
- @property
- def sqla_col(self):
- name = self.column_name
+ def get_sqla_col(self, label=None):
+ db_engine_spec = self.table.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression:
- col = column(self.column_name).label(name)
+ col = column(self.column_name).label(label)
else:
- col = literal_column(self.expression).label(name)
+ col = literal_column(self.expression).label(label)
return col
@property
@@ -113,7 +113,7 @@ class TableColumn(Model, BaseColumn):
return self.table
def get_time_filter(self, start_dttm, end_dttm):
- col = self.sqla_col.label('__time')
+ col = self.get_sqla_col(label='__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'
- @property
- def sqla_col(self):
- name = self.metric_name
- return literal_column(self.expression).label(name)
+ def get_sqla_col(self, label=None):
+ db_engine_spec = self.table.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
+ return literal_column(self.expression).label(label)
@property
def perm(self):
@@ -421,11 +421,10 @@ class SqlaTable(Model, BaseDatasource):
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
- db_engine_spec = self.database.db_engine_spec
qry = (
- select([target_col.sqla_col])
- .select_from(self.get_from_clause(tp, db_engine_spec))
+ select([target_col.get_sqla_col()])
+ .select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
@@ -474,7 +473,7 @@ class SqlaTable(Model, BaseDatasource):
tbl.schema = self.schema
return tbl
- def get_from_clause(self, template_processor=None, db_engine_spec=None):
+ def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
@@ -484,7 +483,7 @@ class SqlaTable(Model, BaseDatasource):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()
- def adhoc_metric_to_sa(self, metric, cols):
+ def adhoc_metric_to_sqla(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -493,22 +492,25 @@ class SqlaTable(Model, BaseDatasource):
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
- expressionType = metric.get('expressionType')
- if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
+ expression_type = metric.get('expressionType')
+ db_engine_spec = self.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(metric.get('label'))
+
+ if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
- sa_column = column(column_name)
+ sqla_column = column(column_name)
table_column = cols.get(column_name)
if table_column:
- sa_column = table_column.sqla_col
-
- sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
- sa_metric = sa_metric.label(metric.get('label'))
- return sa_metric
- elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
- sa_metric = literal_column(metric.get('sqlExpression'))
- sa_metric = sa_metric.label(metric.get('label'))
- return sa_metric
+ sqla_column = table_column.get_sqla_col()
+
+ sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
+ sqla_metric = sqla_metric.label(label)
+ return sqla_metric
+ elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
+ sqla_metric = literal_column(metric.get('sqlExpression'))
+ sqla_metric = sqla_metric.label(label)
+ return sqla_metric
else:
return None
@@ -566,15 +568,16 @@ class SqlaTable(Model, BaseDatasource):
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
- metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
+ metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
- metrics_exprs.append(metrics_dict.get(m).sqla_col)
+ metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
- main_metric_expr = literal_column('COUNT(*)').label('ccount')
+ main_metric_expr = literal_column('COUNT(*)').label(
+ db_engine_spec.make_label_compatible('count'))
select_exprs = []
groupby_exprs = []
@@ -585,8 +588,8 @@ class SqlaTable(Model, BaseDatasource):
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
- outer = col.sqla_col
- inner = col.sqla_col.label(col.column_name + '__')
+ outer = col.get_sqla_col()
+ inner = col.get_sqla_col(col.column_name + '__')
groupby_exprs.append(outer)
select_exprs.append(outer)
@@ -594,7 +597,7 @@ class SqlaTable(Model, BaseDatasource):
inner_select_exprs.append(inner)
elif columns:
for s in columns:
- select_exprs.append(cols[s].sqla_col)
+ select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = []
if granularity:
@@ -618,7 +621,7 @@ class SqlaTable(Model, BaseDatasource):
select_exprs += metrics_exprs
qry = sa.select(select_exprs)
- tbl = self.get_from_clause(template_processor, db_engine_spec)
+ tbl = self.get_from_clause(template_processor)
if not columns:
qry = qry.group_by(*groupby_exprs)
@@ -638,9 +641,9 @@ class SqlaTable(Model, BaseDatasource):
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
- cond = col_obj.sqla_col.in_(eq)
+ cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq:
- cond = or_(cond, col_obj.sqla_col == None) # noqa
+ cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
@@ -648,23 +651,24 @@ class SqlaTable(Model, BaseDatasource):
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
- where_clause_and.append(col_obj.sqla_col == eq)
+ where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=':
- where_clause_and.append(col_obj.sqla_col != eq)
+ where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>':
- where_clause_and.append(col_obj.sqla_col > eq)
+ where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<':
- where_clause_and.append(col_obj.sqla_col < eq)
+ where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=':
- where_clause_and.append(col_obj.sqla_col >= eq)
+ where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=':
- where_clause_and.append(col_obj.sqla_col <= eq)
+ where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE':
- where_clause_and.append(col_obj.sqla_col.like(eq))
+ where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL':
- where_clause_and.append(col_obj.sqla_col == None) # noqa
+ where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL':
- where_clause_and.append(col_obj.sqla_col != None) # noqa
+ where_clause_and.append(
+ col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get('where')
if where:
@@ -686,7 +690,7 @@ class SqlaTable(Model, BaseDatasource):
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
- col = self.adhoc_metric_to_sa(col, cols)
+ col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col))
if row_limit:
@@ -712,12 +716,12 @@ class SqlaTable(Model, BaseDatasource):
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
- ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
+ ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
- ob = timeseries_limit_metric.sqla_col
+ ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
@@ -762,7 +766,7 @@ class SqlaTable(Model, BaseDatasource):
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
- group.append(col_obj.sqla_col == row[dimension])
+ group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group))
return or_(*groups)
@@ -816,6 +820,7 @@ class SqlaTable(Model, BaseDatasource):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
+ db_engine_spec = self.database.db_engine_spec
for col in table.columns:
try:
@@ -848,6 +853,9 @@ class SqlaTable(Model, BaseDatasource):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
+ for metric in metrics:
+ metric.metric_name = db_engine_spec.mutate_expression_label(
+ metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 834f118..1678dd9 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -73,9 +73,7 @@ class SupersetDataFrame(object):
if cursor_description:
column_names = [col[0] for col in cursor_description]
- case_sensitive = db_engine_spec.consistent_case_sensitivity
- self.column_names = dedup(column_names,
- case_sensitive=case_sensitive)
+ self.column_names = dedup(column_names)
data = data or []
self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 2ce7db0..a8a9faa 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -35,7 +35,7 @@ import sqlalchemy as sqla
from sqlalchemy import select
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
-from sqlalchemy.sql import text
+from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
from tableschema import Table
@@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
- consistent_case_sensitivity = True # do results have same case as qry for col names?
+ force_column_alias_quotes = False
arraysize = None
@classmethod
@@ -376,55 +376,15 @@ class BaseEngineSpec(object):
cursor.execute(query)
@classmethod
- def adjust_df_column_names(cls, df, fd):
- """Based of fields in form_data, return dataframe with new column names
-
- Usually sqla engines return column names whose case matches that of the
- original query. For example:
- SELECT 1 as col1, 2 as COL2, 3 as Col_3
- will usually result in the following df.columns:
- ['col1', 'COL2', 'Col_3'].
- For these engines there is no need to adjust the dataframe column names
- (default behavior). However, some engines (at least Snowflake, Oracle and
- Redshift) return column names with different case than in the original query,
- usually all uppercase. For these the column names need to be adjusted to
- correspond to the case of the fields specified in the form data for Viz
- to work properly. This adjustment can be done here.
+ def make_label_compatible(cls, label):
"""
- if cls.consistent_case_sensitivity:
- return df
- else:
- return cls.align_df_col_names_with_form_data(df, fd)
-
- @staticmethod
- def align_df_col_names_with_form_data(df, fd):
- """Helper function to rename columns that have changed case during query.
-
- Returns a dataframe where column names have been adjusted to correspond with
- column names in form data (case insensitive). Examples:
- dataframe: 'col1', form_data: 'col1' -> no change
- dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
- dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+ Return a sqlalchemy.sql.elements.quoted_name if the engine requires
+ quoting of aliases to ensure that select query and query results
+ have same case.
"""
-
- columns = set()
- lowercase_mapping = {}
-
- metrics = utils.get_metric_names(fd.get('metrics', []))
- groupby = fd.get('groupby', [])
- other_cols = [utils.DTTM_ALIAS]
- for col in metrics + groupby + other_cols:
- columns.add(col)
- lowercase_mapping[col.lower()] = col
-
- rename_cols = {}
- for col in df.columns:
- if col not in columns:
- orig_col = lowercase_mapping.get(col.lower())
- if orig_col:
- rename_cols[col] = orig_col
-
- return df.rename(index=str, columns=rename_cols)
+ if cls.force_column_alias_quotes is True:
+ return quoted_name(label, True)
+ return label
@staticmethod
def mutate_expression_label(label):
@@ -478,7 +438,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
+
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
@@ -545,6 +506,7 @@ class OracleEngineSpec(PostgresBaseEngineSpec):
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
+ force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
diff --git a/superset/viz.py b/superset/viz.py
index 90c209e..6a18dfa 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -391,10 +391,6 @@ class BaseViz(object):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
- if hasattr(self.datasource, 'database') and \
- hasattr(self.datasource.database, 'db_engine_spec'):
- db_engine_spec = self.datasource.database.db_engine_spec
- df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True