You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by jo...@apache.org on 2024/02/23 19:47:42 UTC

(superset) branch master updated: fix(sqlglot): Address regressions introduced in #26476 (#27217)

This is an automated email from the ASF dual-hosted git repository.

johnbodley pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 2c564817f1 fix(sqlglot): Address regressions introduced in #26476 (#27217)
2c564817f1 is described below

commit 2c564817f1978e34770e02034a7a4c02e1bfdc9f
Author: John Bodley <45...@users.noreply.github.com>
AuthorDate: Sat Feb 24 08:47:36 2024 +1300

    fix(sqlglot): Address regressions introduced in #26476 (#27217)
---
 superset/sql_parse.py               | 17 +++++++++++------
 tests/unit_tests/sql_parse_tests.py | 10 ++++++----
 2 files changed, 17 insertions(+), 10 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 7b89ab8f0e..c85afc9460 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -28,7 +28,7 @@ import sqlparse
 from sqlalchemy import and_
 from sqlglot import exp, parse, parse_one
 from sqlglot.dialects import Dialects
-from sqlglot.errors import ParseError
+from sqlglot.errors import SqlglotError
 from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
 from sqlparse import keywords
 from sqlparse.lexer import Lexer
@@ -287,7 +287,7 @@ class ParsedQuery:
         """
         try:
             statements = parse(self.stripped(), dialect=self._dialect)
-        except ParseError:
+        except SqlglotError:
             logger.warning("Unable to parse SQL (%s): %s", self._dialect, self.sql)
             return set()
 
@@ -319,12 +319,17 @@ class ParsedQuery:
         elif isinstance(statement, exp.Command):
             # Commands, like `SHOW COLUMNS FROM foo`, have to be converted into a
             # `SELECT` statetement in order to extract tables.
-            literal = statement.find(exp.Literal)
-            if not literal:
+            if not (literal := statement.find(exp.Literal)):
                 return set()
 
-            pseudo_query = parse_one(f"SELECT {literal.this}", dialect=self._dialect)
-            sources = pseudo_query.find_all(exp.Table)
+            try:
+                pseudo_query = parse_one(
+                    f"SELECT {literal.this}",
+                    dialect=self._dialect,
+                )
+                sources = pseudo_query.find_all(exp.Table)
+            except SqlglotError:
+                return set()
         else:
             sources = [
                 source
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index f05e16ae85..2fd23f7e8e 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -271,6 +271,7 @@ def test_extract_tables_illdefined() -> None:
     assert extract_tables("SELECT * FROM catalogname..tbname") == {
         Table(table="tbname", schema=None, catalog="catalogname")
     }
+    assert extract_tables('SELECT * FROM "tbname') == set()
 
 
 def test_extract_tables_show_tables_from() -> None:
@@ -558,6 +559,10 @@ def test_extract_tables_multistatement() -> None:
         Table("t1"),
         Table("t2"),
     }
+    assert extract_tables(
+        "ADD JAR file:///hive.jar; SELECT * FROM t1;",
+        engine="hive",
+    ) == {Table("t1")}
 
 
 def test_extract_tables_complex() -> None:
@@ -1815,10 +1820,7 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
     # test falling back to sqlparse
     logger = mocker.patch("superset.sql_parse.logger")
     sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
-    assert extract_table_references(
-        sql,
-        "trino",
-    ) == {
+    assert extract_table_references(sql, "trino") == {
         Table(table="table", schema=None, catalog=None),
         Table(table="other_table", schema=None, catalog=None),
     }