You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/04/30 15:38:24 UTC

[incubator-superset] branch master updated: [sql] Adding lighweight Table class (#9649)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 3b0f8e9  [sql] Adding lighweight Table class (#9649)
3b0f8e9 is described below

commit 3b0f8e9c8abb60982773a0cacc9634c4df5eb702
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Thu Apr 30 08:38:02 2020 -0700

    [sql] Adding lighweight Table class (#9649)
    
    Co-authored-by: John Bodley <jo...@airbnb.com>
---
 requirements.txt                      |   1 +
 setup.cfg                             |   2 +-
 setup.py                              |   1 +
 superset/security/manager.py          |  97 +++++-------------
 superset/sql_parse.py                 |  79 ++++++++++-----
 superset/views/core.py                |   6 +-
 superset/views/database/decorators.py |   3 +-
 tests/sql_parse_tests.py              | 182 +++++++++++++++++++++-------------
 8 files changed, 202 insertions(+), 169 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index dd8b28a..61895fd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,6 +19,7 @@ colorama==0.4.3           # via apache-superset (setup.py), flask-appbuilder
 contextlib2==0.6.0.post1  # via apache-superset (setup.py)
 croniter==0.3.31          # via apache-superset (setup.py)
 cryptography==2.8         # via apache-superset (setup.py)
+dataclasses==0.6          # via apache-superset (setup.py)
 decorator==4.4.1          # via retry
 defusedxml==0.6.0         # via python3-openid
 flask-appbuilder==2.3.2   # via apache-superset (setup.py)
diff --git a/setup.cfg b/setup.cfg
index 9469118..28a3ab3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -45,7 +45,7 @@ combine_as_imports = true
 include_trailing_comma = true
 line_length = 88
 known_first_party = superset
-known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils [...]
+known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqla [...]
 multi_line_output = 3
 order_by_type = false
 
diff --git a/setup.py b/setup.py
index 6c14849..e3e3cdc 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ setup(
         "contextlib2",
         "croniter>=0.3.28",
         "cryptography>=2.4.2",
+        "dataclasses<0.7",
         "flask>=1.1.0, <2.0.0",
         "flask-appbuilder>=2.3.2, <2.4.0",
         "flask-caching",
diff --git a/superset/security/manager.py b/superset/security/manager.py
index e3b4b1d..fac9068 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
     from superset.common.query_context import QueryContext
     from superset.connectors.base.models import BaseDatasource
     from superset.models.core import Database
+    from superset.sql_parse import Table
     from superset.viz import BaseViz
 
 logger = logging.getLogger(__name__)
@@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager):
 
         return conf.get("PERMISSION_INSTRUCTIONS_LINK")
 
-    def get_table_access_error_msg(self, tables: List[str]) -> str:
+    def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
         """
         Return the error message for the denied SQL tables.
 
-        Note the table names conform to the [[cluster.]schema.]table construct.
-
-        :param tables: The list of denied SQL table names
+        :param tables: The set of denied SQL tables
         :returns: The error message
         """
-        quoted_tables = [f"`{t}`" for t in tables]
+
+        quoted_tables = [f"`{table}`" for table in tables]
         return f"""You need access to the following tables: {", ".join(quoted_tables)},
             `all_database_access` or `all_datasource_access` permission"""
 
-    def get_table_access_link(self, tables: List[str]) -> Optional[str]:
+    def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
         """
         Return the access link for the denied SQL tables.
 
-        Note the table names conform to the [[cluster.]schema.]table construct.
-
-        :param tables: The list of denied SQL table names
+        :param tables: The set of denied SQL tables
         :returns: The access URL
         """
 
@@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager):
         return conf.get("PERMISSION_INSTRUCTIONS_LINK")
 
     def can_access_datasource(
-        self, database: "Database", table_name: str, schema: Optional[str] = None
-    ) -> bool:
-        return self._datasource_access_by_name(database, table_name, schema=schema)
-
-    def _datasource_access_by_name(
-        self, database: "Database", table_name: str, schema: Optional[str] = None
+        self, database: "Database", table: "Table", schema: Optional[str] = None
     ) -> bool:
         """
         Return True if the user can access the SQL table, False otherwise.
 
         :param database: The SQL database
-        :param table_name: The SQL table name
-        :param schema: The Superset schema
+        :param table: The SQL table
+        :param schema: The fallback SQL schema if not present in the table
         :returns: Whether the use can access the SQL table
         """
 
         from superset import db
+        from superset.connectors.sqla.models import SqlaTable
 
         if self.database_access(database) or self.all_datasource_access():
             return True
@@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager):
         if schema_perm and self.can_access("schema_access", schema_perm):
             return True
 
