You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/02/01 16:43:08 UTC
(superset) 01/09: feat(sqlparse): improve table parsing (#26476)
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git
commit fb6410004369c73bae12bb60c5b5ca7981c6cf49
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Jan 22 11:16:50 2024 -0500
feat(sqlparse): improve table parsing (#26476)
---
requirements/base.txt | 2 +
requirements/testing.txt | 6 +-
setup.py | 1 +
superset/connectors/sqla/models.py | 2 +-
superset/connectors/sqla/utils.py | 2 +-
superset/datasets/commands/duplicate.py | 5 +-
superset/db_engine_specs/base.py | 12 +-
superset/db_engine_specs/bigquery.py | 2 +-
superset/models/helpers.py | 2 +-
superset/models/sql_lab.py | 6 +-
superset/security/manager.py | 5 +-
superset/sql_lab.py | 11 +-
superset/sql_parse.py | 245 +++++++++++++++++++++-----------
superset/sql_validators/presto_db.py | 4 +-
superset/sqllab/commands/export.py | 5 +-
superset/sqllab/query_render.py | 6 +-
tests/unit_tests/sql_parse_tests.py | 55 +++++--
17 files changed, 253 insertions(+), 118 deletions(-)
diff --git a/requirements/base.txt b/requirements/base.txt
index c65c7ba77e..a34df1905d 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -283,6 +283,8 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
+sqlglot==20.8.0
+ # via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
diff --git a/requirements/testing.txt b/requirements/testing.txt
index 0a4d826c8c..e41ac7d96f 100644
--- a/requirements/testing.txt
+++ b/requirements/testing.txt
@@ -24,10 +24,6 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
-ephem==4.1.4
- # via lunarcalendar
-exceptiongroup==1.1.1
- # via pytest
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
@@ -119,7 +115,7 @@ pyee==9.0.4
# via playwright
pyfakefs==5.2.2
# via -r requirements/testing.in
-pyhive[presto]==0.6.5
+pyhive[presto]==0.7.0
# via apache-superset
pytest==7.3.1
# via
diff --git a/setup.py b/setup.py
index 211f8c5a80..e428797094 100644
--- a/setup.py
+++ b/setup.py
@@ -121,6 +121,7 @@ setup(
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
+ "sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 5edc724b23..4bcf3fc344 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -848,7 +848,7 @@ class SqlaTable(
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor)
- parsed_query = ParsedQuery(from_sql)
+ parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index c8a5f9f260..3b64bec807 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
- parsed_query = ParsedQuery(sql)
+ parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py
index 9fc05c0960..238f9a2391 100644
--- a/superset/datasets/commands/duplicate.py
+++ b/superset/datasets/commands/duplicate.py
@@ -69,7 +69,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.template_params = self._base_model.template_params
table.normalize_columns = self._base_model.normalize_columns
table.is_sqllab_view = True
- table.sql = ParsedQuery(self._base_model.sql).stripped()
+ table.sql = ParsedQuery(
+ self._base_model.sql,
+ engine=database.db_engine_spec.engine,
+ ).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index c086ce27b7..be0b55b3e9 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -874,7 +874,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return database.compile_sqla_query(qry)
if cls.limit_method == LimitMethod.FORCE_LIMIT:
- parsed_query = sql_parse.ParsedQuery(sql)
+ parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
sql = parsed_query.set_or_update_query_limit(limit, force=force)
return sql
@@ -955,7 +955,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sql: SQL query
:return: Value of limit clause in query
"""
- parsed_query = sql_parse.ParsedQuery(sql)
+ parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.limit
@classmethod
@@ -967,7 +967,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
- parsed_query = sql_parse.ParsedQuery(sql)
+ parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
return parsed_query.set_or_update_query_limit(limit)
@classmethod
@@ -1450,7 +1450,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance
:return: Dictionary with different costs
"""
- parsed_query = ParsedQuery(statement)
+ parsed_query = ParsedQuery(statement, engine=cls.engine)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@@ -1483,7 +1483,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if not cls.get_allow_cost_estimate(extra):
raise Exception("Database does not support cost estimation")
- parsed_query = sql_parse.ParsedQuery(sql)
+ parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
@@ -1544,7 +1544,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return:
"""
if not cls.allows_sql_comments:
- query = sql_parse.strip_comments_from_sql(query)
+ query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
if cls.arraysize:
cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 73b7b18d36..125760ec7f 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -436,7 +436,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")
- parsed_query = sql_parse.ParsedQuery(sql)
+ parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
statements = parsed_query.get_statements()
costs = []
for statement in statements:
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index bf33451e34..f9b30c3d42 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1085,7 +1085,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""
from_sql = self.get_rendered_sql(template_processor)
- parsed_query = ParsedQuery(from_sql)
+ parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 20df535ad3..156add2529 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -183,7 +183,7 @@ class Query(
@property
def sql_tables(self) -> list[Table]:
- return list(ParsedQuery(self.sql).tables)
+ return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
@property
def columns(self) -> list["TableColumn"]:
@@ -428,7 +428,9 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
@property
def sql_tables(self) -> list[Table]:
- return list(ParsedQuery(self.sql).tables)
+ return list(
+ ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
+ )
@property
def last_run_humanized(self) -> str:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 88657e8f21..5c0833fdf9 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1855,7 +1855,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
- for table_ in sql_parse.ParsedQuery(query.sql).tables
+ for table_ in sql_parse.ParsedQuery(
+ query.sql,
+ engine=database.db_engine_spec.engine,
+ ).tables
}
elif table:
tables = {table}
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index ca157b3240..6da1125458 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -203,7 +203,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
database: Database = query.database
db_engine_spec = database.db_engine_spec
- parsed_query = ParsedQuery(sql_statement)
+ parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
if is_feature_enabled("RLS_IN_SQLLAB"):
# Insert any applicable RLS predicates
parsed_query = ParsedQuery(
@@ -213,7 +213,8 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
database.id,
query.schema,
)
- )
+ ),
+ engine=db_engine_spec.engine,
)
sql = parsed_query.stripped()
@@ -404,7 +405,11 @@ def execute_sql_statements(
)
# Breaking down into multiple statements
- parsed_query = ParsedQuery(rendered_query, strip_comments=True)
+ parsed_query = ParsedQuery(
+ rendered_query,
+ strip_comments=True,
+ engine=db_engine_spec.engine,
+ )
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index c196fdabfa..080572eba1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -14,15 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
import logging
import re
-from collections.abc import Iterator
+import urllib.parse
+from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
-from urllib import parse
import sqlparse
from sqlalchemy import and_
+from sqlglot import exp, parse, parse_one
+from sqlglot.dialects import Dialects
+from sqlglot.errors import ParseError
+from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@@ -52,7 +57,7 @@ from superset.utils.backports import StrEnum
try:
from sqloxide import parse_sql as sqloxide_parse
-except: # pylint: disable=bare-except
+except (ImportError, ModuleNotFoundError):
sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@@ -71,6 +76,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
lex.set_SQL_REGEX(sqlparser_sql_regex)
+# mapping between DB engine specs and sqlglot dialects
+SQLGLOT_DIALECTS = {
+ "ascend": Dialects.HIVE,
+ "awsathena": Dialects.PRESTO,
+ "bigquery": Dialects.BIGQUERY,
+ "clickhouse": Dialects.CLICKHOUSE,
+ "clickhousedb": Dialects.CLICKHOUSE,
+ "cockroachdb": Dialects.POSTGRES,
+ # "crate": ???
+ # "databend": ???
+ "databricks": Dialects.DATABRICKS,
+ # "db2": ???
+ # "dremio": ???
+ "drill": Dialects.DRILL,
+ # "druid": ???
+ "duckdb": Dialects.DUCKDB,
+ # "dynamodb": ???
+ # "elasticsearch": ???
+ # "exa": ???
+ # "firebird": ???
+ # "firebolt": ???
+ "gsheets": Dialects.SQLITE,
+ "hana": Dialects.POSTGRES,
+ "hive": Dialects.HIVE,
+ # "ibmi": ???
+ # "impala": ???
+ # "kustokql": ???
+ # "kylin": ???
+ # "mssql": ???
+ "mysql": Dialects.MYSQL,
+ "netezza": Dialects.POSTGRES,
+ # "ocient": ???
+ # "odelasticsearch": ???
+ "oracle": Dialects.ORACLE,
+ # "pinot": ???
+ "postgresql": Dialects.POSTGRES,
+ "presto": Dialects.PRESTO,
+ "pydoris": Dialects.DORIS,
+ "redshift": Dialects.REDSHIFT,
+ # "risingwave": ???
+ # "rockset": ???
+ "shillelagh": Dialects.SQLITE,
+ "snowflake": Dialects.SNOWFLAKE,
+ # "solr": ???
+ "sqlite": Dialects.SQLITE,
+ "starrocks": Dialects.STARROCKS,
+ "superset": Dialects.SQLITE,
+ "teradatasql": Dialects.TERADATA,
+ "trino": Dialects.TRINO,
+ "vertica": Dialects.POSTGRES,
+}
+
+
class CtasMethod(StrEnum):
TABLE = "TABLE"
VIEW = "VIEW"
@@ -149,7 +207,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
-def strip_comments_from_sql(statement: str) -> str:
+def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@@ -159,7 +217,11 @@ def strip_comments_from_sql(statement: str) -> str:
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
- return ParsedQuery(statement).strip_comments() if "--" in statement else statement
+ return (
+ ParsedQuery(statement, engine=engine).strip_comments()
+ if "--" in statement
+ else statement
+ )
@dataclass(eq=True, frozen=True)
@@ -178,7 +240,7 @@ class Table:
"""
return ".".join(
- parse.quote(part, safe="").replace(".", "%2E")
+ urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
@@ -188,11 +250,17 @@ class Table:
class ParsedQuery:
- def __init__(self, sql_statement: str, strip_comments: bool = False):
+ def __init__(
+ self,
+ sql_statement: str,
+ strip_comments: bool = False,
+ engine: Optional[str] = None,
+ ):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
+ self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
@@ -205,14 +273,95 @@ class ParsedQuery:
@property
def tables(self) -> set[Table]:
if not self._tables:
- for statement in self._parsed:
- self._extract_from_token(statement)
-
- self._tables = {
- table for table in self._tables if str(table) not in self._alias_names
- }
+ self._tables = self._extract_tables_from_sql()
return self._tables
+ def _extract_tables_from_sql(self) -> set[Table]:
+ """
+ Extract all table references in a query.
+
+ Note: this uses sqlglot, since it's better at catching more edge cases.
+ """
+ try:
+ statements = parse(self.sql, dialect=self._dialect)
+ except ParseError:
+ logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
+ return set()
+
+ return {
+ table
+ for statement in statements
+ for table in self._extract_tables_from_statement(statement)
+ if statement
+ }
+
+ def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
+ """
+ Extract all table references in a single statement.
+
+ Please not that this is not trivial; consider the following queries:
+
+ DESCRIBE some_table;
+ SHOW PARTITIONS FROM some_table;
+ WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
+
+ See the unit tests for other tricky cases.
+ """
+ sources: Iterable[exp.Table]
+
+ if isinstance(statement, exp.Describe):
+ # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
+ # query for all tables.
+ sources = statement.find_all(exp.Table)
+ elif isinstance(statement, exp.Command):
+ # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+ # `SELECT` statetement in order to extract tables.
+ literal = statement.find(exp.Literal)
+ if not literal:
+ return set()
+
+ pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
+ sources = pseudo_query.find_all(exp.Table)
+ else:
+ sources = [
+ source
+ for scope in traverse_scope(statement)
+ for source in scope.sources.values()
+ if isinstance(source, exp.Table) and not self._is_cte(source, scope)
+ ]
+
+ return {
+ Table(
+ source.name,
+ source.db if source.db != "" else None,
+ source.catalog if source.catalog != "" else None,
+ )
+ for source in sources
+ }
+
+ # pylint: disable=no-self-use
+ def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
+ """
+ Is the source a CTE?
+
+ CTEs in the parent scope look like tables (and are represented by
+ exp.Table objects), but should not be considered as such;
+ otherwise a user with access to table `foo` could access any table
+ with a query like this:
+
+ WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+ """
+ parent_sources = scope.parent.sources if scope.parent else {}
+ ctes_in_scope = {
+ name
+ for name, parent_scope in parent_sources.items()
+ if isinstance(parent_scope, Scope)
+ and parent_scope.scope_type == ScopeType.CTE
+ }
+
+ return source.name in ctes_in_scope
+
@property
def limit(self) -> Optional[int]:
return self._limit
@@ -393,28 +542,6 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
- def _process_tokenlist(self, token_list: TokenList) -> None:
- """
- Add table names to table set
-
- :param token_list: TokenList to be processed
- """
- # exclude subselects
- if "(" not in str(token_list):
- table = self.get_table(token_list)
- if table and not table.table.startswith(CTE_PREFIX):
- self._tables.add(table)
- return
-
- # store aliases
- if token_list.has_alias():
- self._alias_names.add(token_list.get_alias())
-
- # some aliases are not parsed properly
- if token_list.tokens[0].ttype == Name:
- self._alias_names.add(token_list.tokens[0].value)
- self._extract_from_token(token_list)
-
def as_create_table(
self,
table_name: str,
@@ -441,50 +568,6 @@ class ParsedQuery:
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql
- def _extract_from_token(self, token: Token) -> None:
- """
- <Identifier> store a list of subtokens and <IdentifierList> store lists of
- subtoken list.
-
- It extracts <IdentifierList> and <Identifier> from :param token: and loops
- through all subtokens recursively. It finds table_name_preceding_token and
- passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
- self._tables.
-
- :param token: instance of Token or child class, e.g. TokenList, to be processed
- """
- if not hasattr(token, "tokens"):
- return
-
- table_name_preceding_token = False
-
- for item in token.tokens:
- if item.is_group and (
- not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
- ):
- self._extract_from_token(item)
-
- if item.ttype in Keyword and (
- item.normalized in PRECEDES_TABLE_NAME
- or item.normalized.endswith(" JOIN")
- ):
- table_name_preceding_token = True
- continue
-
- if item.ttype in Keyword:
- table_name_preceding_token = False
- continue
- if table_name_preceding_token:
- if isinstance(item, Identifier):
- self._process_tokenlist(item)
- elif isinstance(item, IdentifierList):
- for token2 in item.get_identifiers():
- if isinstance(token2, TokenList):
- self._process_tokenlist(token2)
- elif isinstance(item, IdentifierList):
- if any(not self._is_identifier(token2) for token2 in item.tokens):
- self._extract_from_token(item)
-
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
@@ -779,7 +862,7 @@ def insert_rls(
# mapping between sqloxide and SQLAlchemy dialects
-SQLOXITE_DIALECTS = {
+SQLOXIDE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
@@ -812,7 +895,7 @@ def extract_table_references(
tree = None
if sqloxide_parse:
- for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
+ for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 9d3e7641a6..20b9a8eb98 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
- parsed_query = ParsedQuery(statement)
+ parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
@@ -156,7 +156,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
- parsed_query = ParsedQuery(sql)
+ parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py
index 1b9b0e0344..aa6050f27f 100644
--- a/superset/sqllab/commands/export.py
+++ b/superset/sqllab/commands/export.py
@@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
limit = None
else:
sql = self._query.executed_sql
- limit = ParsedQuery(sql).limit
+ limit = ParsedQuery(
+ sql,
+ engine=self._query.database.db_engine_spec.engine,
+ ).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 95111276fe..5f846ef3b0 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model
)
- parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
+ parsed_query = ParsedQuery(
+ query_model.sql,
+ strip_comments=True,
+ engine=query_model.database.db_engine_spec.engine,
+ )
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
)
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 341ba9d789..52f0597ce3 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -39,11 +39,11 @@ from superset.sql_parse import (
)
-def extract_tables(query: str) -> set[Table]:
+def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
- return ParsedQuery(query).tables
+ return ParsedQuery(query, engine=engine).tables
def test_table() -> None:
@@ -95,8 +95,13 @@ def test_extract_tables() -> None:
Table("left_table")
}
- # reverse select
- assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
+ assert extract_tables(
+ "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
+ ) == {Table("forbidden_table")}
+
+ assert extract_tables(
+ "select * from (select * from forbidden_table) forbidden_table"
+ ) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None:
@@ -262,14 +267,16 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
- assert extract_tables("SELECT * FROM catalogname..tbname") == set()
+ assert extract_tables("SELECT * FROM catalogname..tbname") == {
+ Table(table="tbname", schema=None, catalog="catalogname")
+ }
def test_extract_tables_show_tables_from() -> None:
"""
Test ``SHOW TABLES FROM``.
"""
- assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
+ assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
def test_extract_tables_show_columns_from() -> None:
@@ -310,7 +317,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
"""
SELECT name
FROM t1
-WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
)
== {Table("t1"), Table("t2")}
@@ -525,6 +532,18 @@ select * from (select key from q1) a
== {Table("src")}
)
+ # weird query with circular dependency
+ assert (
+ extract_tables(
+ """
+with src as ( select key from q2 where key = '5'),
+q2 as ( select key from src where key = '5')
+select * from (select key from src) a
+"""
+ )
+ == set()
+ )
+
def test_extract_tables_multistatement() -> None:
"""
@@ -664,7 +683,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
-"""
+""",
+ "mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -675,7 +695,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
-"""
+""",
+ "mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -1305,6 +1326,14 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
+ (
+ "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
+ True,
+ ),
+ (
+ "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
+ True,
+ ),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
@@ -1607,13 +1636,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references(
sql,
"trino",
- ) == {Table(table="other_table", schema=None, catalog=None)}
+ ) == {
+ Table(table="table", schema=None, catalog=None),
+ Table(table="other_table", schema=None, catalog=None),
+ }
logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
- Table(table="other_table", schema=None, catalog=None)
+ Table(table="table", schema=None, catalog=None),
+ Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_not_called()