You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by mi...@apache.org on 2024/03/26 19:37:54 UTC

(superset) 01/02: fix: Leverage actual database for rendering Jinjarized SQL (#27646)

This is an automated email from the ASF dual-hosted git repository.

michaelsmolina pushed a commit to branch 4.0
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 51ad63426c3ad3e227b3393cbc377621f98f8f89
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Wed Mar 27 08:12:25 2024 +1300

    fix: Leverage actual database for rendering Jinjarized SQL (#27646)
---
 superset/models/sql_lab.py          |  2 +-
 superset/security/manager.py        |  4 +---
 superset/sql_parse.py               | 15 ++++++++-------
 tests/unit_tests/sql_parse_tests.py |  6 +++++-
 4 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index 2d7384a74e..f22d774e88 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -73,7 +73,7 @@ class SqlTablesMixin:  # pylint: disable=too-few-public-methods
             return list(
                 extract_tables_from_jinja_sql(
                     self.sql,  # type: ignore
-                    self.database.db_engine_spec.engine,  # type: ignore
+                    self.database,  # type: ignore
                 )
             )
         except SupersetSecurityException:
diff --git a/superset/security/manager.py b/superset/security/manager.py
index 2833e88645..e5a32e97a7 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -1963,9 +1963,7 @@ class SupersetSecurityManager(  # pylint: disable=too-many-public-methods
                 default_schema = database.get_default_schema_for_query(query)
                 tables = {
                     Table(table_.table, table_.schema or default_schema)
-                    for table_ in extract_tables_from_jinja_sql(
-                        query.sql, database.db_engine_spec.engine
-                    )
+                    for table_ in extract_tables_from_jinja_sql(query.sql, database)
                 }
             elif table:
                 tables = {table}
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58bca48a6e..6df5dbc089 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -23,8 +23,7 @@ import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast
-from unittest.mock import Mock
+from typing import Any, cast, TYPE_CHECKING
 
 import sqlparse
 from flask_babel import gettext as __
@@ -71,6 +70,9 @@ try:
 except (ImportError, ModuleNotFoundError):
     sqloxide_parse = None
 
+if TYPE_CHECKING:
+    from superset.models.core import Database
+
 RESULT_OPERATIONS = {"UNION", "INTERSECT", "EXCEPT", "SELECT"}
 ON_KEYWORD = "ON"
 PRECEDES_TABLE_NAME = {"FROM", "JOIN", "DESCRIBE", "WITH", "LEFT JOIN", "RIGHT JOIN"}
@@ -1054,7 +1056,7 @@ def extract_table_references(
     }
 
 
-def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
+def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
     """
     Extract all table references in the Jinjafied SQL statement.
 
@@ -1067,7 +1069,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
     SQLGlot.
 
     :param sql: The Jinjafied SQL statement
-    :param engine: The associated database engine
+    :param database: The database associated with the SQL statement
     :returns: The set of tables referenced in the SQL statement
     :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
     """
@@ -1076,8 +1078,7 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
         get_template_processor,
     )
 
-    # Mock the required database as the processor signature is exposed publically.
-    processor = get_template_processor(database=Mock(backend=engine))
+    processor = get_template_processor(database)
     template = processor.env.parse(sql)
 
     tables = set()
@@ -1107,6 +1108,6 @@ def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Ta
         tables
         | ParsedQuery(
             sql_statement=processor.process_template(template),
-            engine=engine,
+            engine=database.db_engine_spec.engine,
         ).tables
     )
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index 81ea0e5a7a..dab5dbf9c7 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name, redefined-outer-name, too-many-lines
 
 from typing import Optional
+from unittest.mock import Mock
 
 import pytest
 import sqlparse
@@ -1912,6 +1913,9 @@ def test_extract_tables_from_jinja_sql(
     expected: set[Table],
 ) -> None:
     assert (
-        extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine)
+        extract_tables_from_jinja_sql(
+            sql=sql.format(engine=engine, macro=macro),
+            database=Mock(),
+        )
         == expected
     )