-        datasources = ConnectorRegistry.query_datasources_by_name(
-            db.session, database, table_name, schema=schema
+        datasources = SqlaTable.query_datasources_by_name(
+            db.session, database, table.table, schema=table.schema or schema
         )
         for datasource in datasources:
             if self.can_access("datasource_access", datasource.perm):
                 return True
         return False
 
-    def _get_schema_and_table(
-        self, table_in_query: str, schema: str
-    ) -> Tuple[str, str]:
+    def rejected_tables(
+        self, sql: str, database: "Database", schema: str
+    ) -> Set["Table"]:
         """
-        Return the SQL schema/table tuple associated with the table extracted from the
-        SQL query.
-
-        Note the table name conforms to the [[cluster.]schema.]table construct.
-
-        :param table_in_query: The SQL table name
-        :param schema: The fallback SQL schema if not present in the table name
-        :returns: The SQL schema/table tuple
-        """
-
-        table_name_pieces = table_in_query.split(".")
-        if len(table_name_pieces) == 3:
-            return tuple(table_name_pieces[1:])  # type: ignore
-        elif len(table_name_pieces) == 2:
-            return tuple(table_name_pieces)  # type: ignore
-        return (schema, table_name_pieces[0])
-
-    def _datasource_access_by_fullname(
-        self, database: "Database", table_in_query: str, schema: str
-    ) -> bool:
-        """
-        Return True if the user can access the table extracted from the SQL query, False
-        otherwise.
-
-        Note the table name conforms to the [[cluster.]schema.]table construct.
-
-        :param database: The Superset database
-        :param table_in_query: The SQL table name
-        :param schema: The fallback SQL schema, i.e., if not present in the table name
-        :returns: Whether the user can access the SQL table
-        """
-
-        table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
-        return self._datasource_access_by_name(
-            database, table_name, schema=table_schema
-        )
-
-    def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
-        """
-        Return the list of rejected SQL table names.
-
-        Note the rejected table names conform to the [[cluster.]schema.]table construct.
+        Return the list of rejected SQL tables.
 
         :param sql: The SQL statement
         :param database: The SQL database
         :param schema: The SQL database schema
-        :returns: The rejected table names
+        :returns: The rejected tables
         """
 
-        superset_query = sql_parse.ParsedQuery(sql)
+        query = sql_parse.ParsedQuery(sql)
 
-        return [
-            t
-            for t in superset_query.tables
-            if not self._datasource_access_by_fullname(database, t, schema)
-        ]
+        return {
+            table
+            for table in query.tables
+            if not self.can_access_datasource(database, table, schema)
+        }
 
     def get_public_role(self) -> Optional[Any]:  # Optional[self.role_model]
         from superset import conf
@@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager):
                 .filter(or_(SqlaTable.perm.in_(perms)))
                 .distinct()
             )
-            accessible_schemas.update([t.schema for t in tables])
+            accessible_schemas.update([table.schema for table in tables])
 
         return [s for s in schemas if s in accessible_schemas]
 
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 8cac2ff..34747e1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -16,8 +16,10 @@
 # under the License.
 import logging
 from typing import List, Optional, Set
+from urllib import parse
 
 import sqlparse
