You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@superset.apache.org by GitBox <gi...@apache.org> on 2018/07/21 19:01:29 UTC

[GitHub] john-bodley closed pull request #5178: [sql] Correct SQL parameter formatting

john-bodley closed pull request #5178: [sql] Correct SQL parameter formatting 
URL: https://github.com/apache/incubator-superset/pull/5178
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.pylintrc b/.pylintrc
index 820637dbd0..016b04e367 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -282,7 +282,7 @@ ignored-modules=numpy,pandas,alembic.op,sqlalchemy,alembic.context,flask_appbuil
 # List of class names for which member attributes should not be checked (useful
 # for classes with dynamically set attributes). This supports the use of
 # qualified names.
-ignored-classes=optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
+ignored-classes=contextlib.closing,optparse.Values,thread._local,_thread._local,sqlalchemy.orm.scoping.scoped_session
 
 # List of members which are set dynamically and missed by pylint inference
 # system, and so shouldn't trigger E1101 when accessed. Python regular
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index e08053ccd1..de251a249b 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -12,7 +12,6 @@
 from flask_appbuilder import Model
 from flask_babel import lazy_gettext as _
 import pandas as pd
-import six
 import sqlalchemy as sa
 from sqlalchemy import (
     and_, asc, Boolean, Column, DateTime, desc, ForeignKey, Integer, or_,
@@ -420,14 +419,8 @@ def get_template_processor(self, **kwargs):
             table=self, database=self.database, **kwargs)
 
     def get_query_str(self, query_obj):
-        engine = self.database.get_sqla_engine()
         qry = self.get_sqla_query(**query_obj)
-        sql = six.text_type(
-            qry.compile(
-                engine,
-                compile_kwargs={'literal_binds': True},
-            ),
-        )
+        sql = self.database.compile_sqla_query(qry)
         logging.info(sql)
         sql = sqlparse.format(sql, reindent=True)
         if query_obj['is_prequery']:
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index eb3b3681eb..6e67e3d6df 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -64,7 +64,6 @@ class BaseEngineSpec(object):
     """Abstract class for database engine specific configurations"""
 
     engine = 'base'  # str as defined in sqlalchemy.engine.engine
-    cursor_execute_kwargs = {}
     time_grains = tuple()
     time_groupby_inline = False
     limit_method = LimitMethod.FORCE_LIMIT
@@ -333,6 +332,10 @@ def get_normalized_column_names(cls, cursor_description):
     def normalize_column_name(column_name):
         return column_name
 
+    @staticmethod
+    def execute(cursor, query, async=False):
+        cursor.execute(query)
+
 
 class PostgresBaseEngineSpec(BaseEngineSpec):
     """ Abstract class for Postgres 'like' databases """
@@ -556,7 +559,6 @@ def get_table_names(cls, schema, inspector):
 
 class MySQLEngineSpec(BaseEngineSpec):
     engine = 'mysql'
