You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@superset.apache.org by GitBox <gi...@apache.org> on 2018/09/04 05:50:00 UTC
[GitHub] mistercrunch closed pull request #5686: Force quoted column aliases
for Oracle-like databases
mistercrunch closed pull request #5686: Force quoted column aliases for Oracle-like databases
URL: https://github.com/apache/incubator-superset/pull/5686
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/docs/installation.rst b/docs/installation.rst
index 008a2648f1..d1d1fd5e40 100644
--- a/docs/installation.rst
+++ b/docs/installation.rst
@@ -389,6 +389,10 @@ Make sure the user has privileges to access and use all required
databases/schemas/tables/views/warehouses, as the Snowflake SQLAlchemy engine does
not test for user rights during engine creation.
+*Note*: At the time of writing, there is a regression in the current stable version (1.1.2) of
+snowflake-sqlalchemy package that causes problems when used with Superset. It is recommended to
+use version 1.1.0 or try a newer version.
+
See `Snowflake SQLAlchemy <https://github.com/snowflakedb/snowflake-sqlalchemy>`_.
Caching
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 44a2cfb1ca..fcfe9e0b3f 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -99,13 +99,13 @@ class TableColumn(Model, BaseColumn):
s for s in export_fields if s not in ('table_id',)]
export_parent = 'table'
- @property
- def sqla_col(self):
- name = self.column_name
+ def get_sqla_col(self, label=None):
+ db_engine_spec = self.table.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(label if label else self.column_name)
if not self.expression:
- col = column(self.column_name).label(name)
+ col = column(self.column_name).label(label)
else:
- col = literal_column(self.expression).label(name)
+ col = literal_column(self.expression).label(label)
return col
@property
@@ -113,7 +113,7 @@ def datasource(self):
return self.table
def get_time_filter(self, start_dttm, end_dttm):
- col = self.sqla_col.label('__time')
+ col = self.get_sqla_col(label='__time')
l = [] # noqa: E741
if start_dttm:
l.append(col >= text(self.dttm_sql_literal(start_dttm)))
@@ -231,10 +231,10 @@ class SqlMetric(Model, BaseMetric):
s for s in export_fields if s not in ('table_id', )])
export_parent = 'table'
- @property
- def sqla_col(self):
- name = self.metric_name
- return literal_column(self.expression).label(name)
+ def get_sqla_col(self, label=None):
+ db_engine_spec = self.table.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(label if label else self.metric_name)
+ return literal_column(self.expression).label(label)
@property
def perm(self):
@@ -421,11 +421,10 @@ def values_for_column(self, column_name, limit=10000):
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
- db_engine_spec = self.database.db_engine_spec
qry = (
- select([target_col.sqla_col])
- .select_from(self.get_from_clause(tp, db_engine_spec))
+ select([target_col.get_sqla_col()])
+ .select_from(self.get_from_clause(tp))
.distinct()
)
if limit:
@@ -474,7 +473,7 @@ def get_sqla_table(self):
tbl.schema = self.schema
return tbl
- def get_from_clause(self, template_processor=None, db_engine_spec=None):
+ def get_from_clause(self, template_processor=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
@@ -484,7 +483,7 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None):
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()
- def adhoc_metric_to_sa(self, metric, cols):
+ def adhoc_metric_to_sqla(self, metric, cols):
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -493,22 +492,25 @@ def adhoc_metric_to_sa(self, metric, cols):
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
- expressionType = metric.get('expressionType')
- if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
+ expression_type = metric.get('expressionType')
+ db_engine_spec = self.database.db_engine_spec
+ label = db_engine_spec.make_label_compatible(metric.get('label'))
+
+ if expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']:
column_name = metric.get('column').get('column_name')
- sa_column = column(column_name)
+ sqla_column = column(column_name)
table_column = cols.get(column_name)
if table_column:
- sa_column = table_column.sqla_col
-
- sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column)
- sa_metric = sa_metric.label(metric.get('label'))
- return sa_metric
- elif expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
- sa_metric = literal_column(metric.get('sqlExpression'))
- sa_metric = sa_metric.label(metric.get('label'))
- return sa_metric
+ sqla_column = table_column.get_sqla_col()
+
+ sqla_metric = self.sqla_aggregations[metric.get('aggregate')](sqla_column)
+ sqla_metric = sqla_metric.label(label)
+ return sqla_metric
+ elif expression_type == utils.ADHOC_METRIC_EXPRESSION_TYPES['SQL']:
+ sqla_metric = literal_column(metric.get('sqlExpression'))
+ sqla_metric = sqla_metric.label(label)
+ return sqla_metric
else:
return None
@@ -566,15 +568,16 @@ def get_sqla_query( # sqla
metrics_exprs = []
for m in metrics:
if utils.is_adhoc_metric(m):
- metrics_exprs.append(self.adhoc_metric_to_sa(m, cols))
+ metrics_exprs.append(self.adhoc_metric_to_sqla(m, cols))
elif m in metrics_dict:
- metrics_exprs.append(metrics_dict.get(m).sqla_col)
+ metrics_exprs.append(metrics_dict.get(m).get_sqla_col())
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
if metrics_exprs:
main_metric_expr = metrics_exprs[0]
else:
- main_metric_expr = literal_column('COUNT(*)').label('ccount')
+ main_metric_expr = literal_column('COUNT(*)').label(
+ db_engine_spec.make_label_compatible('count'))
select_exprs = []
groupby_exprs = []
@@ -585,8 +588,8 @@ def get_sqla_query( # sqla
inner_groupby_exprs = []
for s in groupby:
col = cols[s]
- outer = col.sqla_col
- inner = col.sqla_col.label(col.column_name + '__')
+ outer = col.get_sqla_col()
+ inner = col.get_sqla_col(col.column_name + '__')
groupby_exprs.append(outer)
select_exprs.append(outer)
@@ -594,7 +597,7 @@ def get_sqla_query( # sqla
inner_select_exprs.append(inner)
elif columns:
for s in columns:
- select_exprs.append(cols[s].sqla_col)
+ select_exprs.append(cols[s].get_sqla_col())
metrics_exprs = []
if granularity:
@@ -618,7 +621,7 @@ def get_sqla_query( # sqla
select_exprs += metrics_exprs
qry = sa.select(select_exprs)
- tbl = self.get_from_clause(template_processor, db_engine_spec)
+ tbl = self.get_from_clause(template_processor)
if not columns:
qry = qry.group_by(*groupby_exprs)
@@ -638,9 +641,9 @@ def get_sqla_query( # sqla
target_column_is_numeric=col_obj.is_num,
is_list_target=is_list_target)
if op in ('in', 'not in'):
- cond = col_obj.sqla_col.in_(eq)
+ cond = col_obj.get_sqla_col().in_(eq)
if '<NULL>' in eq:
- cond = or_(cond, col_obj.sqla_col == None) # noqa
+ cond = or_(cond, col_obj.get_sqla_col() == None) # noqa
if op == 'not in':
cond = ~cond
where_clause_and.append(cond)
@@ -648,23 +651,24 @@ def get_sqla_query( # sqla
if col_obj.is_num:
eq = utils.string_to_num(flt['val'])
if op == '==':
- where_clause_and.append(col_obj.sqla_col == eq)
+ where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == '!=':
- where_clause_and.append(col_obj.sqla_col != eq)
+ where_clause_and.append(col_obj.get_sqla_col() != eq)
elif op == '>':
- where_clause_and.append(col_obj.sqla_col > eq)
+ where_clause_and.append(col_obj.get_sqla_col() > eq)
elif op == '<':
- where_clause_and.append(col_obj.sqla_col < eq)
+ where_clause_and.append(col_obj.get_sqla_col() < eq)
elif op == '>=':
- where_clause_and.append(col_obj.sqla_col >= eq)
+ where_clause_and.append(col_obj.get_sqla_col() >= eq)
elif op == '<=':
- where_clause_and.append(col_obj.sqla_col <= eq)
+ where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == 'LIKE':
- where_clause_and.append(col_obj.sqla_col.like(eq))
+ where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == 'IS NULL':
- where_clause_and.append(col_obj.sqla_col == None) # noqa
+ where_clause_and.append(col_obj.get_sqla_col() == None) # noqa
elif op == 'IS NOT NULL':
- where_clause_and.append(col_obj.sqla_col != None) # noqa
+ where_clause_and.append(
+ col_obj.get_sqla_col() != None) # noqa
if extras:
where = extras.get('where')
if where:
@@ -686,7 +690,7 @@ def get_sqla_query( # sqla
for col, ascending in orderby:
direction = asc if ascending else desc
if utils.is_adhoc_metric(col):
- col = self.adhoc_metric_to_sa(col, cols)
+ col = self.adhoc_metric_to_sqla(col, cols)
qry = qry.order_by(direction(col))
if row_limit:
@@ -712,12 +716,12 @@ def get_sqla_query( # sqla
ob = inner_main_metric_expr
if timeseries_limit_metric:
if utils.is_adhoc_metric(timeseries_limit_metric):
- ob = self.adhoc_metric_to_sa(timeseries_limit_metric, cols)
+ ob = self.adhoc_metric_to_sqla(timeseries_limit_metric, cols)
elif timeseries_limit_metric in metrics_dict:
timeseries_limit_metric = metrics_dict.get(
timeseries_limit_metric,
)
- ob = timeseries_limit_metric.sqla_col
+ ob = timeseries_limit_metric.get_sqla_col()
else:
raise Exception(_("Metric '{}' is not valid".format(m)))
direction = desc if order_desc else asc
@@ -762,7 +766,7 @@ def _get_top_groups(self, df, dimensions):
group = []
for dimension in dimensions:
col_obj = cols.get(dimension)
- group.append(col_obj.sqla_col == row[dimension])
+ group.append(col_obj.get_sqla_col() == row[dimension])
groups.append(and_(*group))
return or_(*groups)
@@ -816,6 +820,7 @@ def fetch_metadata(self):
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
+ db_engine_spec = self.database.db_engine_spec
for col in table.columns:
try:
@@ -848,6 +853,9 @@ def fetch_metadata(self):
))
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
+ for metric in metrics:
+ metric.metric_name = db_engine_spec.mutate_expression_label(
+ metric.metric_name)
self.add_missing_metrics(metrics)
db.session.merge(self)
db.session.commit()
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 834f118047..1678dd97f7 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -73,9 +73,7 @@ def __init__(self, data, cursor_description, db_engine_spec):
if cursor_description:
column_names = [col[0] for col in cursor_description]
- case_sensitive = db_engine_spec.consistent_case_sensitivity
- self.column_names = dedup(column_names,
- case_sensitive=case_sensitive)
+ self.column_names = dedup(column_names)
data = data or []
self.df = (
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 13eb69502b..a6ae8ce603 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -35,7 +35,7 @@
from sqlalchemy import select
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.url import make_url
-from sqlalchemy.sql import text
+from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import TextAsFrom
import sqlparse
from tableschema import Table
@@ -101,7 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
- consistent_case_sensitivity = True # do results have same case as qry for col names?
+ force_column_alias_quotes = False
arraysize = None
@classmethod
@@ -376,55 +376,15 @@ def execute(cls, cursor, query, async=False):
cursor.execute(query)
@classmethod
- def adjust_df_column_names(cls, df, fd):
- """Based of fields in form_data, return dataframe with new column names
-
- Usually sqla engines return column names whose case matches that of the
- original query. For example:
- SELECT 1 as col1, 2 as COL2, 3 as Col_3
- will usually result in the following df.columns:
- ['col1', 'COL2', 'Col_3'].
- For these engines there is no need to adjust the dataframe column names
- (default behavior). However, some engines (at least Snowflake, Oracle and
- Redshift) return column names with different case than in the original query,
- usually all uppercase. For these the column names need to be adjusted to
- correspond to the case of the fields specified in the form data for Viz
- to work properly. This adjustment can be done here.
+ def make_label_compatible(cls, label):
"""
- if cls.consistent_case_sensitivity:
- return df
- else:
- return cls.align_df_col_names_with_form_data(df, fd)
-
- @staticmethod
- def align_df_col_names_with_form_data(df, fd):
- """Helper function to rename columns that have changed case during query.
-
- Returns a dataframe where column names have been adjusted to correspond with
- column names in form data (case insensitive). Examples:
- dataframe: 'col1', form_data: 'col1' -> no change
- dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
- dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
+ Return a sqlalchemy.sql.elements.quoted_name if the engine requires
+ quoting of aliases to ensure that select query and query results
+ have same case.
"""
-
- columns = set()
- lowercase_mapping = {}
-
- metrics = utils.get_metric_names(fd.get('metrics', []))
- groupby = fd.get('groupby', [])
- other_cols = [utils.DTTM_ALIAS]
- for col in metrics + groupby + other_cols:
- columns.add(col)
- lowercase_mapping[col.lower()] = col
-
- rename_cols = {}
- for col in df.columns:
- if col not in columns:
- orig_col = lowercase_mapping.get(col.lower())
- if orig_col:
- rename_cols[col] = orig_col
-
- return df.rename(index=str, columns=rename_cols)
+ if cls.force_column_alias_quotes is True:
+ return quoted_name(label, True)
+ return label
@staticmethod
def mutate_expression_label(label):
@@ -478,7 +438,8 @@ def get_table_names(cls, schema, inspector):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
+
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -515,13 +476,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
- consistent_case_sensitivity = False
+ force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
@@ -545,6 +506,7 @@ def convert_dttm(cls, target_type, dttm):
class Db2EngineSpec(BaseEngineSpec):
engine = 'ibm_db_sa'
limit_method = LimitMethod.WRAP_SQL
+ force_column_alias_quotes = True
time_grain_functions = {
None: '{col}',
'PT1S': 'CAST({col} as TIMESTAMP)'
diff --git a/superset/viz.py b/superset/viz.py
index 5f4cea8498..df2bdaf0a5 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -386,10 +386,6 @@ def get_df_payload(self, query_obj=None):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
- if hasattr(self.datasource, 'database') and \
- hasattr(self.datasource.database, 'db_engine_spec'):
- db_engine_spec = self.datasource.database.db_engine_spec
- df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@superset.apache.org
For additional commands, e-mail: notifications-help@superset.apache.org