+from dataclasses import dataclass
 from sqlparse.sql import (
     Function,
     Identifier,
@@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
+@dataclass(eq=True, frozen=True)
+class Table:  # pylint: disable=too-few-public-methods
+    """
+    A fully qualified SQL table conforming to [[catalog.]schema.]table.
+    """
+
+    table: str
+    schema: Optional[str] = None
+    catalog: Optional[str] = None
+
+    def __str__(self) -> str:
+        """
+        Return the fully qualified SQL table name.
+        """
+
+        return ".".join(
+            parse.quote(part, safe="").replace(".", "%2E")
+            for part in [self.catalog, self.schema, self.table]
+            if part
+        )
+
+
 class ParsedQuery:
     def __init__(self, sql_statement: str):
         self.sql: str = sql_statement
-        self._table_names: Set[str] = set()
+        self._tables: Set[Table] = set()
         self._alias_names: Set[str] = set()
         self._limit: Optional[int] = None
 
@@ -70,12 +94,15 @@ class ParsedQuery:
             self._limit = _extract_limit_from_query(statement)
 
     @property
-    def tables(self) -> Set[str]:
-        if not self._table_names:
+    def tables(self) -> Set[Table]:
+        if not self._tables:
             for statement in self._parsed:
-                self.__extract_from_token(statement)
-            self._table_names = self._table_names - self._alias_names
-        return self._table_names
+                self._extract_from_token(statement)
+
+            self._tables = {
+                table for table in self._tables if str(table) not in self._alias_names
+            }
+        return self._tables
 
     @property
     def limit(self) -> Optional[int]:
@@ -105,13 +132,13 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def __get_full_name(tlist: TokenList) -> Optional[str]:
+    def _get_table(tlist: TokenList) -> Optional[Table]:
         """
-        Return the full unquoted table name if valid, i.e., conforms to the following
-        [[cluster.]schema.]table construct.
+        Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
+        construct.
 
         :param tlist: The SQL tokens
-        :returns: The valid full table name
+        :returns: The table if the name conforms
         """
 
         # Strip the alias if present.
@@ -127,18 +154,18 @@ class ParsedQuery:
 
         if (
             len(tokens) in (1, 3, 5)
-            and all(imt(token, t=[Name, String]) for token in tokens[0::2])
+            and all(imt(token, t=[Name, String]) for token in tokens[::2])
             and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
         ):
-            return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
+            return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
 
         return None
 
     @staticmethod
-    def __is_identifier(token: Token) -> bool:
+    def _is_identifier(token: Token) -> bool:
         return isinstance(token, (IdentifierList, Identifier))
 
-    def __process_tokenlist(self, token_list: TokenList):
+    def _process_tokenlist(self, token_list: TokenList):
         """
         Add table names to table set
 
@@ -146,9 +173,9 @@ class ParsedQuery:
         """
         # exclude subselects
         if "(" not in str(token_list):
-            table_name = self.__get_full_name(token_list)
-            if table_name and not table_name.startswith(CTE_PREFIX):
-                self._table_names.add(table_name)
+            table = self._get_table(token_list)
+            if table and not table.table.startswith(CTE_PREFIX):
+                self._tables.add(table)
             return
 
         # store aliases
@@ -158,7 +185,7 @@ class ParsedQuery:
         # some aliases are not parsed properly
         if token_list.tokens[0].ttype == Name:
             self._alias_names.add(token_list.tokens[0].value)
-        self.__extract_from_token(token_list)
+        self._extract_from_token(token_list)
 
     def as_create_table(
         self,
@@ -184,9 +211,9 @@ class ParsedQuery:
         exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
         return exec_sql
 
-    def __extract_from_token(self, token: Token):  # pylint: disable=too-many-branches
+    def _extract_from_token(self, token: Token):  # pylint: disable=too-many-branches
         """
-        Populate self._table_names from token
+        Populate self._tables from token
 
         :param token: instance of Token or child class, e.g. TokenList, to be processed
         """
@@ -196,8 +223,8 @@ class ParsedQuery:
         table_name_preceding_token = False
 
         for item in token.tokens:
-            if item.is_group and not self.__is_identifier(item):
-                self.__extract_from_token(item)
+            if item.is_group and not self._is_identifier(item):
+                self._extract_from_token(item)
 
             if item.ttype in Keyword and (
                 item.normalized in PRECEDES_TABLE_NAME
@@ -212,15 +239,15 @@ class ParsedQuery:
 
             if table_name_preceding_token:
                 if isinstance(item, Identifier):
-                    self.__process_tokenlist(item)
+                    self._process_tokenlist(item)
                 elif isinstance(item, IdentifierList):
                     for token2 in item.get_identifiers():
                         if isinstance(token2, TokenList):
-                            self.__process_tokenlist(token2)
+                            self._process_tokenlist(token2)
             elif isinstance(item, IdentifierList):
                 for token2 in item.tokens:
-                    if not self.__is_identifier(token2):
-                        self.__extract_from_token(item)
+                    if not self._is_identifier(token2):
+                        self._extract_from_token(item)
 
     def set_or_update_query_limit(self, new_limit: int) -> str:
         """Returns the query with the specified limit.
diff --git a/superset/views/core.py b/superset/views/core.py
index 9688b89..641a68b 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import (
     check_sqlalchemy_uri,
     DBSecurityException,
 )
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
 from superset.sql_validators import get_validator_by_name
 from superset.utils import core as utils, dashboard_import_export
 from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
@@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView):
         schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
         table_name = utils.parse_js_uri_path_item(table_name)
         # Check that the user can access the datasource
-        if not self.appbuilder.sm.can_access_datasource(database, table_name, schema):
+        if not self.appbuilder.sm.can_access_datasource(
+            database, Table(table_name, schema), schema
+        ):
             stats_logger.incr(
                 f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
             )
diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py
index 3dd0e2a..322b420 100644
--- a/superset/views/database/decorators.py
+++ b/superset/views/database/decorators.py
@@ -22,6 +22,7 @@ from flask import g
 from flask_babel import lazy_gettext as _
 
 from superset.models.core import Database
+from superset.sql_parse import Table
 from superset.utils.core import parse_js_uri_path_item
 
 logger = logging.getLogger(__name__)
@@ -45,7 +46,7 @@ def check_datasource_access(f):
             return self.response_404()
         # Check that the user can access the datasource
         if not self.appbuilder.sm.can_access_datasource(
-            database, table_name_parsed, schema_name_parsed
+            database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
         ):
             self.stats_logger.incr(
                 f"permisssion_denied_{self.__class__.__name__}.select_star"
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index 46e54ff..d0ee5d1 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -16,90 +16,102 @@
 # under the License.
 import unittest
 
-from superset import sql_parse
+from superset.sql_parse import ParsedQuery, Table
 
 
 class SupersetTestCase(unittest.TestCase):
     def extract_tables(self, query):
-        sq = sql_parse.ParsedQuery(query)
-        return sq.tables
+        return ParsedQuery(query).tables
+
+    def test_table(self):
+        self.assertEqual(str(Table("tbname")), "tbname")
+        self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")
+
+        self.assertEqual(
+            str(Table("tbname", "schemaname", "catalogname")),
+            "catalogname.schemaname.tbname",
+        )
+
+        self.assertEqual(
+            str(Table("tb.name", "schema/name", "catalog\name")),
+            "catalog%0Aame.schema%2Fname.tb%2Ename",
+        )
 
     def test_simple_select(self):
         query = "SELECT * FROM tbname"
-        self.assertEqual({"tbname"}, self.extract_tables(query))
+        self.assertEqual({Table("tbname")}, self.extract_tables(query))
 
         query = "SELECT * FROM tbname foo"
-        self.assertEqual({"tbname"}, self.extract_tables(query))
+        self.assertEqual({Table("tbname")}, self.extract_tables(query))
 
         query = "SELECT * FROM tbname AS foo"
-        self.assertEqual({"tbname"}, self.extract_tables(query))
+        self.assertEqual({Table("tbname")}, self.extract_tables(query))
 
         # underscores
         query = "SELECT * FROM tb_name"
-        self.assertEqual({"tb_name"}, self.extract_tables(query))
+        self.assertEqual({Table("tb_name")}, self.extract_tables(query))
 
         # quotes
         query = 'SELECT * FROM "tbname"'
-        self.assertEqual({"tbname"}, self.extract_tables(query))
+        self.assertEqual({Table("tbname")}, self.extract_tables(query))
 
         # unicode encoding
         query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
-        self.assertEqual({"tb_name"}, self.extract_tables(query))
+        self.assertEqual({Table("tb_name")}, self.extract_tables(query))
 
         # schema
         self.assertEqual(
-            {"schemaname.tbname"},
+            {Table("tbname", "schemaname")},
             self.extract_tables("SELECT * FROM schemaname.tbname"),
         )
 
         self.assertEqual(
-            {"schemaname.tbname"},
+            {Table("tbname", "schemaname")},
             self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
         )
 
         self.assertEqual(
-            {"schemaname.tbname"},
+            {Table("tbname", "schemaname")},
             self.extract_tables("SELECT * FROM schemaname.tbname foo"),
         )
 
         self.assertEqual(
-            {"schemaname.tbname"},
+            {Table("tbname", "schemaname")},
             self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
         )
 
-        # cluster
         self.assertEqual(
-            {"clustername.schemaname.tbname"},
-            self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
+            {Table("tbname", "schemaname", "catalogname")},
+            self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"),
         )
 
         # Ill-defined cluster/schema/table.
         self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
 
         self.assertEqual(
-            set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
+            set(), self.extract_tables("SELECT * FROM catalogname.schemaname.")
         )
 
-        self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
+        self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname.."))
 
         self.assertEqual(
-            set(), self.extract_tables("SELECT * FROM clustername..tbname")
+            set(), self.extract_tables("SELECT * FROM catalogname..tbname")
         )
 
         # quotes
         query = "SELECT field1, field2 FROM tb_name"
-        self.assertEqual({"tb_name"}, self.extract_tables(query))
+        self.assertEqual({Table("tb_name")}, self.extract_tables(query))
 
         query = "SELECT t1.f1, t2.f2 FROM t1, t2"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
     def test_select_named_table(self):
         query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
-        self.assertEqual({"left_table"}, self.extract_tables(query))
+        self.assertEqual({Table("left_table")}, self.extract_tables(query))
 
     def test_reverse_select(self):
         query = "FROM t1 SELECT field"
-        self.assertEqual({"t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1")}, self.extract_tables(query))
 
     def test_subselect(self):
         query = """
@@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase):
                    ) sub, s2.t2
           WHERE sub.resolution = 'NONE'
         """
-        self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query)
+        )
 
         query = """
           SELECT sub.*
@@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase):
                    ) sub
           WHERE sub.resolution = 'NONE'
         """
-        self.assertEqual({"s1.t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1", "s1")}, self.extract_tables(query))
 
         query = """
             SELECT * FROM t1