-    cursor_execute_kwargs = {'args': {}}
     time_grains = (
         Grain('Time Column', _('Time Column'), '{col}', None),
         Grain('second', _('second'), 'DATE_ADD(DATE({col}), '
@@ -619,7 +621,6 @@ def extract_error_message(cls, e):
 
 class PrestoEngineSpec(BaseEngineSpec):
     engine = 'presto'
-    cursor_execute_kwargs = {'parameters': None}
 
     time_grains = (
         Grain('Time Column', _('Time Column'), '{col}', None),
@@ -918,7 +919,6 @@ class HiveEngineSpec(PrestoEngineSpec):
     """Reuses PrestoEngineSpec functionality."""
 
     engine = 'hive'
-    cursor_execute_kwargs = {'async': True}
 
     # Scoping regex at class level to avoid recompiling
     # 17/02/07 19:36:38 INFO ql.Driver: Total jobs = 5
@@ -1180,6 +1180,10 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
             configuration['hive.server2.proxy.user'] = username
         return configuration
 
+    @staticmethod
+    def execute(cursor, query, async=False):
+        cursor.execute(query, async=async)
+
 
 class MssqlEngineSpec(BaseEngineSpec):
     engine = 'mssql'
diff --git a/superset/migrations/versions/4451805bbaa1_remove_double_percents.py b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py
new file mode 100644
index 0000000000..2e57b39d3f
--- /dev/null
+++ b/superset/migrations/versions/4451805bbaa1_remove_double_percents.py
@@ -0,0 +1,86 @@
+"""remove double percents
+
+Revision ID: 4451805bbaa1
+Revises: afb7730f6a9c
+Create Date: 2018-06-13 10:20:35.846744
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = '4451805bbaa1'
+down_revision = 'bddc498dd179'
+
+
+from alembic import op
+import json
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy import Column, create_engine, ForeignKey, Integer, String, Text
+
+from superset import db
+
+Base = declarative_base()
+
+
+class Slice(Base):
+    __tablename__ = 'slices'
+
+    id = Column(Integer, primary_key=True)
+    datasource_id = Column(Integer, ForeignKey('tables.id'))
+    datasource_type = Column(String(200))
+    params = Column(Text)
+
+
+class Table(Base):
+    __tablename__ = 'tables'
+
+    id = Column(Integer, primary_key=True)
+    database_id = Column(Integer, ForeignKey('dbs.id'))
+
+
+class Database(Base):
+    __tablename__ = 'dbs'
+
+    id = Column(Integer, primary_key=True)
+    sqlalchemy_uri = Column(String(1024))
+
+
+def replace(source, target):
+    bind = op.get_bind()
+    session = db.Session(bind=bind)
+
+    query = (
+        session.query(Slice, Database)
+        .join(Table)
+        .join(Database)
+        .filter(Slice.datasource_type == 'table')
+        .all()
+    )
+
+    for slc, database in query:
+        try:
+            engine = create_engine(database.sqlalchemy_uri)
+
+            if engine.dialect.identifier_preparer._double_percents:
+                params = json.loads(slc.params)
+
+                if 'adhoc_filters' in params:
+                    for filt in params['adhoc_filters']:
+                        if 'sqlExpression' in filt:
+                            filt['sqlExpression'] = (
+                                filt['sqlExpression'].replace(source, target)
+                            )
+
+                    slc.params = json.dumps(params, sort_keys=True)
+        except Exception:
+            pass
+
+    session.commit()
+    session.close()
+
+
+def upgrade():
+    replace('%%', '%')
+
+
+def downgrade():
+    replace('%', '%%')
diff --git a/superset/models/core.py b/superset/models/core.py
index ebce5fcb31..4512e85f50 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -6,6 +6,7 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
+from contextlib import closing
 from copy import copy, deepcopy
 from datetime import datetime
 import functools
@@ -19,6 +20,7 @@
 from future.standard_library import install_aliases
 import numpy
 import pandas as pd
+import six
 import sqlalchemy as sqla
 from sqlalchemy import (
     Boolean, Column, create_engine, DateTime, ForeignKey, Integer,
@@ -692,12 +694,7 @@ def get_quoter(self):
 
     def get_df(self, sql, schema):
         sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
-        eng = self.get_sqla_engine(schema=schema)
-
-        for i in range(len(sqls) - 1):
-            eng.execute(sqls[i])
-
-        df = pd.read_sql_query(sqls[-1], eng)
+        engine = self.get_sqla_engine(schema=schema)
 
         def needs_conversion(df_series):
             if df_series.empty:
@@ -706,15 +703,35 @@ def needs_conversion(df_series):
                 return True
             return False
 
-        for k, v in df.dtypes.items():
-            if v.type == numpy.object_ and needs_conversion(df[k]):
-                df[k] = df[k].apply(utils.json_dumps_w_dates)
-        return df
+        with closing(engine.raw_connection()) as conn:
+            with closing(conn.cursor()) as cursor:
+                for sql in sqls:
+                    self.db_engine_spec.execute(cursor, sql)
+                df = pd.DataFrame.from_records(
+                    data=list(cursor.fetchall()),
+                    columns=[col_desc[0] for col_desc in cursor.description],
+                    coerce_float=True,
+                )
+
+                for k, v in df.dtypes.items():
+                    if v.type == numpy.object_ and needs_conversion(df[k]):
+                        df[k] = df[k].apply(utils.json_dumps_w_dates)
+                return df
 
     def compile_sqla_query(self, qry, schema=None):
-        eng = self.get_sqla_engine(schema=schema)
-        compiled = qry.compile(eng, compile_kwargs={'literal_binds': True})
-        return '{}'.format(compiled)
+        engine = self.get_sqla_engine(schema=schema)
+
+        sql = six.text_type(
+            qry.compile(
+                engine,
+                compile_kwargs={'literal_binds': True},
+            ),
+        )
+
+        if engine.dialect.identifier_preparer._double_percents:
+            sql = sql.replace('%%', '%')
+
+        return sql
 
     def select_star(
             self, table_name, schema=None, limit=100, show_cols=False,
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index c9f07ae906..e08991184c 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -214,8 +214,7 @@ def handle_error(msg):
         cursor = conn.cursor()
         logging.info('Running query: \n{}'.format(executed_sql))
         logging.info(query.executed_sql)
-        cursor.execute(query.executed_sql,
-                       **db_engine_spec.cursor_execute_kwargs)
+        db_engine_spec.execute(cursor, query.executed_sql, async=True)
         logging.info('Handling cursor')
         db_engine_spec.handle_cursor(cursor, query, session)
         logging.info('Fetching data: {}'.format(query.to_dict()))
diff --git a/tests/core_tests.py b/tests/core_tests.py
index dd6e3d891d..5f2d9fef92 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -427,6 +427,15 @@ def test_csv_endpoint(self):
         expected_data = csv.reader(
             io.StringIO('first_name,last_name\nadmin, user\n'))
 
+        sql = "SELECT first_name FROM ab_user WHERE first_name LIKE '%admin%'"
+        client_id = '{}'.format(random.getrandbits(64))[:10]
+        self.run_sql(sql, client_id, raise_on_error=True)
+
+        resp = self.get_resp('/superset/csv/{}'.format(client_id))
+        data = csv.reader(io.StringIO(resp))
+        expected_data = csv.reader(
+            io.StringIO('first_name\nadmin\n'))
+
         self.assertEqual(list(expected_data), list(data))
         self.logout()
 
diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py
index 49926f80de..bde9b3736f 100644
--- a/tests/sqllab_tests.py
+++ b/tests/sqllab_tests.py
@@ -249,7 +249,7 @@ def test_sqllab_viz(self):
             'sql': """\
                 SELECT viz_type, count(1) as ccount
                 FROM slices
-                WHERE viz_type LIKE '%%a%%'
+                WHERE viz_type LIKE '%a%'
                 GROUP BY viz_type""",
             'dbId': 1,
         }
diff --git a/tox.ini b/tox.ini
index 2b2678eae4..29026147ae 100644
--- a/tox.ini
+++ b/tox.ini
@@ -36,7 +36,7 @@ setenv =
     SUPERSET_CONFIG = tests.superset_test_config
     SUPERSET_HOME = {envtmpdir}
     py27-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset?charset=utf8
-    py34-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
+    py{34,36}-mysql: SUPERSET__SQLALCHEMY_DATABASE_URI = mysql://mysqluser:mysqluserpassword@localhost/superset
     py{27,34,36}-postgres: SUPERSET__SQLALCHEMY_DATABASE_URI = postgresql+psycopg2://postgresuser:pguserpassword@localhost/superset
     py{27,34,36}-sqlite: SUPERSET__SQLALCHEMY_DATABASE_URI = sqlite:////{envtmpdir}/superset.db
 whitelist_externals =


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@superset.apache.org
For additional commands, e-mail: notifications-help@superset.apache.org