You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2019/05/02 05:07:07 UTC

[incubator-superset] branch master updated: [fix] Fixing SQL parsing issue (#7374)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fb627ba  [fix] Fixing SQL parsing issue (#7374)
fb627ba is described below

commit fb627ba3769dfeb8f79718790a17a91873239383
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed May 1 22:07:01 2019 -0700

    [fix] Fixing SQL parsing issue (#7374)
---
 superset/sql_parse.py    | 39 ++++++++++++++++++++-------------------
 tests/sql_parse_tests.py |  9 +++++++++
 2 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 2f65392..662f6c3 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -18,7 +18,7 @@
 import logging
 
 import sqlparse
-from sqlparse.sql import Identifier, IdentifierList
+from sqlparse.sql import Identifier, IdentifierList, Token, TokenList
 from sqlparse.tokens import Keyword, Name
 
 RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT', 'SELECT'}
@@ -75,32 +75,32 @@ class ParsedQuery(object):
         return statements
 
     @staticmethod
-    def __get_full_name(identifier):
-        if len(identifier.tokens) > 2 and identifier.tokens[1].value == '.':
-            return '{}.{}'.format(identifier.tokens[0].value,
-                                  identifier.tokens[2].value)
-        return identifier.get_real_name()
+    def __get_full_name(tlist: TokenList):
+        if len(tlist.tokens) > 2 and tlist.tokens[1].value == '.':
+            return '{}.{}'.format(tlist.tokens[0].value,
+                                  tlist.tokens[2].value)
+        return tlist.get_real_name()
 
     @staticmethod
-    def __is_identifier(token):
+    def __is_identifier(token: Token):
         return isinstance(token, (IdentifierList, Identifier))
 
-    def __process_identifier(self, identifier):
+    def __process_tokenlist(self, tlist: TokenList):
         # exclude subselects
-        if '(' not in str(identifier):
-            table_name = self.__get_full_name(identifier)
+        if '(' not in str(tlist):
+            table_name = self.__get_full_name(tlist)
             if table_name and not table_name.startswith(CTE_PREFIX):
                 self._table_names.add(table_name)
             return
 
         # store aliases
-        if hasattr(identifier, 'get_alias'):
-            self._alias_names.add(identifier.get_alias())
-        if hasattr(identifier, 'tokens'):
-            # some aliases are not parsed properly
-            if identifier.tokens[0].ttype == Name:
-                self._alias_names.add(identifier.tokens[0].value)
-        self.__extract_from_token(identifier)
+        if tlist.has_alias():
+            self._alias_names.add(tlist.get_alias())
+
+        # some aliases are not parsed properly
+        if tlist.tokens[0].ttype == Name:
+            self._alias_names.add(tlist.tokens[0].value)
+        self.__extract_from_token(tlist)
 
     def as_create_table(self, table_name, overwrite=False):
         """Reformats the query into the create table as query.
@@ -144,10 +144,11 @@ class ParsedQuery(object):
 
             if table_name_preceding_token:
                 if isinstance(item, Identifier):
-                    self.__process_identifier(item)
+                    self.__process_tokenlist(item)
                 elif isinstance(item, IdentifierList):
                     for token in item.get_identifiers():
-                        self.__process_identifier(token)
+                        if isinstance(token, TokenList):
+                            self.__process_tokenlist(token)
             elif isinstance(item, IdentifierList):
                 for token in item.tokens:
                     if not self.__is_identifier(token):
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index 5695939..7096147 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -462,3 +462,12 @@ class SupersetTestCase(unittest.TestCase):
             'SELECT * FROM ab_user LIMIT 1',
         ]
         self.assertEquals(statements, expected)
+
+    def test_identifier_list_with_keyword_as_alias(self):
+        query = """
+        WITH
+            f AS (SELECT * FROM foo),
+            match AS (SELECT * FROM f)
+        SELECT * FROM match
+        """
+        self.assertEquals({'foo'}, self.extract_tables(query))