You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/03/08 20:04:51 UTC
(superset) 02/02: Revert "feat(sqlparse): improve table parsing (#26476)"
This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch 3.1
in repository https://gitbox.apache.org/repos/asf/superset.git
commit 82895af3ce7988548bf8f727e643b38ec789c1e2
Author: Michael S. Molina <mi...@gmail.com>
AuthorDate: Fri Mar 8 17:00:38 2024 -0300
Revert "feat(sqlparse): improve table parsing (#26476)"
This reverts commit 1d9cfdabd1816c71f716c8d7e213558d5b7ff05e.
---
requirements/base.txt | 15 +-
requirements/testing.txt | 4 +
setup.py | 1 -
superset/commands/dataset/duplicate.py | 5 +-
superset/commands/sql_lab/export.py | 5 +-
superset/connectors/sqla/models.py | 2 +-
superset/connectors/sqla/utils.py | 2 +-
superset/db_engine_specs/base.py | 12 +-
superset/db_engine_specs/bigquery.py | 2 +-
superset/models/helpers.py | 2 +-
superset/models/sql_lab.py | 6 +-
superset/security/manager.py | 5 +-
superset/sql_lab.py | 11 +-
superset/sql_parse.py | 246 +++++++++++----------------------
superset/sql_validators/presto_db.py | 4 +-
superset/sqllab/query_render.py | 6 +-
tests/unit_tests/sql_parse_tests.py | 55 ++------
17 files changed, 120 insertions(+), 263 deletions(-)
diff --git a/requirements/base.txt b/requirements/base.txt
index de25938a01..fbb82b46f5 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -141,9 +141,7 @@ geographiclib==1.52
geopy==2.2.0
# via apache-superset
greenlet==2.0.2
- # via
- # shillelagh
- # sqlalchemy
+ # via shillelagh
gunicorn==21.2.0
# via apache-superset
hashids==1.3.1
@@ -157,10 +155,7 @@ idna==3.2
# email-validator
# requests
importlib-metadata==6.6.0
- # via
- # apache-superset
- # flask
- # shillelagh
+ # via apache-superset
importlib-resources==5.12.0
# via limits
isodate==0.6.0
@@ -335,8 +330,6 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
-sqlglot==20.8.0
- # via apache-superset
sqlparse==0.4.4
# via apache-superset
sshtunnel==0.4.0
@@ -387,9 +380,7 @@ wtforms-json==0.3.5
xlsxwriter==3.0.7
# via apache-superset
zipp==3.15.0
- # via
- # importlib-metadata
- # importlib-resources
+ # via importlib-metadata
# The following packages are considered to be unsafe in a requirements file:
# setuptools
diff --git a/requirements/testing.txt b/requirements/testing.txt
index fce953f8e4..725df3ac3c 100644
--- a/requirements/testing.txt
+++ b/requirements/testing.txt
@@ -24,6 +24,10 @@ db-dtypes==1.1.1
# via pandas-gbq
docker==6.1.1
# via -r requirements/testing.in
+exceptiongroup==1.1.1
+ # via pytest
+ephem==4.1.4
+ # via lunarcalendar
flask-testing==0.8.1
# via -r requirements/testing.in
fonttools==4.39.4
diff --git a/setup.py b/setup.py
index 7050d7b497..15ac83417b 100644
--- a/setup.py
+++ b/setup.py
@@ -126,7 +126,6 @@ setup(
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
- "sqlglot>=20,<21",
"sqlparse>=0.4.4, <0.5",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py
index 850290422e..0ae47c35bc 100644
--- a/superset/commands/dataset/duplicate.py
+++ b/superset/commands/dataset/duplicate.py
@@ -70,10 +70,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
table.normalize_columns = self._base_model.normalize_columns
table.always_filter_main_dttm = self._base_model.always_filter_main_dttm
table.is_sqllab_view = True
- table.sql = ParsedQuery(
- self._base_model.sql,
- engine=database.db_engine_spec.engine,
- ).stripped()
+ table.sql = ParsedQuery(self._base_model.sql).stripped()
db.session.add(table)
cols = []
for config_ in self._base_model.columns:
diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py
index aa6050f27f..1b9b0e0344 100644
--- a/superset/commands/sql_lab/export.py
+++ b/superset/commands/sql_lab/export.py
@@ -115,10 +115,7 @@ class SqlResultExportCommand(BaseCommand):
limit = None
else:
sql = self._query.executed_sql
- limit = ParsedQuery(
- sql,
- engine=self._query.database.db_engine_spec.engine,
- ).limit
+ limit = ParsedQuery(sql).limit
if limit is not None and self._query.limiting_factor in {
LimitingFactor.QUERY,
LimitingFactor.DROPDOWN,
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index bd54032d5d..598bc6741b 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -1458,7 +1458,7 @@ class SqlaTable(
return self.get_sqla_table(), None
from_sql = self.get_rendered_sql(template_processor)
- parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
+ parsed_query = ParsedQuery(from_sql)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 688be53515..66594084c8 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
sql = dataset.get_template_processor().process_template(
dataset.sql, **dataset.template_params_dict
)
- parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
+ parsed_query = ParsedQuery(sql)
if not db_engine_spec.is_readonly_query(parsed_query):
raise SupersetSecurityException(
SupersetError(
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 66293ccf52..ce67cb448c 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -900,7 +900,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return database.compile_sqla_query(qry)
if cls.limit_method == LimitMethod.FORCE_LIMIT:
- parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
+ parsed_query = sql_parse.ParsedQuery(sql)
sql = parsed_query.set_or_update_query_limit(limit, force=force)
return sql
@@ -981,7 +981,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param sql: SQL query
:return: Value of limit clause in query
"""
- parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
+ parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.limit
@classmethod
@@ -993,7 +993,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param limit: New limit to insert/replace into query
:return: Query with new limit
"""
- parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
+ parsed_query = sql_parse.ParsedQuery(sql)
return parsed_query.set_or_update_query_limit(limit)
@classmethod
@@ -1490,7 +1490,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: Database instance
:return: Dictionary with different costs
"""
- parsed_query = ParsedQuery(statement, engine=cls.engine)
+ parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()
sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@@ -1525,7 +1525,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
"Database does not support cost estimation"
)
- parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
+ parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
costs = []
@@ -1586,7 +1586,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return:
"""
if not cls.allows_sql_comments:
- query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
+ query = sql_parse.strip_comments_from_sql(query)
if cls.arraysize:
cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index a8d834276e..8e7ed0bf7d 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -435,7 +435,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
if not cls.get_allow_cost_estimate(extra):
raise SupersetException("Database does not support cost estimation")
- parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
+ parsed_query = sql_parse.ParsedQuery(sql)
statements = parsed_query.get_statements()
costs = []
for statement in statements:
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 4ff206882e..1dc5a57da5 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1094,7 +1094,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""
from_sql = self.get_rendered_sql(template_processor)
- parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
+ parsed_query = ParsedQuery(from_sql)
if not (
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index aff8c1ce3d..7e63e984df 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -183,7 +183,7 @@ class Query(
@property
def sql_tables(self) -> list[Table]:
- return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
+ return list(ParsedQuery(self.sql).tables)
@property
def columns(self) -> list["TableColumn"]:
@@ -427,9 +427,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
@property
def sql_tables(self) -> list[Table]:
- return list(
- ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
- )
+ return list(ParsedQuery(self.sql).tables)
@property
def last_run_humanized(self) -> str:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index e6eb77e645..501c8cf6a6 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1909,10 +1909,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
default_schema = database.get_default_schema_for_query(query)
tables = {
Table(table_.table, table_.schema or default_schema)
- for table_ in sql_parse.ParsedQuery(
- query.sql,
- engine=database.db_engine_spec.engine,
- ).tables
+ for table_ in sql_parse.ParsedQuery(query.sql).tables
}
elif table:
tables = {table}
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index e9b4d406f8..efbef6560a 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -208,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database: Database = query.database
db_engine_spec = database.db_engine_spec
- parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
+ parsed_query = ParsedQuery(sql_statement)
if is_feature_enabled("RLS_IN_SQLLAB"):
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
@@ -228,8 +228,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments, too-many-local
database.id,
query.schema,
)
- ),
- engine=db_engine_spec.engine,
+ )
)
sql = parsed_query.stripped()
@@ -420,11 +419,7 @@ def execute_sql_statements(
)
# Breaking down into multiple statements
- parsed_query = ParsedQuery(
- rendered_query,
- strip_comments=True,
- engine=db_engine_spec.engine,
- )
+ parsed_query = ParsedQuery(rendered_query, strip_comments=True)
if not db_engine_spec.run_multiple_statements_as_one:
statements = parsed_query.get_statements()
logger.info(
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 7b89ab8f0e..b9af21c8c3 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -14,22 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# pylint: disable=too-many-lines
-
import logging
import re
-import urllib.parse
-from collections.abc import Iterable, Iterator
+from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, cast, Optional
+from urllib import parse
import sqlparse
from sqlalchemy import and_
-from sqlglot import exp, parse, parse_one
-from sqlglot.dialects import Dialects
-from sqlglot.errors import ParseError
-from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
from sqlparse import keywords
from sqlparse.lexer import Lexer
from sqlparse.sql import (
@@ -60,7 +53,7 @@ from superset.utils.backports import StrEnum
try:
from sqloxide import parse_sql as sqloxide_parse
-except (ImportError, ModuleNotFoundError):
+except: # pylint: disable=bare-except
sqloxide_parse = None
RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@@ -79,59 +72,6 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
lex.set_SQL_REGEX(sqlparser_sql_regex)
-# mapping between DB engine specs and sqlglot dialects
-SQLGLOT_DIALECTS = {
- "ascend": Dialects.HIVE,
- "awsathena": Dialects.PRESTO,
- "bigquery": Dialects.BIGQUERY,
- "clickhouse": Dialects.CLICKHOUSE,
- "clickhousedb": Dialects.CLICKHOUSE,
- "cockroachdb": Dialects.POSTGRES,
- # "crate": ???
- # "databend": ???
- "databricks": Dialects.DATABRICKS,
- # "db2": ???
- # "dremio": ???
- "drill": Dialects.DRILL,
- # "druid": ???
- "duckdb": Dialects.DUCKDB,
- # "dynamodb": ???
- # "elasticsearch": ???
- # "exa": ???
- # "firebird": ???
- # "firebolt": ???
- "gsheets": Dialects.SQLITE,
- "hana": Dialects.POSTGRES,
- "hive": Dialects.HIVE,
- # "ibmi": ???
- # "impala": ???
- # "kustokql": ???
- # "kylin": ???
- # "mssql": ???
- "mysql": Dialects.MYSQL,
- "netezza": Dialects.POSTGRES,
- # "ocient": ???
- # "odelasticsearch": ???
- "oracle": Dialects.ORACLE,
- # "pinot": ???
- "postgresql": Dialects.POSTGRES,
- "presto": Dialects.PRESTO,
- "pydoris": Dialects.DORIS,
- "redshift": Dialects.REDSHIFT,
- # "risingwave": ???
- # "rockset": ???
- "shillelagh": Dialects.SQLITE,
- "snowflake": Dialects.SNOWFLAKE,
- # "solr": ???
- "sqlite": Dialects.SQLITE,
- "starrocks": Dialects.STARROCKS,
- "superset": Dialects.SQLITE,
- "teradatasql": Dialects.TERADATA,
- "trino": Dialects.TRINO,
- "vertica": Dialects.POSTGRES,
-}
-
-
class CtasMethod(StrEnum):
TABLE = "TABLE"
VIEW = "VIEW"
@@ -210,7 +150,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
return cte, remainder
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str) -> str:
"""
Strips comments from a SQL statement, does a simple test first
to avoid always instantiating the expensive ParsedQuery constructor
@@ -220,11 +160,7 @@ def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str
:param statement: A string with the SQL statement
:return: SQL statement without comments
"""
- return (
- ParsedQuery(statement, engine=engine).strip_comments()
- if "--" in statement
- else statement
- )
+ return ParsedQuery(statement).strip_comments() if "--" in statement else statement
@dataclass(eq=True, frozen=True)
@@ -243,7 +179,7 @@ class Table:
"""
return ".".join(
- urllib.parse.quote(part, safe="").replace(".", "%2E")
+ parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
@@ -253,17 +189,11 @@ class Table:
class ParsedQuery:
- def __init__(
- self,
- sql_statement: str,
- strip_comments: bool = False,
- engine: Optional[str] = None,
- ):
+ def __init__(self, sql_statement: str, strip_comments: bool = False):
if strip_comments:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
- self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
self._tables: set[Table] = set()
self._alias_names: set[str] = set()
self._limit: Optional[int] = None
@@ -276,93 +206,13 @@ class ParsedQuery:
@property
def tables(self) -> set[Table]:
if not self._tables:
- self._tables = self._extract_tables_from_sql()
- return self._tables
-
- def _extract_tables_from_sql(self) -> set[Table]:
- """
- Extract all table references in a query.
-
- Note: this uses sqlglot, since it's better at catching more edge cases.
- """
- try:
- statements = parse(self.stripped(), dialect=self._dialect)
- except ParseError:
- logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
- return set()
-
- return {
- table
- for statement in statements
- for table in self._extract_tables_from_statement(statement)
- if statement
- }
-
- def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
- """
- Extract all table references in a single statement.
-
- Please not that this is not trivial; consider the following queries:
-
- DESCRIBE some_table;
- SHOW PARTITIONS FROM some_table;
- WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
-
- See the unit tests for other tricky cases.
- """
- sources: Iterable[exp.Table]
-
- if isinstance(statement, exp.Describe):
- # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
- # query for all tables.
- sources = statement.find_all(exp.Table)
- elif isinstance(statement, exp.Command):
- # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
- # `SELECT` statetement in order to extract tables.
- literal = statement.find(exp.Literal)
- if not literal:
- return set()
-
- pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
- sources = pseudo_query.find_all(exp.Table)
- else:
- sources = [
- source
- for scope in traverse_scope(statement)
- for source in scope.sources.values()
- if isinstance(source, exp.Table) and not self._is_cte(source, scope)
- ]
-
- return {
- Table(
- source.name,
- source.db if source.db != "" else None,
- source.catalog if source.catalog != "" else None,
- )
- for source in sources
- }
-
- def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
- """
- Is the source a CTE?
+ for statement in self._parsed:
+ self._extract_from_token(statement)
- CTEs in the parent scope look like tables (and are represented by
- exp.Table objects), but should not be considered as such;
- otherwise a user with access to table `foo` could access any table
- with a query like this:
-
- WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
-
- """
- parent_sources = scope.parent.sources if scope.parent else {}
- ctes_in_scope = {
- name
- for name, parent_scope in parent_sources.items()
- if isinstance(parent_scope, Scope)
- and parent_scope.scope_type == ScopeType.CTE
- }
-
- return source.name in ctes_in_scope
+ self._tables = {
+ table for table in self._tables if str(table) not in self._alias_names
+ }
+ return self._tables
@property
def limit(self) -> Optional[int]:
@@ -543,6 +393,28 @@ class ParsedQuery:
def _is_identifier(token: Token) -> bool:
return isinstance(token, (IdentifierList, Identifier))
+ def _process_tokenlist(self, token_list: TokenList) -> None:
+ """
+ Add table names to table set
+
+ :param token_list: TokenList to be processed
+ """
+ # exclude subselects
+ if "(" not in str(token_list):
+ table = self.get_table(token_list)
+ if table and not table.table.startswith(CTE_PREFIX):
+ self._tables.add(table)
+ return
+
+ # store aliases
+ if token_list.has_alias():
+ self._alias_names.add(token_list.get_alias())
+
+ # some aliases are not parsed properly
+ if token_list.tokens[0].ttype == Name:
+ self._alias_names.add(token_list.tokens[0].value)
+ self._extract_from_token(token_list)
+
def as_create_table(
self,
table_name: str,
@@ -569,6 +441,50 @@ class ParsedQuery:
exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
return exec_sql
+ def _extract_from_token(self, token: Token) -> None:
+ """
+ <Identifier> store a list of subtokens and <IdentifierList> store lists of
+ subtoken list.
+
+ It extracts <IdentifierList> and <Identifier> from :param token: and loops
+ through all subtokens recursively. It finds table_name_preceding_token and
+ passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
+ self._tables.
+
+ :param token: instance of Token or child class, e.g. TokenList, to be processed
+ """
+ if not hasattr(token, "tokens"):
+ return
+
+ table_name_preceding_token = False
+
+ for item in token.tokens:
+ if item.is_group and (
+ not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
+ ):
+ self._extract_from_token(item)
+
+ if item.ttype in Keyword and (
+ item.normalized in PRECEDES_TABLE_NAME
+ or item.normalized.endswith(" JOIN")
+ ):
+ table_name_preceding_token = True
+ continue
+
+ if item.ttype in Keyword:
+ table_name_preceding_token = False
+ continue
+ if table_name_preceding_token:
+ if isinstance(item, Identifier):
+ self._process_tokenlist(item)
+ elif isinstance(item, IdentifierList):
+ for token2 in item.get_identifiers():
+ if isinstance(token2, TokenList):
+ self._process_tokenlist(token2)
+ elif isinstance(item, IdentifierList):
+ if any(not self._is_identifier(token2) for token2 in item.tokens):
+ self._extract_from_token(item)
+
def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
"""Returns the query with the specified limit.
@@ -965,7 +881,7 @@ def insert_rls_in_predicate(
# mapping between sqloxide and SQLAlchemy dialects
-SQLOXIDE_DIALECTS = {
+SQLOXITE_DIALECTS = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
@@ -998,7 +914,7 @@ def extract_table_references(
tree = None
if sqloxide_parse:
- for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
+ for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
if sqla_dialect in sqla_dialects:
break
sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index fed1ff3bfa..c01b938671 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
) -> Optional[SQLValidationAnnotation]:
# pylint: disable=too-many-locals
db_engine_spec = database.db_engine_spec
- parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
+ parsed_query = ParsedQuery(statement)
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
@@ -154,7 +154,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
VALIDATE) SELECT 1 FROM default.mytable.
"""
- parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
+ parsed_query = ParsedQuery(sql)
statements = parsed_query.get_statements()
logger.info("Validating %i statement(s)", len(statements))
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 5597bcb086..f4c1c26c6e 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -58,11 +58,7 @@ class SqlQueryRenderImpl(SqlQueryRender):
database=query_model.database, query=query_model
)
- parsed_query = ParsedQuery(
- query_model.sql,
- strip_comments=True,
- engine=query_model.database.db_engine_spec.engine,
- )
+ parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
rendered_query = sql_template_processor.process_template(
parsed_query.stripped(), **execution_context.template_params
)
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f650b77734..efd8838101 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -40,11 +40,11 @@ from superset.sql_parse import (
)
-def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
+def extract_tables(query: str) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
- return ParsedQuery(query, engine=engine).tables
+ return ParsedQuery(query).tables
def test_table() -> None:
@@ -96,13 +96,8 @@ def test_extract_tables() -> None:
Table("left_table")
}
- assert extract_tables(
- "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
- ) == {Table("forbidden_table")}
-
- assert extract_tables(
- "select * from (select * from forbidden_table) forbidden_table"
- ) == {Table("forbidden_table")}
+ # reverse select
+ assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
def test_extract_tables_subselect() -> None:
@@ -268,16 +263,14 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
- assert extract_tables("SELECT * FROM catalogname..tbname") == {
- Table(table="tbname", schema=None, catalog="catalogname")
- }
+ assert extract_tables("SELECT * FROM catalogname..tbname") == set()
def test_extract_tables_show_tables_from() -> None:
"""
Test ``SHOW TABLES FROM``.
"""
- assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
+ assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
def test_extract_tables_show_columns_from() -> None:
@@ -318,7 +311,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
"""
SELECT name
FROM t1
-WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
+WHERE regionkey EXISTS (SELECT regionkey FROM t2)
"""
)
== {Table("t1"), Table("t2")}
@@ -533,18 +526,6 @@ select * from (select key from q1) a
== {Table("src")}
)
- # weird query with circular dependency
- assert (
- extract_tables(
- """
-with src as ( select key from q2 where key = '5'),
-q2 as ( select key from src where key = '5')
-select * from (select key from src) a
-"""
- )
- == set()
- )
-
def test_extract_tables_multistatement() -> None:
"""
@@ -684,8 +665,7 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
-""",
- "mysql",
+"""
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -696,8 +676,7 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
-""",
- "mysql",
+"""
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -1327,14 +1306,6 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
- (
- "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
- True,
- ),
- (
- "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
- True,
- ),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
@@ -1819,17 +1790,13 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references(
sql,
"trino",
- ) == {
- Table(table="table", schema=None, catalog=None),
- Table(table="other_table", schema=None, catalog=None),
- }
+ ) == {Table(table="other_table", schema=None, catalog=None)}
logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
- Table(table="table", schema=None, catalog=None),
- Table(table="other_table", schema=None, catalog=None),
+ Table(table="other_table", schema=None, catalog=None)
}
logger.warning.assert_not_called()