@@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase):
                   WHERE ROW(5*t2.s1,77)=
                     (SELECT 50,11*s1 FROM t4)));
         """
-        self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("t1"), Table("t2"), Table("t3"), Table("t4")},
+            self.extract_tables(query),
+        )
 
     def test_select_in_expression(self):
         query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
     def test_union(self):
         query = "SELECT * FROM t1 UNION SELECT * FROM t2"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
     def test_select_from_values(self):
         query = "SELECT * FROM VALUES (13, 42)"
@@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase):
             SELECT ARRAY[1, 2, 3] AS my_array
             FROM t1 LIMIT 10
         """
-        self.assertEqual({"t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1")}, self.extract_tables(query))
 
     def test_select_if(self):
         query = """
             SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
             FROM t1 LIMIT 10
         """
-        self.assertEqual({"t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1")}, self.extract_tables(query))
 
     # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
     def test_show_tables(self):
         query = "SHOW TABLES FROM s1 like '%order%'"
         # TODO: figure out what should code do here
-        self.assertEqual({"s1"}, self.extract_tables(query))
+        self.assertEqual({Table("s1")}, self.extract_tables(query))
 
     # SHOW COLUMNS (FROM | IN) qualifiedName
     def test_show_columns(self):
         query = "SHOW COLUMNS FROM t1"
