You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/14 21:28:51 UTC

(superset) 01/01: chore: add annotations to `sql_parse.py`

This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sqlparse-annotations
in repository https://gitbox.apache.org/repos/asf/superset.git

commit e8bec0a706cf2174319d0b3e80ebe1aa47343514
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 14 17:28:20 2024 -0400

    chore: add annotations to `sql_parse.py`
---
 superset/sql_parse.py | 56 +++++++++++++++++++++++++--------------------------
 1 file changed, 28 insertions(+), 28 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58dc210e2b..eeaecb3ad6 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,14 @@
 
 # pylint: disable=too-many-lines
 
+from __future__ import annotations
+
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast
 
 import sqlglot
 import sqlparse
@@ -138,7 +140,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -185,7 +185,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -193,7 +193,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +235,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -255,7 +255,7 @@ class Table:
 
 def extract_tables_from_statement(
     statement: exp.Expression,
-    dialect: Optional[Dialects],
+    dialect: Dialects | None,
 ) -> set[Table]:
     """
     Extract all table references in a single statement.
@@ -334,7 +334,7 @@ class SQLScript:
     def __init__(
         self,
         query: str,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
 
@@ -375,8 +375,8 @@ class SQLStatement:
 
     def __init__(
         self,
-        statement: Union[str, exp.Expression],
-        engine: Optional[str] = None,
+        statement: str | exp.Expression,
+        engine: str | None = None,
     ):
         dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
 
@@ -394,7 +394,7 @@ class SQLStatement:
     @staticmethod
     def _parse_statement(
         sql_statement: str,
-        dialect: Optional[Dialects],
+        dialect: Dialects | None,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
@@ -437,7 +437,7 @@ class ParsedQuery:
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -446,7 +446,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -550,7 +550,7 @@ class ParsedQuery:
         return source.name in ctes_in_scope
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -631,7 +631,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -695,7 +695,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -731,7 +731,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -938,7 +938,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -954,7 +954,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1030,7 +1030,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -1041,7 +1041,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1175,7 +1175,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """