You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@superset.apache.org by be...@apache.org on 2024/03/14 21:28:50 UTC

(superset) branch sqlparse-annotations created (now e8bec0a706)

This is an automated email from the ASF dual-hosted git repository.

beto pushed a change to branch sqlparse-annotations
in repository https://gitbox.apache.org/repos/asf/superset.git


      at e8bec0a706 chore: add annotations to `sql_parse.py`

This branch includes the following new commits:

     new e8bec0a706 chore: add annotations to `sql_parse.py`

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



(superset) 01/01: chore: add annotations to `sql_parse.py`

Posted by be...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch sqlparse-annotations
in repository https://gitbox.apache.org/repos/asf/superset.git

commit e8bec0a706cf2174319d0b3e80ebe1aa47343514
Author: Beto Dealmeida <ro...@dealmeida.net>
AuthorDate: Thu Mar 14 17:28:20 2024 -0400

    chore: add annotations to `sql_parse.py`
---
 superset/sql_parse.py | 56 +++++++++++++++++++++++++--------------------------
 1 file changed, 28 insertions(+), 28 deletions(-)

diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 58dc210e2b..eeaecb3ad6 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -17,12 +17,14 @@
 
 # pylint: disable=too-many-lines
 
+from __future__ import annotations
+
 import logging
 import re
 import urllib.parse
 from collections.abc import Iterable, Iterator
 from dataclasses import dataclass
-from typing import Any, cast, Optional, Union
+from typing import Any, cast
 
 import sqlglot
 import sqlparse
@@ -138,7 +140,7 @@ class CtasMethod(StrEnum):
     VIEW = "VIEW"
 
 
-def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
+def _extract_limit_from_query(statement: TokenList) -> int | None:
     """
     Extract limit clause from SQL statement.
 
@@ -159,9 +161,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
     return None
 
 
-def extract_top_from_query(
-    statement: TokenList, top_keywords: set[str]
-) -> Optional[int]:
+def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
     """
     Extract top clause value from SQL statement.
 
@@ -185,7 +185,7 @@ def extract_top_from_query(
     return top
 
 
-def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
     """
     parse the SQL and return the CTE and rest of the block to the caller
 
@@ -193,7 +193,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     :return: CTE and remainder block to the caller
 
     """
-    cte: Optional[str] = None
+    cte: str | None = None
     remainder = sql
     stmt = sqlparse.parse(sql)[0]
 
@@ -211,7 +211,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
     return cte, remainder
 
 
-def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
+def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
     """
     Strips comments from a SQL statement, does a simple test first
     to avoid always instantiating the expensive ParsedQuery constructor
@@ -235,8 +235,8 @@ class Table:
     """
 
     table: str
-    schema: Optional[str] = None
-    catalog: Optional[str] = None
+    schema: str | None = None
+    catalog: str | None = None
 
     def __str__(self) -> str:
         """
@@ -255,7 +255,7 @@ class Table:
 
 def extract_tables_from_statement(
     statement: exp.Expression,
-    dialect: Optional[Dialects],
+    dialect: Dialects | None,
 ) -> set[Table]:
     """
     Extract all table references in a single statement.
@@ -334,7 +334,7 @@ class SQLScript:
     def __init__(
         self,
         query: str,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
 
@@ -375,8 +375,8 @@ class SQLStatement:
 
     def __init__(
         self,
-        statement: Union[str, exp.Expression],
-        engine: Optional[str] = None,
+        statement: str | exp.Expression,
+        engine: str | None = None,
     ):
         dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
 
@@ -394,7 +394,7 @@ class SQLStatement:
     @staticmethod
     def _parse_statement(
         sql_statement: str,
-        dialect: Optional[Dialects],
+        dialect: Dialects | None,
     ) -> exp.Expression:
         """
         Parse a single SQL statement.
@@ -437,7 +437,7 @@ class ParsedQuery:
         self,
         sql_statement: str,
         strip_comments: bool = False,
-        engine: Optional[str] = None,
+        engine: str | None = None,
     ):
         if strip_comments:
             sql_statement = sqlparse.format(sql_statement, strip_comments=True)
@@ -446,7 +446,7 @@ class ParsedQuery:
         self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
         self._tables: set[Table] = set()
         self._alias_names: set[str] = set()
-        self._limit: Optional[int] = None
+        self._limit: int | None = None
 
         logger.debug("Parsing with sqlparse statement: %s", self.sql)
         self._parsed = sqlparse.parse(self.stripped())
@@ -550,7 +550,7 @@ class ParsedQuery:
         return source.name in ctes_in_scope
 
     @property
-    def limit(self) -> Optional[int]:
+    def limit(self) -> int | None:
         return self._limit
 
     def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
@@ -631,7 +631,7 @@ class ParsedQuery:
 
         return True
 
-    def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
+    def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
         for token in tokens:
             if self._is_identifier(token):
                 for identifier_token in token.tokens:
@@ -695,7 +695,7 @@ class ParsedQuery:
         return statements
 
     @staticmethod
-    def get_table(tlist: TokenList) -> Optional[Table]:
+    def get_table(tlist: TokenList) -> Table | None:
         """
         Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
         construct.
@@ -731,7 +731,7 @@ class ParsedQuery:
     def as_create_table(
         self,
         table_name: str,
-        schema_name: Optional[str] = None,
+        schema_name: str | None = None,
         overwrite: bool = False,
         method: CtasMethod = CtasMethod.TABLE,
     ) -> str:
@@ -891,8 +891,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
 def get_rls_for_table(
     candidate: Token,
     database_id: int,
-    default_schema: Optional[str],
-) -> Optional[TokenList]:
+    default_schema: str | None,
+) -> TokenList | None:
     """
     Given a table name, return any associated RLS predicates.
     """
@@ -938,7 +938,7 @@ def get_rls_for_table(
 def insert_rls_as_subquery(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -954,7 +954,7 @@ def insert_rls_as_subquery(
     This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
     databases.
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1030,7 +1030,7 @@ def insert_rls_as_subquery(
 def insert_rls_in_predicate(
     token_list: TokenList,
     database_id: int,
-    default_schema: Optional[str],
+    default_schema: str | None,
 ) -> TokenList:
     """
     Update a statement inplace applying any associated RLS predicates.
@@ -1041,7 +1041,7 @@ def insert_rls_in_predicate(
         after:  SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
 
     """
-    rls: Optional[TokenList] = None
+    rls: TokenList | None = None
     state = InsertRLSState.SCANNING
     for token in token_list.tokens:
         # Recurse into child token list
@@ -1175,7 +1175,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
 
 def extract_table_references(
     sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> set["Table"]:
+) -> set[Table]:
     """
     Return all the dependencies from a SQL sql_text.
     """