You are viewing a plain text version of this content. The canonical link for it is here.
Posted to notifications@superset.apache.org by GitBox <gi...@apache.org> on 2018/07/16 22:27:32 UTC

[GitHub] timifasubaa closed pull request #5295: [sqllab] Fix sqllab limit regex issue with sqlparse

timifasubaa closed pull request #5295: [sqllab] Fix sqllab limit regex issue with sqlparse
URL: https://github.com/apache/incubator-superset/pull/5295
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 4181c49d67..bb7893e9cc 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -41,7 +41,7 @@
 from tableschema import Table
 from werkzeug.utils import secure_filename
 
-from superset import app, cache_util, conf, db, utils
+from superset import app, cache_util, conf, db, sql_parse, utils
 from superset.exceptions import SupersetTemplateException
 from superset.utils import QueryStatus
 
@@ -110,32 +110,19 @@ def apply_limit_to_sql(cls, sql, limit, database):
             )
             return database.compile_sqla_query(qry)
         elif LimitMethod.FORCE_LIMIT:
-            sql_without_limit = cls.get_query_without_limit(sql)
-            return '{sql_without_limit} LIMIT {limit}'.format(**locals())
+            parsed_query = sql_parse.SupersetQuery(sql)
+            sql = parsed_query.get_query_with_new_limit(limit)
         return sql
 
     @classmethod
     def get_limit_from_sql(cls, sql):
-        limit_pattern = re.compile(r"""
-                (?ix)          # case insensitive, verbose
-                \s+            # whitespace
-                LIMIT\s+(\d+)  # LIMIT $ROWS
-                ;?             # optional semi-colon
-                (\s|;)*$       # remove trailing spaces tabs or semicolons
-                """)
-        matches = limit_pattern.findall(sql)
-        if matches:
-            return int(matches[0][0])
-
-    @classmethod
-    def get_query_without_limit(cls, sql):
-        return re.sub(r"""
-                (?ix)        # case insensitive, verbose
-                \s+          # whitespace
-                LIMIT\s+\d+  # LIMIT $ROWS
-                ;?           # optional semi-colon
-                (\s|;)*$     # remove trailing spaces tabs or semicolons
-                """, '', sql)
+        parsed_query = sql_parse.SupersetQuery(sql)
+        return parsed_query.limit
+
+    @classmethod
+    def get_query_with_new_limit(cls, sql, limit):
+        parsed_query = sql_parse.SupersetQuery(sql)
+        return parsed_query.get_query_with_new_limit(limit)
 
     @staticmethod
     def csv_to_df(**kwargs):
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index ea1c9c3885..7b5103924c 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -20,18 +20,24 @@ def __init__(self, sql_statement):
         self.sql = sql_statement
         self._table_names = set()
         self._alias_names = set()
+        self._limit = None
         # TODO: multistatement support
 
         logging.info('Parsing with sqlparse statement {}'.format(self.sql))
         self._parsed = sqlparse.parse(self.sql)
         for statement in self._parsed:
             self.__extract_from_token(statement)
+            self._limit = self._extract_limit_from_query(statement)
         self._table_names = self._table_names - self._alias_names
 
     @property
     def tables(self):
         return self._table_names
 
+    @property
+    def limit(self):
+        return self._limit
+
     def is_select(self):
         return self._parsed[0].get_type() == 'SELECT'
 
@@ -128,3 +134,41 @@ def __extract_from_token(self, token):
                 for token in item.tokens:
                     if self.__is_identifier(token):
                         self.__process_identifier(token)
+
+    def _get_limit_from_token(self, token):
+        if token.ttype == sqlparse.tokens.Literal.Number.Integer:
+            return int(token.value)
+        elif token.is_group:
+            return int(token.get_token_at_offset(1).value)
+
+    def _extract_limit_from_query(self, statement):
+        limit_token = None
+        for pos, item in enumerate(statement.tokens):
+            if item.ttype in Keyword and item.value.lower() == 'limit':
+                limit_token = statement.tokens[pos + 2]
+                return self._get_limit_from_token(limit_token)
+
+    def get_query_with_new_limit(self, new_limit):
+        """returns the query with the specified limit"""
+        """does not change the underlying query"""
+        if not self._limit:
+            return self.sql + ' LIMIT ' + str(new_limit)
+        limit_pos = None
+        tokens = self._parsed[0].tokens
+        # Add all items to before_str until there is a limit
+        for pos, item in enumerate(tokens):
+            if item.ttype in Keyword and item.value.lower() == 'limit':
+                limit_pos = pos
+                break
+        limit = tokens[limit_pos + 2]
+        if limit.ttype == sqlparse.tokens.Literal.Number.Integer:
+            tokens[limit_pos + 2].value = new_limit
+        elif limit.is_group:
+            tokens[limit_pos + 2].value = (
+                '{}, {}'.format(next(limit.get_identifiers()), new_limit)
+            )
+
+        str_res = ''
+        for i in tokens:
+            str_res += str(i.value)
+        return str_res
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index 447914ed5f..c85e23a26c 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -4,8 +4,6 @@
 from __future__ import print_function
 from __future__ import unicode_literals
 
