You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/02/01 16:43:08 UTC

(superset) 01/09: feat(sqlparse): improve table parsing (#26476)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 3.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit fb6410004369c73bae12bb60c5b5ca7981c6cf49
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Mon Jan 22 11:16:50 2024 -0500

    feat(sqlparse): improve table parsing (#26476)
---
 requirements/base.txt                   |   2 +
 requirements/testing.txt                |   6 +-
 setup.py                                |   1 +
 superset/connectors/sqla/models.py      |   2 +-
 superset/connectors/sqla/utils.py       |   2 +-
 superset/datasets/commands/duplicate.py |   5 +-
 superset/db_engine_specs/base.py        |  12 +-
 superset/db_engine_specs/bigquery.py    |   2 +-
 superset/models/helpers.py              |   2 +-
 superset/models/sql_lab.py              |   6 +-
 superset/security/manager.py            |   5 +-
 superset/sql_lab.py                     |  11 +-
 superset/sql_parse.py                   | 245 +++++++++++++++++++++-----------
 superset/sql_validators/presto_db.py    |   4 +-
 superset/sqllab/commands/export.py      |   5 +-
 superset/sqllab/query_render.py         |   6 +-
 tests/unit_tests/sql_parse_tests.py     |  55 +++++--
 17 files changed, 253 insertions(+), 118 deletions(-)

diff --git a/requirements/base.txt b/requirements/base.txt
index c65c7ba77e..a34df1905d 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -283,6 +283,8 @@ sqlalchemy-utils==0.38.3
     # via
     #   apache-superset
     #   flask-appbuilder
+sqlglot==20.8.0
+    # via apache-superset
 sqlparse==0.4.4
     # via apache-superset
 sshtunnel==0.4.0
diff --git a/requirements/testing.txt b/requirements/testing.txt
index 0a4d826c8c..e41ac7d96f 100644
--- a/requirements/testing.txt
+++ b/requirements/testing.txt
@@ -24,10 +24,6 @@ db-dtypes==1.1.1
     # via pandas-gbq
 docker==6.1.1
     # via -r requirements/testing.in
-ephem==4.1.4
-    # via lunarcalendar
-exceptiongroup==1.1.1
-    # via pytest
 flask-testing==0.8.1
     # via -r requirements/testing.in
 fonttools==4.39.4
@@ -119,7 +115,7 @@ pyee==9.0.4
     # via playwright
 pyfakefs==5.2.2
     # via -r requirements/testing.in
-pyhive[presto]==0.6.5
+pyhive[presto]==0.7.0
     # via apache-superset
 pytest==7.3.1
     # via
diff --git a/setup.py b/setup.py
index 211f8c5a80..e428797094 100644
--- a/setup.py
+++ b/setup.py
@@ -121,6 +121,7 @@ setup(
         "slack_sdk>=3.19.0, <4",
         "sqlalchemy>=1.4, <2",
         "sqlalchemy-utils>=0.38.3, <0.39",
+        "sqlglot>=20,<21",
         "sqlparse>=0.4.4, <0.5",
         "tabulate>=0.8.9, <0.9",
         "typing-extensions>=4, <5",
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 5edc724b23..4bcf3fc344 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -848,7 +848,7 @@ class SqlaTable(
             return self.get_sqla_table(), None
 
         from_sql = self.get_rendered_sql(template_processor)
-        parsed_query = ParsedQuery(from_sql)
+        parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
         if not (
             parsed_query.is_unknown()
             or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index c8a5f9f260..3b64bec807 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -111,7 +111,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
     sql = dataset.get_template_processor().process_template(
         dataset.sql, **dataset.template_params_dict
     )
-    parsed_query = ParsedQuery(sql)
+    parsed_query = ParsedQuery(sql, engine=db_engine_spec.engine)
     if not db_engine_spec.is_readonly_query(parsed_query):
         raise SupersetSecurityException(
             SupersetError(
diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py
index 9fc05c0960..238f9a2391 100644
--- a/superset/datasets/commands/duplicate.py
+++ b/superset/datasets/commands/duplicate.py
@@ -69,7 +69,10 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
             table.template_params = self._base_model.template_params
             table.normalize_columns = self._base_model.normalize_columns
             table.is_sqllab_view = True
-            table.sql = ParsedQuery(self._base_model.sql).stripped()
+            table.sql = ParsedQuery(
+                self._base_model.sql,
+                engine=database.db_engine_spec.engine,
+            ).stripped()
             db.session.add(table)
             cols = []
             for config_ in self._base_model.columns:
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index c086ce27b7..be0b55b3e9 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -874,7 +874,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
             return database.compile_sqla_query(qry)
 
         if cls.limit_method == LimitMethod.FORCE_LIMIT:
-            parsed_query = sql_parse.ParsedQuery(sql)
+            parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
             sql = parsed_query.set_or_update_query_limit(limit, force=force)
 
         return sql
@@ -955,7 +955,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         :param sql: SQL query
         :return: Value of limit clause in query
         """
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         return parsed_query.limit
 
     @classmethod
@@ -967,7 +967,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         :param limit: New limit to insert/replace into query
         :return: Query with new limit
         """
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         return parsed_query.set_or_update_query_limit(limit)
 
     @classmethod
@@ -1450,7 +1450,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         :param database: Database instance
         :return: Dictionary with different costs
         """
-        parsed_query = ParsedQuery(statement)
+        parsed_query = ParsedQuery(statement, engine=cls.engine)
         sql = parsed_query.stripped()
         sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"]
         mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"]
@@ -1483,7 +1483,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         if not cls.get_allow_cost_estimate(extra):
             raise Exception("Database does not support cost estimation")
 
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         statements = parsed_query.get_statements()
 
         costs = []
@@ -1544,7 +1544,7 @@ class BaseEngineSpec:  # pylint: disable=too-many-public-methods
         :return:
         """
         if not cls.allows_sql_comments:
-            query = sql_parse.strip_comments_from_sql(query)
+            query = sql_parse.strip_comments_from_sql(query, engine=cls.engine)
 
         if cls.arraysize:
             cursor.arraysize = cls.arraysize
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 73b7b18d36..125760ec7f 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -436,7 +436,7 @@ class BigQueryEngineSpec(BaseEngineSpec):  # pylint: disable=too-many-public-met
         if not cls.get_allow_cost_estimate(extra):
             raise SupersetException("Database does not support cost estimation")
 
-        parsed_query = sql_parse.ParsedQuery(sql)
+        parsed_query = sql_parse.ParsedQuery(sql, engine=cls.engine)
         statements = parsed_query.get_statements()
         costs = []
         for statement in statements:
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index bf33451e34..f9b30c3d42 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -1085,7 +1085,7 @@ class ExploreMixin:  # pylint: disable=too-many-public-methods
         """
 
         from_sql = self.get_rendered_sql(template_processor)
-        parsed_query = ParsedQuery(from_sql)
+        parsed_query = ParsedQuery(from_sql, engine=self.db_engine_spec.engine)
         if not (
             parsed_query.is_unknown()
             or self.db_engine_spec.is_readonly_query(parsed_query)
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 20df535ad3..156add2529 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -183,7 +183,7 @@ class Query(
 
     @property
     def sql_tables(self) -> list[Table]:
-        return list(ParsedQuery(self.sql).tables)
+        return list(ParsedQuery(self.sql, engine=self.db_engine_spec.engine).tables)
 
     @property
     def columns(self) -> list["TableColumn"]:
@@ -428,7 +428,9 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
 
     @property
     def sql_tables(self) -> list[Table]:
-        return list(ParsedQuery(self.sql).tables)
+        return list(
+            ParsedQuery(self.sql, engine=self.database.db_engine_spec.engine).tables
+        )
 
     @property
     def last_run_humanized(self) -> str:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 88657e8f21..5c0833fdf9 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1855,7 +1855,10 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 default_schema = database.get_default_schema_for_query(query)
                 tables = {
                     Table(table_.table, table_.schema or default_schema)
-                    for table_ in sql_parse.ParsedQuery(query.sql).tables
+                    for table_ in sql_parse.ParsedQuery(
+                        query.sql,
+                        engine=database.db_engine_spec.engine,
+                    ).tables
                 }
             elif table:
                 tables = {table}
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index ca157b3240..6da1125458 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -203,7 +203,7 @@ def execute_sql_statement(  # pylint: disable=too-many-arguments
     database: Database = query.database
     db_engine_spec = database.db_engine_spec
 
-    parsed_query = ParsedQuery(sql_statement)
+    parsed_query = ParsedQuery(sql_statement, engine=db_engine_spec.engine)
     if is_feature_enabled("RLS_IN_SQLLAB"):
         # Insert any applicable RLS predicates
         parsed_query = ParsedQuery(
@@ -213,7 +213,8 @@ def execute_sql_statement(  # pylint: disable=too-many-arguments
                     database.id,
                     query.schema,
                 )
-            )
+            ),
+            engine=db_engine_spec.engine,
         )
 
     sql = parsed_query.stripped()
@@ -404,7 +405,11 @@ def execute_sql_statements(
         )
 
     # Breaking down into multiple statements
-    parsed_query = ParsedQuery(rendered_query, strip_comments=True)
+    parsed_query = ParsedQuery(
+        rendered_query,
+        strip_comments=True,
+        engine=db_engine_spec.engine,
+    )
     if not db_engine_spec.run_multiple_statements_as_one:
         statements = parsed_query.get_statements()
         logger.info(
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index c196fdabfa..080572eba1 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -14,15 +14,20 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import logging
 import re
-from collections.abc import Iterator
+import urllib.parse
+from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
 from typing import Any, cast, Optional
-from urllib import parse
 
 import sqlparse
 from sqlalchemy import and_
+from sqlglot import exp, parse, parse_one
+from sqlglot.dialects import Dialects
+from sqlglot.errors import ParseError
+from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
 from sqlparse import keywords
 from sqlparse.lexer import Lexer
 from sqlparse.sql import (
@@ -52,7 +57,7 @@ from superset.utils.backports import StrEnum
 
 try:
     from sqloxide import parse_sql as sqloxide_parse
-except:  # pylint: disable=bare-except
+except (ImportError, ModuleNotFoundError):
     sqloxide_parse = None
 
 RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
@@ -71,6 +76,59 @@ sqlparser_sql_regex.insert(25, (r"'(''|\\\\|\\|[^'])*'", sqlparse.tokens.String.
 lex.set_SQL_REGEX(sqlparser_sql_regex)
 
 
+# mapping between DB engine specs and sqlglot dialects
+SQLGLOT_DIALECTS = {
+    "ascend": Dialects.HIVE,
+    "awsathena": Dialects.PRESTO,
+    "bigquery": Dialects.BIGQUERY,
+    "clickhouse": Dialects.CLICKHOUSE,
+    "clickhousedb": Dialects.CLICKHOUSE,
+    "cockroachdb": Dialects.POSTGRES,
+    # "crate": ???
+    # "databend": ???
+    "databricks": Dialects.DATABRICKS,
+    # "db2": ???
+    # "dremio": ???
+    "drill": Dialects.DRILL,
+    # "druid": ???
+    "duckdb": Dialects.DUCKDB,
+    # "dynamodb": ???
+    # "elasticsearch": ???
+    # "exa": ???
+    # "firebird": ???
+    # "firebolt": ???
+    "gsheets": Dialects.SQLITE,
+    "hana": Dialects.POSTGRES,
+    "hive": Dialects.HIVE,
+    # "ibmi": ???
+    # "impala": ???
+    # "kustokql": ???
+    # "kylin": ???
+    # "mssql": ???
+    "mysql": Dialects.MYSQL,
+    "netezza": Dialects.POSTGRES,
+    # "ocient": ???
+    # "odelasticsearch": ???
+    "oracle": Dialects.ORACLE,
+    # "pinot": ???
+    "postgresql": Dialects.POSTGRES,
+    "presto": Dialects.PRESTO,
+    "pydoris": Dialects.DORIS,
+    "redshift": Dialects.REDSHIFT,
+    # "risingwave": ???
+    # "rockset": ???
+    "shillelagh": Dialects.SQLITE,
+    "snowflake": Dialects.SNOWFLAKE,
+    # "solr": ???
+    "sqlite": Dialects.SQLITE,
+    "starrocks": Dialects.STARROCKS,
+    "superset": Dialects.SQLITE,
+    "teradatasql": Dialects.TERADATA,
+    "trino": Dialects.TRINO,
+    "vertica": Dialects.POSTGRES,
+}
+
+
 class CtasMethod(StrEnum):
     TABLE = "TABLE"
     VIEW = "VIEW"
@@ -149,7 +207,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str) -> str:
+def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -159,7 +217,11 @@ def strip_comments_from_sql(statement: str) -> str:
     :param statement: A string with the SQL statement
     :return: SQL statement without comments
     """
-    return ParsedQuery(statement).strip_comments() if "--" in statement else statement
+    return (
+        ParsedQuery(statement, engine=engine).strip_comments()
+        if "--" in statement
+        else statement
+    )
 
 
 @dataclass(eq=True, frozen=True)
@@ -178,7 +240,7 @@ class Table:
         """
 
         return ".".join(
-            parse.quote(part, safe="").replace(".", "%2E")
+            urllib.parse.quote(part, safe="").replace(".", "%2E")
             for part in [self.catalog, self.schema, self.table]
             if part
         )
@@ -188,11 +250,17 @@ class Table:
 
 
 class ParsedQuery:
-    def __init__(self, sql_statement: str, strip_comments: bool = False):
+    def __init__(
+        self,
+        sql_statement: str,
+        strip_comments: bool = False,
+        engine: Optional[str] = None,
+    ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
 
         self.sql: str = sql_statement
+        self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
         self._limit: Optional[int] = None
@@ -205,14 +273,95 @@ class ParsedQuery:
     @property
     def tables(self) -> set[Table]:
         if not self._tables:
-            for statement in self._parsed:
-                self._extract_from_token(statement)
-
-            self._tables = {
-                table for table in self._tables if str(table) not in self._alias_names
-            }
+            self._tables = self._extract_tables_from_sql()
         return self._tables
 
+    def _extract_tables_from_sql(self) -> set[Table]:
+        """
+        Extract all table references in a query.
+
+        Note: this uses sqlglot, since it's better at catching more edge cases.
+        """
+        try:
+            statements = parse(self.sql, dialect=self._dialect)
+        except ParseError:
+            logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
+            return set()
+
+        return {
+            table
+            for statement in statements
+            for table in self._extract_tables_from_statement(statement)
+            if statement
+        }
+
+    def _extract_tables_from_statement(self, statement: exp.Expression) -> set[Table]:
+        """
+        Extract all table references in a single statement.
+
+        Please not that this is not trivial; consider the following queries:
+
+            DESCRIBE some_table;
+            SHOW PARTITIONS FROM some_table;
+            WITH masked_name AS (SELECT * FROM some_table) SELECT * FROM masked_name;
+
+        See the unit tests for other tricky cases.
+        """
+        sources: Iterable[exp.Table]
+
+        if isinstance(statement, exp.Describe):
+            # A `DESCRIBE` query has no sources in sqlglot, so we need to explicitly
+            # query for all tables.
+            sources = statement.find_all(exp.Table)
+        elif isinstance(statement, exp.Command):
+            # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
+            # `SELECT` statetement in order to extract tables.
+            literal = statement.find(exp.Literal)
+            if not literal:
+                return set()
+
+            pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
+            sources = pseudo_query.find_all(exp.Table)
+        else:
+            sources = [
+                source
+                for scope in traverse_scope(statement)
+                for source in scope.sources.values()
+                if isinstance(source, exp.Table) and not self._is_cte(source, scope)
+            ]
+
+        return {
+            Table(
+                source.name,
+                source.db if source.db != "" else None,
+                source.catalog if source.catalog != "" else None,
+            )
+            for source in sources
+        }
+
+    # pylint: disable=no-self-use
+    def _is_cte(self, source: exp.Table, scope: Scope) -> bool:
+        """
+        Is the source a CTE?
+
+        CTEs in the parent scope look like tables (and are represented by
+        exp.Table objects), but should not be considered as such;
+        otherwise a user with access to table `foo` could access any table
+        with a query like this:
+
+            WITH foo AS (SELECT * FROM target_table) SELECT * FROM foo
+
+        """
+        parent_sources = scope.parent.sources if scope.parent else {}
+        ctes_in_scope = {
+            name
+            for name, parent_scope in parent_sources.items()
+            if isinstance(parent_scope, Scope)
+            and parent_scope.scope_type == ScopeType.CTE
+        }
+
+        return source.name in ctes_in_scope
+
     @property
     def limit(self) -> Optional[int]:
         return self._limit
@@ -393,28 +542,6 @@ class ParsedQuery:
     def _is_identifier(token: Token) -> bool:
         return isinstance(token, (IdentifierList, Identifier))
 
-    def _process_tokenlist(self, token_list: TokenList) -> None:
-        """
-        Add table names to table set
-
-        :param token_list: TokenList to be processed
-        """
-        # exclude subselects
-        if "(" not in str(token_list):
-            table = self.get_table(token_list)
-            if table and not table.table.startswith(CTE_PREFIX):
-                self._tables.add(table)
-            return
-
-        # store aliases
-        if token_list.has_alias():
-            self._alias_names.add(token_list.get_alias())
-
-        # some aliases are not parsed properly
-        if token_list.tokens[0].ttype == Name:
-            self._alias_names.add(token_list.tokens[0].value)
-        self._extract_from_token(token_list)
-
     def as_create_table(
         self,
         table_name: str,
@@ -441,50 +568,6 @@ class ParsedQuery:
         exec_sql += f"CREATE {method} {full_table_name} AS \n{sql}"
         return exec_sql
 
-    def _extract_from_token(self, token: Token) -> None:
-        """
-        <Identifier> store a list of subtokens and <IdentifierList> store lists of
-        subtoken list.
-
-        It extracts <IdentifierList> and <Identifier> from :param token: and loops
-        through all subtokens recursively. It finds table_name_preceding_token and
-        passes <IdentifierList> and <Identifier> to self._process_tokenlist to populate
-        self._tables.
-
-        :param token: instance of Token or child class, e.g. TokenList, to be processed
-        """
-        if not hasattr(token, "tokens"):
-            return
-
-        table_name_preceding_token = False
-
-        for item in token.tokens:
-            if item.is_group and (
-                not self._is_identifier(item) or isinstance(item.tokens[0], Parenthesis)
-            ):
-                self._extract_from_token(item)
-
-            if item.ttype in Keyword and (
-                item.normalized in PRECEDES_TABLE_NAME
-                or item.normalized.endswith(" JOIN")
-            ):
-                table_name_preceding_token = True
-                continue
-
-            if item.ttype in Keyword:
-                table_name_preceding_token = False
-                continue
-            if table_name_preceding_token:
-                if isinstance(item, Identifier):
-                    self._process_tokenlist(item)
-                elif isinstance(item, IdentifierList):
-                    for token2 in item.get_identifiers():
-                        if isinstance(token2, TokenList):
-                            self._process_tokenlist(token2)
-            elif isinstance(item, IdentifierList):
-                if any(not self._is_identifier(token2) for token2 in item.tokens):
-                    self._extract_from_token(item)
-
     def set_or_update_query_limit(self, new_limit: int, force: bool = False) -> str:
         """Returns the query with the specified limit.
 
@@ -779,7 +862,7 @@ def insert_rls(
 
 
 # mapping between sqloxide and SQLAlchemy dialects
-SQLOXITE_DIALECTS = {
+SQLOXIDE_DIALECTS = {
     "ansi": {"trino", "trinonative", "presto"},
     "hive": {"hive", "databricks"},
     "ms": {"mssql"},
@@ -812,7 +895,7 @@ def extract_table_references(
     tree = None
 
     if sqloxide_parse:
-        for dialect, sqla_dialects in SQLOXITE_DIALECTS.items():
+        for dialect, sqla_dialects in SQLOXIDE_DIALECTS.items():
             if sqla_dialect in sqla_dialects:
                 break
         sql_text = RE_JINJA_BLOCK.sub(" ", sql_text)
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index 9d3e7641a6..20b9a8eb98 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -50,7 +50,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
     ) -> Optional[SQLValidationAnnotation]:
         # pylint: disable=too-many-locals
         db_engine_spec = database.db_engine_spec
-        parsed_query = ParsedQuery(statement)
+        parsed_query = ParsedQuery(statement, engine=db_engine_spec.engine)
         sql = parsed_query.stripped()
 
         # Hook to allow environment-specific mutation (usually comments) to the SQL
@@ -156,7 +156,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
         For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE
         VALIDATE) SELECT 1 FROM default.mytable.
         """
-        parsed_query = ParsedQuery(sql)
+        parsed_query = ParsedQuery(sql, engine=database.db_engine_spec.engine)
         statements = parsed_query.get_statements()
 
         logger.info("Validating %i statement(s)", len(statements))
diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py
index 1b9b0e0344..aa6050f27f 100644
--- a/superset/sqllab/commands/export.py
+++ b/superset/sqllab/commands/export.py
@@ -115,7 +115,10 @@ class SqlResultExportCommand(BaseCommand):
                 limit = None
             else:
                 sql = self._query.executed_sql
-                limit = ParsedQuery(sql).limit
+                limit = ParsedQuery(
+                    sql,
+                    engine=self._query.database.db_engine_spec.engine,
+                ).limit
             if limit is not None and self._query.limiting_factor in {
                 LimitingFactor.QUERY,
                 LimitingFactor.DROPDOWN,
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 95111276fe..5f846ef3b0 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -58,7 +58,11 @@ class SqlQueryRenderImpl(SqlQueryRender):
                 database=query_model.database, query=query_model
             )
 
-            parsed_query = ParsedQuery(query_model.sql, strip_comments=True)
+            parsed_query = ParsedQuery(
+                query_model.sql,
+                strip_comments=True,
+                engine=query_model.database.db_engine_spec.engine,
+            )
             rendered_query = sql_template_processor.process_template(
                 parsed_query.stripped(), **execution_context.template_params
             )
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 341ba9d789..52f0597ce3 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -39,11 +39,11 @@ from superset.sql_parse import (
 )
 
 
-def extract_tables(query: str) -> set[Table]:
+def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
     """
     Helper function to extract tables referenced in a query.
     """
-    return ParsedQuery(query).tables
+    return ParsedQuery(query, engine=engine).tables
 
 
 def test_table() -> None:
@@ -95,8 +95,13 @@ def test_extract_tables() -> None:
         Table("left_table")
     }
 
-    # reverse select
-    assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
+    assert extract_tables(
+        "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
+    ) == {Table("forbidden_table")}
+
+    assert extract_tables(
+        "select * from (select * from forbidden_table) forbidden_table"
+    ) == {Table("forbidden_table")}
 
 
 def test_extract_tables_subselect() -> None:
@@ -262,14 +267,16 @@ def test_extract_tables_illdefined() -> None:
     assert extract_tables("SELECT * FROM schemaname.") == set()
     assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
     assert extract_tables("SELECT * FROM catalogname..") == set()
-    assert extract_tables("SELECT * FROM catalogname..tbname") == set()
+    assert extract_tables("SELECT * FROM catalogname..tbname") == {
+        Table(table="tbname", schema=None, catalog="catalogname")
+    }
 
 
 def test_extract_tables_show_tables_from() -> None:
     """
     Test ``SHOW TABLES FROM``.
     """
-    assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
+    assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
 
 
 def test_extract_tables_show_columns_from() -> None:
@@ -310,7 +317,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
             """
 SELECT name
 FROM t1
-WHERE regionkey EXISTS (SELECT regionkey FROM t2)
+WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
 """
         )
         == {Table("t1"), Table("t2")}
@@ -525,6 +532,18 @@ select * from (select key from q1) a
         == {Table("src")}
     )
 
+    # weird query with circular dependency
+    assert (
+        extract_tables(
+            """
+with src as ( select key from q2 where key = '5'),
+q2 as ( select key from src where key = '5')
+select * from (select key from src) a
+"""
+        )
+        == set()
+    )
+
 
 def test_extract_tables_multistatement() -> None:
     """
@@ -664,7 +683,8 @@ def test_extract_tables_nested_select() -> None:
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
-"""
+""",
+            "mysql",
         )
         == {Table("COLUMNS", "INFORMATION_SCHEMA")}
     )
@@ -675,7 +695,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
 select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
 from INFORMATION_SCHEMA.COLUMNS
 WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
-"""
+""",
+            "mysql",
         )
         == {Table("COLUMNS", "INFORMATION_SCHEMA")}
     )
@@ -1305,6 +1326,14 @@ def test_sqlparse_issue_652():
             "(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
             True,
         ),
+        (
+            "SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
+            True,
+        ),
+        (
+            "SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
+            True,
+        ),
     ],
 )
 def test_has_table_query(sql: str, expected: bool) -> None:
@@ -1607,13 +1636,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
     assert extract_table_references(
         sql,
         "trino",
-    ) == {Table(table="other_table", schema=None, catalog=None)}
+    ) == {
+        Table(table="table", schema=None, catalog=None),
+        Table(table="other_table", schema=None, catalog=None),
+    }
     logger.warning.assert_called_once()
 
     logger = mocker.patch("superset.migrations.shared.utils.logger")
     sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
     assert extract_table_references(sql, "trino", show_warning=False) == {
-        Table(table="other_table", schema=None, catalog=None)
+        Table(table="table", schema=None, catalog=None),
+        Table(table="other_table", schema=None, catalog=None),
     }
     logger.warning.assert_not_called()