-        self.assertEqual({"t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1")}, self.extract_tables(query))
 
     def test_where_subquery(self):
         query = """
@@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase):
             FROM t1
             WHERE regionkey = (SELECT max(regionkey) FROM t2)
         """
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         query = """
           SELECT name
             FROM t1
             WHERE regionkey IN (SELECT regionkey FROM t2)
         """
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         query = """
           SELECT name
             FROM t1
             WHERE regionkey EXISTS (SELECT regionkey FROM t2)
         """
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
     # DESCRIBE | DESC qualifiedName
     def test_describe(self):
-        self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
+        self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
 
     # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
     # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase):
             SHOW PARTITIONS FROM orders
             WHERE ds >= '2013-01-01' ORDER BY ds DESC;
         """
-        self.assertEqual({"orders"}, self.extract_tables(query))
+        self.assertEqual({Table("orders")}, self.extract_tables(query))
 
     def test_join(self):
         query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         # subquery + join
         query = """
@@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase):
                 ) b
                 ON a.date = b.date
         """
-        self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+        )
 
         query = """
             SELECT a.date, b.name FROM
@@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase):
                 ) b
                 ON a.date = b.date
         """
-        self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+        )
 
         query = """
             SELECT a.date, b.name FROM
@@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase):
                 ) b
                 ON a.date = b.date
         """
-        self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+        )
 
         query = """
             SELECT a.date, b.name FROM
