You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2019/03/05 17:36:16 UTC

[incubator-superset] branch master updated: [sql-parse] Fixing LIMIT exceptions (#6963)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e076cb  [sql-parse] Fixing LIMIT exceptions (#6963)
3e076cb is described below

commit 3e076cb60b385e675ed1c9a8053493375e43370b
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Tue Mar 5 09:36:08 2019 -0800

    [sql-parse] Fixing LIMIT exceptions (#6963)
---
 superset/sql_parse.py         | 19 ++++++++-----------
 tests/db_engine_specs_test.py | 14 ++++++++++++++
 2 files changed, 22 insertions(+), 11 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index d1ad23d..b653c88 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -154,18 +154,15 @@ class ParsedQuery(object):
                     if not self.__is_identifier(token):
                         self.__extract_from_token(item, depth=depth + 1)
 
-    def _get_limit_from_token(self, token):
-        if token.ttype == sqlparse.tokens.Literal.Number.Integer:
-            return int(token.value)
-        elif token.is_group:
-            return int(token.get_token_at_offset(1).value)
-
     def _extract_limit_from_query(self, statement):
-        limit_token = None
-        for pos, item in enumerate(statement.tokens):
-            if item.ttype in Keyword and item.value.lower() == 'limit':
-                limit_token = statement.tokens[pos + 2]
-                return self._get_limit_from_token(limit_token)
+        idx, _ = statement.token_next_by(m=(Keyword, 'LIMIT'))
+        if idx is not None:
+            _, token = statement.token_next(idx=idx)
+            if token:
+                if isinstance(token, IdentifierList):
+                    _, token = token.token_next(idx=-1)
+                if token and token.ttype == sqlparse.tokens.Literal.Number.Integer:
+                    return int(token.value)
 
     def get_query_with_new_limit(self, new_limit):
         """returns the query with the specified limit"""
diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py
index a48012d..1390f50 100644
--- a/tests/db_engine_specs_test.py
+++ b/tests/db_engine_specs_test.py
@@ -142,12 +142,26 @@ class DbEngineSpecsTestCase(SupersetTestCase):
         q2 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20'
         q3 = 'select * from (select * from my_subquery limit 10);'
         q4 = 'select * from (select * from my_subquery limit 10) where col=1 limit 20;'
+        q5 = 'select * from mytable limit 10, 20'
+        q6 = 'select * from mytable limit 10 offset 20'
+        q7 = 'select * from mytable limit'
+        q8 = 'select * from mytable limit 10.0'
+        q9 = 'select * from mytable limit x'
+        q10 = 'select * from mytable limit x, 20'
+        q11 = 'select * from mytable limit x offset 20'
 
         self.assertEqual(engine_spec_class.get_limit_from_sql(q0), None)
         self.assertEqual(engine_spec_class.get_limit_from_sql(q1), 10)
         self.assertEqual(engine_spec_class.get_limit_from_sql(q2), 20)
         self.assertEqual(engine_spec_class.get_limit_from_sql(q3), None)
         self.assertEqual(engine_spec_class.get_limit_from_sql(q4), 20)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q5), 10)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q6), 10)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q7), None)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q8), None)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q9), None)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q10), None)
+        self.assertEqual(engine_spec_class.get_limit_from_sql(q11), None)
 
     def test_wrapped_query(self):
         self.sql_limit_regex(