You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2020/04/30 15:38:24 UTC
[incubator-superset] branch master updated: [sql] Adding lighweight
Table class (#9649)
This is an automated email from the ASF dual-hosted git repository.
johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-superset.git
The following commit(s) were added to refs/heads/master by this push:
new 3b0f8e9 [sql] Adding lighweight Table class (#9649)
3b0f8e9 is described below
commit 3b0f8e9c8abb60982773a0cacc9634c4df5eb702
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Thu Apr 30 08:38:02 2020 -0700
[sql] Adding lighweight Table class (#9649)
Co-authored-by: John Bodley <jo...@airbnb.com>
---
requirements.txt | 1 +
setup.cfg | 2 +-
setup.py | 1 +
superset/security/manager.py | 97 +++++-------------
superset/sql_parse.py | 79 ++++++++++-----
superset/views/core.py | 6 +-
superset/views/database/decorators.py | 3 +-
tests/sql_parse_tests.py | 182 +++++++++++++++++++++-------------
8 files changed, 202 insertions(+), 169 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index dd8b28a..61895fd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -19,6 +19,7 @@ colorama==0.4.3 # via apache-superset (setup.py), flask-appbuilder
contextlib2==0.6.0.post1 # via apache-superset (setup.py)
croniter==0.3.31 # via apache-superset (setup.py)
cryptography==2.8 # via apache-superset (setup.py)
+dataclasses==0.6 # via apache-superset (setup.py)
decorator==4.4.1 # via retry
defusedxml==0.6.0 # via python3-openid
flask-appbuilder==2.3.2 # via apache-superset (setup.py)
diff --git a/setup.cfg b/setup.cfg
index 9469118..28a3ab3 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -45,7 +45,7 @@ combine_as_imports = true
include_trailing_comma = true
line_length = 88
known_first_party = superset
-known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqlalchemy_utils [...]
+known_third_party =alembic,apispec,backoff,bleach,celery,click,colorama,contextlib2,croniter,cryptography,dataclasses,dateutil,flask,flask_appbuilder,flask_babel,flask_caching,flask_compress,flask_login,flask_migrate,flask_sqlalchemy,flask_talisman,flask_testing,flask_wtf,geohash,geopy,humanize,isodate,jinja2,markdown,markupsafe,marshmallow,msgpack,numpy,pandas,parsedatetime,pathlib2,polyline,prison,pyarrow,pyhive,pytz,retry,selenium,setuptools,simplejson,sphinx_rtd_theme,sqlalchemy,sqla [...]
multi_line_output = 3
order_by_type = false
diff --git a/setup.py b/setup.py
index 6c14849..e3e3cdc 100644
--- a/setup.py
+++ b/setup.py
@@ -75,6 +75,7 @@ setup(
"contextlib2",
"croniter>=0.3.28",
"cryptography>=2.4.2",
+ "dataclasses<0.7",
"flask>=1.1.0, <2.0.0",
"flask-appbuilder>=2.3.2, <2.4.0",
"flask-caching",
diff --git a/superset/security/manager.py b/superset/security/manager.py
index e3b4b1d..fac9068 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
+ from superset.sql_parse import Table
from superset.viz import BaseViz
logger = logging.getLogger(__name__)
@@ -290,26 +291,23 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
- def get_table_access_error_msg(self, tables: List[str]) -> str:
+ def get_table_access_error_msg(self, tables: Set["Table"]) -> str:
"""
Return the error message for the denied SQL tables.
- Note the table names conform to the [[cluster.]schema.]table construct.
-
- :param tables: The list of denied SQL table names
+ :param tables: The set of denied SQL tables
:returns: The error message
"""
- quoted_tables = [f"`{t}`" for t in tables]
+
+ quoted_tables = [f"`{table}`" for table in tables]
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""
- def get_table_access_link(self, tables: List[str]) -> Optional[str]:
+ def get_table_access_link(self, tables: Set["Table"]) -> Optional[str]:
"""
Return the access link for the denied SQL tables.
- Note the table names conform to the [[cluster.]schema.]table construct.
-
- :param tables: The list of denied SQL table names
+ :param tables: The set of denied SQL tables
:returns: The access URL
"""
@@ -318,23 +316,19 @@ class SupersetSecurityManager(SecurityManager):
return conf.get("PERMISSION_INSTRUCTIONS_LINK")
def can_access_datasource(
- self, database: "Database", table_name: str, schema: Optional[str] = None
- ) -> bool:
- return self._datasource_access_by_name(database, table_name, schema=schema)
-
- def _datasource_access_by_name(
- self, database: "Database", table_name: str, schema: Optional[str] = None
+ self, database: "Database", table: "Table", schema: Optional[str] = None
) -> bool:
"""
Return True if the user can access the SQL table, False otherwise.
:param database: The SQL database
- :param table_name: The SQL table name
- :param schema: The Superset schema
+ :param table: The SQL table
+ :param schema: The fallback SQL schema if not present in the table
:returns: Whether the use can access the SQL table
"""
from superset import db
+ from superset.connectors.sqla.models import SqlaTable
if self.database_access(database) or self.all_datasource_access():
return True
@@ -343,74 +337,33 @@ class SupersetSecurityManager(SecurityManager):
if schema_perm and self.can_access("schema_access", schema_perm):
return True
- datasources = ConnectorRegistry.query_datasources_by_name(
- db.session, database, table_name, schema=schema
+ datasources = SqlaTable.query_datasources_by_name(
+ db.session, database, table.table, schema=table.schema or schema
)
for datasource in datasources:
if self.can_access("datasource_access", datasource.perm):
return True
return False
- def _get_schema_and_table(
- self, table_in_query: str, schema: str
- ) -> Tuple[str, str]:
+ def rejected_tables(
+ self, sql: str, database: "Database", schema: str
+ ) -> Set["Table"]:
"""
- Return the SQL schema/table tuple associated with the table extracted from the
- SQL query.
-
- Note the table name conforms to the [[cluster.]schema.]table construct.
-
- :param table_in_query: The SQL table name
- :param schema: The fallback SQL schema if not present in the table name
- :returns: The SQL schema/table tuple
- """
-
- table_name_pieces = table_in_query.split(".")
- if len(table_name_pieces) == 3:
- return tuple(table_name_pieces[1:]) # type: ignore
- elif len(table_name_pieces) == 2:
- return tuple(table_name_pieces) # type: ignore
- return (schema, table_name_pieces[0])
-
- def _datasource_access_by_fullname(
- self, database: "Database", table_in_query: str, schema: str
- ) -> bool:
- """
- Return True if the user can access the table extracted from the SQL query, False
- otherwise.
-
- Note the table name conforms to the [[cluster.]schema.]table construct.
-
- :param database: The Superset database
- :param table_in_query: The SQL table name
- :param schema: The fallback SQL schema, i.e., if not present in the table name
- :returns: Whether the user can access the SQL table
- """
-
- table_schema, table_name = self._get_schema_and_table(table_in_query, schema)
- return self._datasource_access_by_name(
- database, table_name, schema=table_schema
- )
-
- def rejected_tables(self, sql: str, database: "Database", schema: str) -> List[str]:
- """
- Return the list of rejected SQL table names.
-
- Note the rejected table names conform to the [[cluster.]schema.]table construct.
+ Return the list of rejected SQL tables.
:param sql: The SQL statement
:param database: The SQL database
:param schema: The SQL database schema
- :returns: The rejected table names
+ :returns: The rejected tables
"""
- superset_query = sql_parse.ParsedQuery(sql)
+ query = sql_parse.ParsedQuery(sql)
- return [
- t
- for t in superset_query.tables
- if not self._datasource_access_by_fullname(database, t, schema)
- ]
+ return {
+ table
+ for table in query.tables
+ if not self.can_access_datasource(database, table, schema)
+ }
def get_public_role(self) -> Optional[Any]: # Optional[self.role_model]
from superset import conf
@@ -493,7 +446,7 @@ class SupersetSecurityManager(SecurityManager):
.filter(or_(SqlaTable.perm.in_(perms)))
.distinct()
)
- accessible_schemas.update([t.schema for t in tables])
+ accessible_schemas.update([table.schema for table in tables])
return [s for s in schemas if s in accessible_schemas]
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 8cac2ff..34747e1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -16,8 +16,10 @@
# under the License.
import logging
from typing import List, Optional, Set
+from urllib import parse
import sqlparse
+from dataclasses import dataclass
from sqlparse.sql import (
Function,
Identifier,
@@ -57,10 +59,32 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
return None
+@dataclass(eq=True, frozen=True)
+class Table: # pylint: disable=too-few-public-methods
+ """
+ A fully qualified SQL table conforming to [[catalog.]schema.]table.
+ """
+
+ table: str
+ schema: Optional[str] = None
+ catalog: Optional[str] = None
+
+ def __str__(self) -> str:
+ """
+ Return the fully qualified SQL table name.
+ """
+
+ return ".".join(
+ parse.quote(part, safe="").replace(".", "%2E")
+ for part in [self.catalog, self.schema, self.table]
+ if part
+ )
+
+
class ParsedQuery:
def __init__(self, sql_statement: str):
self.sql: str = sql_statement
- self._table_names: Set[str] = set()
+ self._tables: Set[Table] = set()
self._alias_names: Set[str] = set()
self._limit: Optional[int] = None
@@ -70,12 +94,15 @@ class ParsedQuery:
self._limit = _extract_limit_from_query(statement)
@property
- def tables(self) -> Set[str]:
- if not self._table_names:
+ def tables(self) -> Set[Table]:
+ if not self._tables:
for statement in self._parsed:
- self.__extract_from_token(statement)
- self._table_names = self._table_names - self._alias_names
- return self._table_names
+ self._extract_from_token(statement)
+
+ self._tables = {
+ table for table in self._tables if str(table) not in self._alias_names
+ }
+ return self._tables
@property
def limit(self) -> Optional[int]:
@@ -105,13 +132,13 @@ class ParsedQuery:
return statements
@staticmethod
- def __get_full_name(tlist: TokenList) -> Optional[str]:
+ def _get_table(tlist: TokenList) -> Optional[Table]:
"""
- Return the full unquoted table name if valid, i.e., conforms to the following
- [[cluster.]schema.]table construct.
+ Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
+ construct.
:param tlist: The SQL tokens
- :returns: The valid full table name
+ :returns: The table if the name conforms
"""
# Strip the alias if present.
@@ -127,18 +154,18 @@ class ParsedQuery:
if (
len(tokens) in (1, 3, 5)
- and all(imt(token, t=[Name, String]) for token in tokens[0::2])
+ and all(imt(token, t=[Name, String]) for token in tokens[::2])
and all(imt(token, m=(Punctuation, ".")) for token in tokens[1::2])
):
- return ".".join([remove_quotes(token.value) for token in tokens[0::2]])
+ return Table(*[remove_quotes(token.value) for token in tokens[::-2]])
return None
@staticmethod
- def __is_identifier(token: Token) -> bool:
+ def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
- def __process_tokenlist(self, token_list: TokenList):
+ def _process_tokenlist(self, token_list: TokenList):
"""
Add table names to table set
@@ -146,9 +173,9 @@ class ParsedQuery:
"""
# exclude subselects
if "(" not in str(token_list):
- table_name = self.__get_full_name(token_list)
- if table_name and not table_name.startswith(CTE_PREFIX):
- self._table_names.add(table_name)
+ table = self._get_table(token_list)
+ if table and not table.table.startswith(CTE_PREFIX):
+ self._tables.add(table)
return
# store aliases
@@ -158,7 +185,7 @@ class ParsedQuery:
# some aliases are not parsed properly
if token_list.tokens[0].ttype == Name:
self._alias_names.add(token_list.tokens[0].value)
- self.__extract_from_token(token_list)
+ self._extract_from_token(token_list)
def as_create_table(
self,
@@ -184,9 +211,9 @@ class ParsedQuery:
exec_sql += f"CREATE TABLE {full_table_name} AS \n{sql}"
return exec_sql
- def __extract_from_token(self, token: Token): # pylint: disable=too-many-branches
+ def _extract_from_token(self, token: Token): # pylint: disable=too-many-branches
"""
- Populate self._table_names from token
+ Populate self._tables from token
:param token: instance of Token or child class, e.g. TokenList, to be processed
"""
@@ -196,8 +223,8 @@ class ParsedQuery:
table_name_preceding_token = False
for item in token.tokens:
- if item.is_group and not self.__is_identifier(item):
- self.__extract_from_token(item)
+ if item.is_group and not self._is_identifier(item):
+ self._extract_from_token(item)
if item.ttype in Keyword and (
item.normalized in PRECEDES_TABLE_NAME
@@ -212,15 +239,15 @@ class ParsedQuery:
if table_name_preceding_token:
if isinstance(item, Identifier):
- self.__process_tokenlist(item)
+ self._process_tokenlist(item)
elif isinstance(item, IdentifierList):
for token2 in item.get_identifiers():
if isinstance(token2, TokenList):
- self.__process_tokenlist(token2)
+ self._process_tokenlist(token2)
elif isinstance(item, IdentifierList):
for token2 in item.tokens:
- if not self.__is_identifier(token2):
- self.__extract_from_token(item)
+ if not self._is_identifier(token2):
+ self._extract_from_token(item)
def set_or_update_query_limit(self, new_limit: int) -> str:
"""Returns the query with the specified limit.
diff --git a/superset/views/core.py b/superset/views/core.py
index 9688b89..641a68b 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -85,7 +85,7 @@ from superset.security.analytics_db_safety import (
check_sqlalchemy_uri,
DBSecurityException,
)
-from superset.sql_parse import ParsedQuery
+from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.utils import core as utils, dashboard_import_export
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
@@ -2083,7 +2083,9 @@ class Superset(BaseSupersetView):
schema = utils.parse_js_uri_path_item(schema, eval_undefined=True)
table_name = utils.parse_js_uri_path_item(table_name)
# Check that the user can access the datasource
- if not self.appbuilder.sm.can_access_datasource(database, table_name, schema):
+ if not self.appbuilder.sm.can_access_datasource(
+ database, Table(table_name, schema), schema
+ ):
stats_logger.incr(
f"deprecated.{self.__class__.__name__}.select_star.permission_denied"
)
diff --git a/superset/views/database/decorators.py b/superset/views/database/decorators.py
index 3dd0e2a..322b420 100644
--- a/superset/views/database/decorators.py
+++ b/superset/views/database/decorators.py
@@ -22,6 +22,7 @@ from flask import g
from flask_babel import lazy_gettext as _
from superset.models.core import Database
+from superset.sql_parse import Table
from superset.utils.core import parse_js_uri_path_item
logger = logging.getLogger(__name__)
@@ -45,7 +46,7 @@ def check_datasource_access(f):
return self.response_404()
# Check that the user can access the datasource
if not self.appbuilder.sm.can_access_datasource(
- database, table_name_parsed, schema_name_parsed
+ database, Table(table_name_parsed, schema_name_parsed), schema_name_parsed
):
self.stats_logger.incr(
f"permisssion_denied_{self.__class__.__name__}.select_star"
diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py
index 46e54ff..d0ee5d1 100644
--- a/tests/sql_parse_tests.py
+++ b/tests/sql_parse_tests.py
@@ -16,90 +16,102 @@
# under the License.
import unittest
-from superset import sql_parse
+from superset.sql_parse import ParsedQuery, Table
class SupersetTestCase(unittest.TestCase):
def extract_tables(self, query):
- sq = sql_parse.ParsedQuery(query)
- return sq.tables
+ return ParsedQuery(query).tables
+
+ def test_table(self):
+ self.assertEqual(str(Table("tbname")), "tbname")
+ self.assertEqual(str(Table("tbname", "schemaname")), "schemaname.tbname")
+
+ self.assertEqual(
+ str(Table("tbname", "schemaname", "catalogname")),
+ "catalogname.schemaname.tbname",
+ )
+
+ self.assertEqual(
+ str(Table("tb.name", "schema/name", "catalog\name")),
+ "catalog%0Aame.schema%2Fname.tb%2Ename",
+ )
def test_simple_select(self):
query = "SELECT * FROM tbname"
- self.assertEqual({"tbname"}, self.extract_tables(query))
+ self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname foo"
- self.assertEqual({"tbname"}, self.extract_tables(query))
+ self.assertEqual({Table("tbname")}, self.extract_tables(query))
query = "SELECT * FROM tbname AS foo"
- self.assertEqual({"tbname"}, self.extract_tables(query))
+ self.assertEqual({Table("tbname")}, self.extract_tables(query))
# underscores
query = "SELECT * FROM tb_name"
- self.assertEqual({"tb_name"}, self.extract_tables(query))
+ self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# quotes
query = 'SELECT * FROM "tbname"'
- self.assertEqual({"tbname"}, self.extract_tables(query))
+ self.assertEqual({Table("tbname")}, self.extract_tables(query))
# unicode encoding
query = 'SELECT * FROM "tb_name" WHERE city = "Lübeck"'
- self.assertEqual({"tb_name"}, self.extract_tables(query))
+ self.assertEqual({Table("tb_name")}, self.extract_tables(query))
# schema
self.assertEqual(
- {"schemaname.tbname"},
+ {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname"),
)
self.assertEqual(
- {"schemaname.tbname"},
+ {Table("tbname", "schemaname")},
self.extract_tables('SELECT * FROM "schemaname"."tbname"'),
)
self.assertEqual(
- {"schemaname.tbname"},
+ {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname foo"),
)
self.assertEqual(
- {"schemaname.tbname"},
+ {Table("tbname", "schemaname")},
self.extract_tables("SELECT * FROM schemaname.tbname AS foo"),
)
- # cluster
self.assertEqual(
- {"clustername.schemaname.tbname"},
- self.extract_tables("SELECT * FROM clustername.schemaname.tbname"),
+ {Table("tbname", "schemaname", "catalogname")},
+ self.extract_tables("SELECT * FROM catalogname.schemaname.tbname"),
)
# Ill-defined cluster/schema/table.
self.assertEqual(set(), self.extract_tables("SELECT * FROM schemaname."))
self.assertEqual(
- set(), self.extract_tables("SELECT * FROM clustername.schemaname.")
+ set(), self.extract_tables("SELECT * FROM catalogname.schemaname.")
)
- self.assertEqual(set(), self.extract_tables("SELECT * FROM clustername.."))
+ self.assertEqual(set(), self.extract_tables("SELECT * FROM catalogname.."))
self.assertEqual(
- set(), self.extract_tables("SELECT * FROM clustername..tbname")
+ set(), self.extract_tables("SELECT * FROM catalogname..tbname")
)
# quotes
query = "SELECT field1, field2 FROM tb_name"
- self.assertEqual({"tb_name"}, self.extract_tables(query))
+ self.assertEqual({Table("tb_name")}, self.extract_tables(query))
query = "SELECT t1.f1, t2.f2 FROM t1, t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_named_table(self):
query = "SELECT a.date, a.field FROM left_table a LIMIT 10"
- self.assertEqual({"left_table"}, self.extract_tables(query))
+ self.assertEqual({Table("left_table")}, self.extract_tables(query))
def test_reverse_select(self):
query = "FROM t1 SELECT field"
- self.assertEqual({"t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_subselect(self):
query = """
@@ -111,7 +123,9 @@ class SupersetTestCase(unittest.TestCase):
) sub, s2.t2
WHERE sub.resolution = 'NONE'
"""
- self.assertEqual({"s1.t1", "s2.t2"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("t1", "s1"), Table("t2", "s2")}, self.extract_tables(query)
+ )
query = """
SELECT sub.*
@@ -122,7 +136,7 @@ class SupersetTestCase(unittest.TestCase):
) sub
WHERE sub.resolution = 'NONE'
"""
- self.assertEqual({"s1.t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1", "s1")}, self.extract_tables(query))
query = """
SELECT * FROM t1
@@ -133,21 +147,24 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t2.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
- self.assertEqual({"t1", "t2", "t3", "t4"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("t1"), Table("t2"), Table("t3"), Table("t4")},
+ self.extract_tables(query),
+ )
def test_select_in_expression(self):
query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_union(self):
query = "SELECT * FROM t1 UNION SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_select_from_values(self):
query = "SELECT * FROM VALUES (13, 42)"
@@ -158,25 +175,25 @@ class SupersetTestCase(unittest.TestCase):
SELECT ARRAY[1, 2, 3] AS my_array
FROM t1 LIMIT 10
"""
- self.assertEqual({"t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_select_if(self):
query = """
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
FROM t1 LIMIT 10
"""
- self.assertEqual({"t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1")}, self.extract_tables(query))
# SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)?
def test_show_tables(self):
query = "SHOW TABLES FROM s1 like '%order%'"
# TODO: figure out what should code do here
- self.assertEqual({"s1"}, self.extract_tables(query))
+ self.assertEqual({Table("s1")}, self.extract_tables(query))
# SHOW COLUMNS (FROM | IN) qualifiedName
def test_show_columns(self):
query = "SHOW COLUMNS FROM t1"
- self.assertEqual({"t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_where_subquery(self):
query = """
@@ -184,25 +201,25 @@ class SupersetTestCase(unittest.TestCase):
FROM t1
WHERE regionkey = (SELECT max(regionkey) FROM t2)
"""
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey IN (SELECT regionkey FROM t2)
"""
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = """
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
"""
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# DESCRIBE | DESC qualifiedName
def test_describe(self):
- self.assertEqual({"t1"}, self.extract_tables("DESCRIBE t1"))
+ self.assertEqual({Table("t1")}, self.extract_tables("DESCRIBE t1"))
# SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)?
# (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))?
@@ -211,11 +228,11 @@ class SupersetTestCase(unittest.TestCase):
SHOW PARTITIONS FROM orders
WHERE ds >= '2013-01-01' ORDER BY ds DESC;
"""
- self.assertEqual({"orders"}, self.extract_tables(query))
+ self.assertEqual({Table("orders")}, self.extract_tables(query))
def test_join(self):
query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
# subquery + join
query = """
@@ -229,7 +246,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+ )
query = """
SELECT a.date, b.name FROM
@@ -242,7 +261,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+ )
query = """
SELECT a.date, b.name FROM
@@ -255,7 +276,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+ )
query = """
SELECT a.date, b.name FROM
@@ -268,7 +291,9 @@ class SupersetTestCase(unittest.TestCase):
) b
ON a.date = b.date
"""
- self.assertEqual({"left_table", "right_table"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("left_table"), Table("right_table")}, self.extract_tables(query)
+ )
# TODO: add SEMI join support, SQL Parse does not handle it.
# query = """
@@ -296,13 +321,16 @@ class SupersetTestCase(unittest.TestCase):
WHERE ROW(5*t3.s1,77)=
(SELECT 50,11*s1 FROM t4)));
"""
- self.assertEqual({"t1", "t3", "t4", "t6"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("t1"), Table("t3"), Table("t4"), Table("t6")},
+ self.extract_tables(query),
+ )
query = """
SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS)
AS S1) AS S2) AS S3;
"""
- self.assertEqual({"EmployeeS"}, self.extract_tables(query))
+ self.assertEqual({Table("EmployeeS")}, self.extract_tables(query))
def test_with(self):
query = """
@@ -312,7 +340,9 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM t3)
SELECT c FROM z;
"""
- self.assertEqual({"t1", "t2", "t3"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("t1"), Table("t2"), Table("t3")}, self.extract_tables(query)
+ )
query = """
WITH
@@ -321,7 +351,7 @@ class SupersetTestCase(unittest.TestCase):
z AS (SELECT b AS c FROM y)
SELECT c FROM z;
"""
- self.assertEqual({"t1"}, self.extract_tables(query))
+ self.assertEqual({Table("t1")}, self.extract_tables(query))
def test_reusing_aliases(self):
query = """
@@ -329,22 +359,22 @@ class SupersetTestCase(unittest.TestCase):
q2 as ( select key from src where key = '5')
select * from (select key from q1) a;
"""
- self.assertEqual({"src"}, self.extract_tables(query))
+ self.assertEqual({Table("src")}, self.extract_tables(query))
def test_multistatement(self):
query = "SELECT * FROM t1; SELECT * FROM t2"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
query = "SELECT * FROM t1; SELECT * FROM t2;"
- self.assertEqual({"t1", "t2"}, self.extract_tables(query))
+ self.assertEqual({Table("t1"), Table("t2")}, self.extract_tables(query))
def test_update_not_select(self):
- sql = sql_parse.ParsedQuery("UPDATE t1 SET col1 = NULL")
+ sql = ParsedQuery("UPDATE t1 SET col1 = NULL")
self.assertEqual(False, sql.is_select())
self.assertEqual(False, sql.is_readonly())
def test_explain(self):
- sql = sql_parse.ParsedQuery("EXPLAIN SELECT 1")
+ sql = ParsedQuery("EXPLAIN SELECT 1")
self.assertEqual(True, sql.is_explain())
self.assertEqual(False, sql.is_select())
@@ -367,7 +397,12 @@ class SupersetTestCase(unittest.TestCase):
ORDER BY "sum__m_example" DESC
LIMIT 10;"""
self.assertEqual(
- {"my_l_table", "my_b_table", "my_t_table", "inner_table"},
+ {
+ Table("my_l_table"),
+ Table("my_b_table"),
+ Table("my_t_table"),
+ Table("inner_table"),
+ },
self.extract_tables(query),
)
@@ -375,13 +410,19 @@ class SupersetTestCase(unittest.TestCase):
query = """SELECT *
FROM table_a AS a, table_b AS b, table_c as c
WHERE a.id = b.id and b.id = c.id"""
- self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("table_a"), Table("table_b"), Table("table_c")},
+ self.extract_tables(query),
+ )
def test_mixed_from_clause(self):
query = """SELECT *
FROM table_a AS a, (select * from table_b) AS b, table_c as c
WHERE a.id = b.id and b.id = c.id"""
- self.assertEqual({"table_a", "table_b", "table_c"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("table_a"), Table("table_b"), Table("table_c")},
+ self.extract_tables(query),
+ )
def test_nested_selects(self):
query = """
@@ -389,13 +430,17 @@ class SupersetTestCase(unittest.TestCase):
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
"""
- self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
+ )
query = """
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achivement_daily"),0x7e)));
"""
- self.assertEqual({"INFORMATION_SCHEMA.COLUMNS"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("COLUMNS", "INFORMATION_SCHEMA")}, self.extract_tables(query)
+ )
def test_complex_extract_tables3(self):
query = """SELECT somecol AS somecol
@@ -431,7 +476,10 @@ class SupersetTestCase(unittest.TestCase):
WHERE 2=2
GROUP BY last_col
LIMIT 50000;"""
- self.assertEqual({"a", "b", "c", "d", "e", "f"}, self.extract_tables(query))
+ self.assertEqual(
+ {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")},
+ self.extract_tables(query),
+ )
def test_complex_cte_with_prefix(self):
query = """
@@ -446,23 +494,23 @@ class SupersetTestCase(unittest.TestCase):
GROUP BY SalesYear, SalesPersonID
ORDER BY SalesPersonID, SalesYear;
"""
- self.assertEqual({"SalesOrderHeader"}, self.extract_tables(query))
+ self.assertEqual({Table("SalesOrderHeader")}, self.extract_tables(query))
def test_get_query_with_new_limit_comment(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT"
- parsed = sql_parse.ParsedQuery(sql)
+ parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_comment_with_limit(self):
sql = "SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555"
- parsed = sql_parse.ParsedQuery(sql)
+ parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
self.assertEqual(newsql, sql + "\nLIMIT 1000")
def test_get_query_with_new_limit_lower(self):
sql = "SELECT * FROM birth_names LIMIT 555"
- parsed = sql_parse.ParsedQuery(sql)
+ parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
# not applied as new limit is higher
expected = "SELECT * FROM birth_names LIMIT 555"
@@ -470,7 +518,7 @@ class SupersetTestCase(unittest.TestCase):
def test_get_query_with_new_limit_upper(self):
sql = "SELECT * FROM birth_names LIMIT 1555"
- parsed = sql_parse.ParsedQuery(sql)
+ parsed = ParsedQuery(sql)
newsql = parsed.set_or_update_query_limit(1000)
# applied as new limit is lower
expected = "SELECT * FROM birth_names LIMIT 1000"
@@ -481,7 +529,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names;
SELECT * FROM birth_names LIMIT 1;
"""
- parsed = sql_parse.ParsedQuery(multi_sql)
+ parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEqual(len(statements), 2)
expected = ["SELECT * FROM birth_names", "SELECT * FROM birth_names LIMIT 1"]
@@ -494,7 +542,7 @@ class SupersetTestCase(unittest.TestCase):
SELECT * FROM birth_names;;;
SELECT * FROM birth_names LIMIT 1
"""
- parsed = sql_parse.ParsedQuery(multi_sql)
+ parsed = ParsedQuery(multi_sql)
statements = parsed.get_statements()
self.assertEqual(len(statements), 4)
expected = [
@@ -512,4 +560,4 @@ class SupersetTestCase(unittest.TestCase):
match AS (SELECT * FROM f)
SELECT * FROM match
"""
- self.assertEqual({"foo"}, self.extract_tables(query))
+ self.assertEqual({Table("foo")}, self.extract_tables(query))