@@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase):
                 ) b
                 ON a.date = b.date
         """
-        self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+        )
 
         # TODO: add SEMI join support, SQL Parse does not handle it.
         # query = """
@@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase):
                   WHERE ROW(5*t3.s1,77)=
                     (SELECT 50,11*s1 FROM t4)));
         """
-        self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("t1"), Table("t3"), Table("t4"), Table("t6")},
+            self.extract_tables(query),
+        )
 
         query = """
         SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
             AS S1) AS S2) AS S3;
         """
-        self.assertEqual({"EmployeeS"}, self.extract_tables(query))
+        self.assertEqual({Table("EmployeeS")}, self.extract_tables(query))
 
     def test_with(self):
         query = """
@@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase):
               z AS (SELECT b AS c FROM t3)
             SELECT c FROM z;
         """
-        self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query)
+        )
 
         query = """
             WITH
@@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase):
               z AS (SELECT b AS c FROM y)
             SELECT c FROM z;
         """
-        self.assertEqual({"t1"}, self.extract_tables(query))
+        self.assertEqual({Table("t1")}, self.extract_tables(query))
 
     def test_reusing_aliases(self):
         query = """
@@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase):
             q2 as ( select key from src where key = '5')
             select * from (select key from q1) a;
         """
-        self.assertEqual({"src"}, self.extract_tables(query))
+        self.assertEqual({Table("src")}, self.extract_tables(query))
 
     def test_multistatement(self):
         query = "SELECT * FROM t1; SELECT * FROM t2"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
         query = "SELECT * FROM t1; SELECT * FROM t2;"
-        self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+        self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
 
     def test_update_not_select(self):
-        sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
+        sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
         self.assertEqual(False, sql.is_select())
         self.assertEqual(False, sql.is_readonly())
 
     def test_explain(self):
-        sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
+        sql = ParsedQuery("EXPLAIN SELECT 1")
 
         self.assertEqual(True, sql.is_explain())
         self.assertEqual(False, sql.is_select())
@@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase):
             ORDER BY "sum__m_example" DESC
             LIMIT 10;"""
         self.assertEqual(
-            {"my_l_table", "my_b_table", "my_t_table", "inner_table"},
+            {
+                Table("my_l_table"),
+                Table("my_b_table"),
+                Table("my_t_table"),
+                Table("inner_table"),
+            },
             self.extract_tables(query),
         )
 