-import textwrap
-
 from superset.db_engine_specs import (
     BaseEngineSpec, HiveEngineSpec, MssqlEngineSpec,
     MySQLEngineSpec, PrestoEngineSpec,
@@ -143,18 +141,6 @@ def test_modify_limit_query(self):
             'SELECT * FROM a LIMIT 1000',
         )
 
-    def test_modify_newline_query(self):
-        self.sql_limit_regex(
-            'SELECT * FROM a\nLIMIT 9999',
-            'SELECT * FROM a LIMIT 1000',
-        )
-
-    def test_modify_lcase_limit_query(self):
-        self.sql_limit_regex(
-            'SELECT * FROM a\tlimit 9999',
-            'SELECT * FROM a LIMIT 1000',
-        )
-
     def test_limit_query_with_limit_subquery(self):
         self.sql_limit_regex(
             'SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999',
@@ -163,37 +149,38 @@ def test_limit_query_with_limit_subquery(self):
 
     def test_limit_with_expr(self):
         self.sql_limit_regex(
-            textwrap.dedent("""\
-                SELECT
-                    'LIMIT 777' AS a
-                    , b
-                FROM
-                table
-                LIMIT
-                99990"""),
-            textwrap.dedent("""\
+            """
+            SELECT
+                'LIMIT 777' AS a
+                , b
+            FROM
+            table
+            LIMIT 99990""",
+            """
             SELECT
                 'LIMIT 777' AS a
                 , b
             FROM
-            table LIMIT 1000"""),
+            table
+            LIMIT 1000""",
         )
 
     def test_limit_expr_and_semicolon(self):
         self.sql_limit_regex(
-            textwrap.dedent("""\
+            """
                 SELECT
                     'LIMIT 777' AS a
                     , b
                 FROM
                 table
-                LIMIT         99990            ;"""),
-            textwrap.dedent("""\
+                LIMIT         99990            ;""",
+            """
                 SELECT
                     'LIMIT 777' AS a
                     , b
                 FROM
-                table LIMIT 1000"""),
+                table
+                LIMIT         1000            ;""",
         )
 
     def test_get_datatype(self):
@@ -201,3 +188,51 @@ def test_get_datatype(self):
         self.assertEquals('TINY', MySQLEngineSpec.get_datatype(1))
         self.assertEquals('VARCHAR', MySQLEngineSpec.get_datatype(15))
         self.assertEquals('VARCHAR', BaseEngineSpec.get_datatype('VARCHAR'))
+
+    def test_limit_with_implicit_offset(self):
+        self.sql_limit_regex(
+            """
+                SELECT
+                    'LIMIT 777' AS a
+                    , b
+                FROM
+                table
+                LIMIT 99990, 999999""",
+            """
+                SELECT
+                    'LIMIT 777' AS a
+                    , b
+                FROM
+                table
+                LIMIT 99990, 1000""",
+        )
+
+    def test_limit_with_explicit_offset(self):
+        self.sql_limit_regex(
+            """
+                SELECT
+                    'LIMIT 777' AS a
+                    , b
+                FROM
+                table
+                LIMIT 99990
+                OFFSET 999999""",
+            """
+                SELECT
+                    'LIMIT 777' AS a
+                    , b
+                FROM
+                table
+                LIMIT 1000
+                OFFSET 999999""",
+        )
+
+    def test_limit_with_non_token_limit(self):
+        self.sql_limit_regex(
+            """
+                SELECT
+                    'LIMIT 777'""",
+            """
+                SELECT
+                    'LIMIT 777' LIMIT 1000""",
+        )


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: notifications-unsubscribe@superset.apache.org
For additional commands, e-mail: notifications-help@superset.apache.org