@@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase):
         query = """SELECT *
             FROM table_a AS a, table_b AS b, table_c as c
             WHERE a.id = b.id and b.id = c.id"""
-        self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("table_a"), Table("table_b"), Table("table_c")},
+            self.extract_tables(query),
+        )
 
     def test_mixed_from_clause(self):
         query = """SELECT *
             FROM table_a AS a, (select * from table_b) AS b, table_c as c
             WHERE a.id = b.id and b.id = c.id"""
-        self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("table_a"), Table("table_b"), Table("table_c")},
+            self.extract_tables(query),
+        )
 
     def test_nested_selects(self):
         query = """
@@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase):
             from INFORMATION_SCHEMA.COLUMNS
             WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
         """
-        self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
+        )
         query = """
             select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
             from INFORMATION_SCHEMA.COLUMNS
             WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
         """
-        self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
+        )
 
     def test_complex_extract_tables3(self):
         query = """SELECT somecol AS somecol
@@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase):
             WHERE 2=2
             GROUP BY last_col
             LIMIT 50000;"""
-        self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
+        self.assertEqual(
+            {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")},
+            self.extract_tables(query),
+        )
 
     def test_complex_cte_with_prefix(self):
         query = """
@@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase):
         GROUP BY SalesYear, SalesPersonID
         ORDER BY SalesPersonID, SalesYear;
         """
-        self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
+        self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query))
 
     def test_get_query_with_new_limit_comment(self):
         sql = "SELECT * FROM birth_names -- SOME COMMENT"
-        parsed = sql_parse.ParsedQuery(sql)
+        parsed = ParsedQuery(sql)
         newsql = parsed.set_or_update_query_limit(1000)
         self.assertEqual(newsql, sql + "\nLIMIT 1000")
 
     def test_get_query_with_new_limit_comment_with_limit(self):
         sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
-        parsed = sql_parse.ParsedQuery(sql)
+        parsed = ParsedQuery(sql)
         newsql = parsed.set_or_update_query_limit(1000)
         self.assertEqual(newsql, sql + "\nLIMIT 1000")
 
     def test_get_query_with_new_limit_lower(self):
         sql = "SELECT * FROM birth_names LIMIT 555"
-        parsed = sql_parse.ParsedQuery(sql)
+        parsed = ParsedQuery(sql)
         newsql = parsed.set_or_update_query_limit(1000)
         # not applied as new limit is higher
         expected = "SELECT * FROM birth_names LIMIT 555"
@@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase):
 
     def test_get_query_with_new_limit_upper(self):
         sql = "SELECT * FROM birth_names LIMIT 1555"
-        parsed = sql_parse.ParsedQuery(sql)
+        parsed = ParsedQuery(sql)
         newsql = parsed.set_or_update_query_limit(1000)
         # applied as new limit is lower
         expected = "SELECT * FROM birth_names LIMIT 1000"
@@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase):
         SELECT * FROM birth_names;
         SELECT * FROM birth_names LIMIT 1;
         """
-        parsed = sql_parse.ParsedQuery(multi_sql)
+        parsed = ParsedQuery(multi_sql)
         statements = parsed.get_statements()
         self.assertEqual(len(statements), 2)
         expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
@@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase):
         SELECT * FROM birth_names;;;
         SELECT * FROM birth_names LIMIT 1
         """
-        parsed = sql_parse.ParsedQuery(multi_sql)
+        parsed = ParsedQuery(multi_sql)
         statements = parsed.get_statements()
         self.assertEqual(len(statements), 4)
         expected = [
@@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase):
             match AS (SELECT * FROM f)
         SELECT * FROM match
         """
-        self.assertEqual({"foo"}, self.extract_tables(query))
+        self.assertEqual({Table("foo")}, self.extract_tables(query))