You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/07/22 16:38:12 UTC
[airflow] branch main updated: Unify DbApiHook.run() method with the methods which override it (#23971)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df00436569 Unify DbApiHook.run() method with the methods which override it (#23971)
df00436569 is described below
commit df00436569bb6fb79ce8c0b7ca71dddf02b854ef
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Fri Jul 22 19:38:02 2022 +0300
Unify DbApiHook.run() method with the methods which override it (#23971)
---
airflow/operators/sql.py | 2 +-
.../providers/amazon/aws/operators/redshift_sql.py | 4 +-
.../amazon/aws/transfers/redshift_to_s3.py | 2 +-
.../amazon/aws/transfers/s3_to_redshift.py | 4 +-
airflow/providers/apache/drill/operators/drill.py | 8 +-
airflow/providers/apache/drill/provider.yaml | 1 -
airflow/providers/apache/pinot/hooks/pinot.py | 6 +-
airflow/providers/common/sql/hooks/sql.py | 68 +++++++++++---
airflow/providers/common/sql/provider.yaml | 3 +-
.../providers/databricks/hooks/databricks_sql.py | 76 ++++++++-------
.../databricks/operators/databricks_sql.py | 12 ++-
airflow/providers/exasol/hooks/exasol.py | 64 +++++++------
airflow/providers/exasol/operators/exasol.py | 6 +-
.../providers/google/cloud/operators/bigquery.py | 2 +-
.../providers/google/cloud/operators/cloud_sql.py | 6 +-
.../google/suite/transfers/sql_to_sheets.py | 2 +-
airflow/providers/jdbc/operators/jdbc.py | 14 +--
.../providers/microsoft/mssql/operators/mssql.py | 4 +-
airflow/providers/mysql/operators/mysql.py | 6 +-
airflow/providers/neo4j/operators/neo4j.py | 2 +-
airflow/providers/oracle/operators/oracle.py | 6 +-
airflow/providers/postgres/operators/postgres.py | 6 +-
airflow/providers/presto/hooks/presto.py | 45 +++++----
airflow/providers/snowflake/hooks/snowflake.py | 52 ++++++-----
airflow/providers/snowflake/operators/snowflake.py | 19 ++--
airflow/providers/sqlite/operators/sqlite.py | 6 +-
airflow/providers/trino/hooks/trino.py | 103 +++++++++------------
airflow/providers/vertica/operators/vertica.py | 4 +-
generated/provider_dependencies.json | 7 +-
.../databricks/hooks/test_databricks_sql.py | 11 ++-
.../databricks/operators/test_databricks_sql.py | 9 +-
tests/providers/jdbc/operators/test_jdbc.py | 3 +-
tests/providers/oracle/hooks/test_oracle.py | 8 +-
33 files changed, 307 insertions(+), 264 deletions(-)
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
index cb8b664875..fbce9b85a1 100644
--- a/airflow/operators/sql.py
+++ b/airflow/operators/sql.py
@@ -496,7 +496,7 @@ class BranchSQLOperator(BaseSQLOperator, SkipMixin):
follow_task_ids_if_false: List[str],
conn_id: str = "default_conn_id",
database: Optional[str] = None,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py
index c7ad77acb5..aa324b40bc 100644
--- a/airflow/providers/amazon/aws/operators/redshift_sql.py
+++ b/airflow/providers/amazon/aws/operators/redshift_sql.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
@@ -55,7 +55,7 @@ class RedshiftSQLOperator(BaseOperator):
*,
sql: Union[str, Iterable[str]],
redshift_conn_id: str = 'redshift_default',
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
**kwargs,
) -> None:
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index 0bfdda44e7..f5a3cd9bf6 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -93,7 +93,7 @@ class RedshiftToS3Operator(BaseOperator):
unload_options: Optional[List] = None,
autocommit: bool = False,
include_header: bool = False,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
**kwargs,
) -> None:
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 014e23ec07..747c97e1f1 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -16,7 +16,7 @@
# under the License.
import warnings
-from typing import TYPE_CHECKING, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -140,7 +140,7 @@ class S3ToRedshiftOperator(BaseOperator):
copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)
- sql: Union[list, str]
+ sql: Union[str, Iterable[str]]
if self.method == 'REPLACE':
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py
index 791ed546c3..6dad45cc3c 100644
--- a/airflow/providers/apache/drill/operators/drill.py
+++ b/airflow/providers/apache/drill/operators/drill.py
@@ -17,8 +17,6 @@
# under the License.
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
-import sqlparse
-
from airflow.models import BaseOperator
from airflow.providers.apache.drill.hooks.drill import DrillHook
@@ -52,7 +50,7 @@ class DrillOperator(BaseOperator):
*,
sql: str,
drill_conn_id: str = 'drill_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -64,6 +62,4 @@ class DrillOperator(BaseOperator):
def execute(self, context: 'Context'):
self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
- sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True))
- no_term_sql = [s[:-1] for s in sql if s[-1] == ';']
- self.hook.run(no_term_sql, parameters=self.parameters)
+ self.hook.run(self.sql, parameters=self.parameters, split_statements=True)
diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml
index 33235850b8..0e26ae5186 100644
--- a/airflow/providers/apache/drill/provider.yaml
+++ b/airflow/providers/apache/drill/provider.yaml
@@ -34,7 +34,6 @@ dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql
- sqlalchemy-drill>=1.1.0
- - sqlparse>=0.4.1
integrations:
- integration-name: Apache Drill
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py
index fa31b9f33d..794646e46d 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -18,7 +18,7 @@
import os
import subprocess
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import Any, Iterable, List, Mapping, Optional, Union
from pinotdb import connect
@@ -275,7 +275,7 @@ class PinotDbApiHook(DbApiHook):
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'
- def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
+ def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns a set of records.
@@ -287,7 +287,7 @@ class PinotDbApiHook(DbApiHook):
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
+ def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
index efd4a9dcfe..e6687fa938 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -17,8 +17,9 @@
import warnings
from contextlib import closing
from datetime import datetime
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union
+import sqlparse
from sqlalchemy import create_engine
from typing_extensions import Protocol
@@ -27,6 +28,17 @@ from airflow.hooks.base import BaseHook
from airflow.providers_manager import ProvidersManager
from airflow.utils.module_loading import import_string
+if TYPE_CHECKING:
+ from sqlalchemy.engine import CursorResult
+
+
+def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]:
+ """Handler for DbApiHook.run() to return results"""
+ if cursor.returns_rows:
+ return cursor.fetchall()
+ else:
+ return None
+
def _backported_get_hook(connection, *, hook_params=None):
"""Return hook based on conn_type
@@ -201,7 +213,31 @@ class DbApiHook(BaseHook):
cur.execute(sql)
return cur.fetchone()
- def run(self, sql, autocommit=False, parameters=None, handler=None):
+ @staticmethod
+ def strip_sql_string(sql: str) -> str:
+ return sql.strip().rstrip(';')
+
+ @staticmethod
+ def split_sql_string(sql: str) -> List[str]:
+ """
+ Splits string into multiple SQL expressions
+
+ :param sql: SQL string potentially consisting of multiple expressions
+ :return: list of individual expressions
+ """
+ splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
+ statements = [s.rstrip(';') for s in splits if s.endswith(';')]
+ return statements
+
+ def run(
+ self,
+ sql: Union[str, Iterable[str]],
+ autocommit: bool = False,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Optional[Union[Any, List[Any]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -213,14 +249,19 @@ class DbApiHook(BaseHook):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
- :return: query results if handler was provided.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the ALL SQL expressions if handler was provided.
"""
- scalar = isinstance(sql, str)
- if scalar:
- sql = [sql]
+ scalar_return_last = isinstance(sql, str) and return_last
+ if isinstance(sql, str):
+ if split_statements:
+ sql = self.split_sql_string(sql)
+ else:
+ sql = [self.strip_sql_string(sql)]
if sql:
- self.log.debug("Executing %d statements", len(sql))
+ self.log.debug("Executing following statements against DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")
@@ -232,22 +273,21 @@ class DbApiHook(BaseHook):
results = []
for sql_statement in sql:
self._run_command(cur, sql_statement, parameters)
+
if handler is not None:
result = handler(cur)
results.append(result)
- # If autocommit was set to False for db that supports autocommit,
- # or if db does not supports autocommit, we do a manual commit.
+ # If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()
if handler is None:
return None
-
- if scalar:
- return results[0]
-
- return results
+ elif scalar_return_last:
+ return results[-1]
+ else:
+ return results
def _run_command(self, cur, sql_statement, parameters):
"""Runs a statement using an already open cursor."""
diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml
index a277f327cc..39c8d483e4 100644
--- a/airflow/providers/common/sql/provider.yaml
+++ b/airflow/providers/common/sql/provider.yaml
@@ -24,7 +24,8 @@ description: |
versions:
- 1.0.0
-dependencies: []
+dependencies:
+ - sqlparse>=0.4.2
additional-extras:
- name: pandas
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py
index 6c5800170d..7a888438e9 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
-import re
from contextlib import closing
from copy import copy
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
@@ -139,19 +138,15 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
)
return self._sql_conn
- @staticmethod
- def maybe_split_sql_string(sql: str) -> List[str]:
- """
- Splits strings consisting of multiple SQL expressions into an
- TODO: do we need something more sophisticated?
-
- :param sql: SQL string potentially consisting of multiple expressions
- :return: list of individual expressions
- """
- splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""]
- return splits
-
- def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None):
+ def run(
+ self,
+ sql: Union[str, Iterable[str]],
+ autocommit: bool = False,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ handler: Optional[Callable] = None,
+ split_statements: bool = True,
+ return_last: bool = True,
+ ) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -163,41 +158,44 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
- :return: query results.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the LAST SQL expression if handler was provided.
"""
+ scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
- sql = self.maybe_split_sql_string(sql)
+ if split_statements:
+ sql = self.split_sql_string(sql)
+ else:
+ sql = [self.strip_sql_string(sql)]
if sql:
- self.log.debug("Executing %d statements", len(sql))
+ self.log.debug("Executing following statements against Databricks DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")
- conn = None
+ results = []
for sql_statement in sql:
# when using AAD tokens, it could expire if previous query run longer than token lifetime
- conn = self.get_conn()
- with closing(conn.cursor()) as cur:
- self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters)
- if parameters:
- cur.execute(sql_statement, parameters)
- else:
- cur.execute(sql_statement)
- schema = cur.description
- results = []
- if handler is not None:
- cur = handler(cur)
- for row in cur:
- self.log.debug("Statement results: %s", row)
- results.append(row)
-
- self.log.info("Rows affected: %s", cur.rowcount)
- if conn:
- conn.close()
+ with closing(self.get_conn()) as conn:
+ self.set_autocommit(conn, autocommit)
+
+ with closing(conn.cursor()) as cur:
+ self._run_command(cur, sql_statement, parameters)
+
+ if handler is not None:
+ result = handler(cur)
+ schema = cur.description
+ results.append((schema, result))
+
self._sql_conn = None
- # Return only result of the last SQL expression
- return schema, results
+ if handler is None:
+ return None
+ elif scalar_return_last:
+ return results[-1]
+ else:
+ return results
def test_connection(self):
"""Test the Databricks SQL connection by running a simple query."""
diff --git a/airflow/providers/databricks/operators/databricks_sql.py b/airflow/providers/databricks/operators/databricks_sql.py
index 9e6298bc21..ad4add30fe 100644
--- a/airflow/providers/databricks/operators/databricks_sql.py
+++ b/airflow/providers/databricks/operators/databricks_sql.py
@@ -20,12 +20,13 @@
import csv
import json
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast
from databricks.sql.utils import ParamEscaper
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
if TYPE_CHECKING:
@@ -71,11 +72,11 @@ class DatabricksSqlOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
http_path: Optional[str] = None,
sql_endpoint_name: Optional[str] = None,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
session_configuration=None,
http_headers: Optional[List[Tuple[str, str]]] = None,
catalog: Optional[str] = None,
@@ -147,10 +148,11 @@ class DatabricksSqlOperator(BaseOperator):
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")
- def execute(self, context: 'Context') -> Any:
+ def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = self._get_hook()
- schema, results = hook.run(self.sql, parameters=self.parameters)
+ response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler)
+ schema, results = cast(List[Tuple[Any, Any]], response)[0]
# self.log.info('Schema: %s', schema)
# self.log.info('Results: %s', results)
self._format_output(schema, results)
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index 784c57cde0..537f2fcb6d 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -17,7 +17,7 @@
# under the License.
from contextlib import closing
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import pandas as pd
import pyexasol
@@ -64,9 +64,7 @@ class ExasolHook(DbApiHook):
conn = pyexasol.connect(**conn_args)
return conn
- def get_pandas_df(
- self, sql: Union[str, list], parameters: Optional[dict] = None, **kwargs
- ) -> pd.DataFrame:
+ def get_pandas_df(self, sql: str, parameters: Optional[dict] = None, **kwargs) -> pd.DataFrame:
"""
Executes the sql and returns a pandas dataframe
@@ -79,9 +77,7 @@ class ExasolHook(DbApiHook):
df = conn.export_to_pandas(sql, query_params=parameters, **kwargs)
return df
- def get_records(
- self, sql: Union[str, list], parameters: Optional[dict] = None
- ) -> List[Union[dict, Tuple[Any, ...]]]:
+ def get_records(self, sql: str, parameters: Optional[dict] = None) -> List[Union[dict, Tuple[Any, ...]]]:
"""
Executes the sql and returns a set of records.
@@ -93,7 +89,7 @@ class ExasolHook(DbApiHook):
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()
- def get_first(self, sql: Union[str, list], parameters: Optional[dict] = None) -> Optional[Any]:
+ def get_first(self, sql: str, parameters: Optional[dict] = None) -> Optional[Any]:
"""
Executes the sql and returns the first resulting row.
@@ -133,8 +129,14 @@ class ExasolHook(DbApiHook):
self.log.info("Data saved to %s", filename)
def run(
- self, sql: Union[str, list], autocommit: bool = False, parameters: Optional[dict] = None, handler=None
- ) -> Optional[list]:
+ self,
+ sql: Union[str, Iterable[str]],
+ autocommit: bool = False,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
+ handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Optional[Union[Any, List[Any]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -146,38 +148,44 @@ class ExasolHook(DbApiHook):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the LAST SQL expression if handler was provided.
"""
+ scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
- sql = [sql]
+ if split_statements:
+ sql = self.split_sql_string(sql)
+ else:
+ sql = [self.strip_sql_string(sql)]
if sql:
- self.log.debug("Executing %d statements against Exasol DB", len(sql))
+ self.log.debug("Executing following statements against Exasol DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")
with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, autocommit)
-
- for query in sql:
- self.log.info(query)
- with closing(conn.execute(query, parameters)) as cur:
- results = []
-
+ self.set_autocommit(conn, autocommit)
+ results = []
+ for sql_statement in sql:
+ with closing(conn.execute(sql_statement, parameters)) as cur:
+ self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
- cur = handler(cur)
+ result = handler(cur)
+ results.append(result)
- for row in cur:
- self.log.info("Statement execution info - %s", row)
- results.append(row)
+ self.log.info("Rows affected: %s", cur.rowcount)
- self.log.info(cur.row_count)
- # If autocommit was set to False for db that supports autocommit,
- # or if db does not support autocommit, we do a manual commit.
+ # If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()
- return results
+ if handler is None:
+ return None
+ elif scalar_return_last:
+ return results[-1]
+ else:
+ return results
def set_autocommit(self, conn, autocommit: bool) -> None:
"""
diff --git a/airflow/providers/exasol/operators/exasol.py b/airflow/providers/exasol/operators/exasol.py
index eecf44885e..33d4ab55c6 100644
--- a/airflow/providers/exasol/operators/exasol.py
+++ b/airflow/providers/exasol/operators/exasol.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.exasol.hooks.exasol import ExasolHook
@@ -46,10 +46,10 @@ class ExasolOperator(BaseOperator):
def __init__(
self,
*,
- sql: str,
+ sql: Union[str, Iterable[str]],
exasol_conn_id: str = 'exasol_default',
autocommit: bool = False,
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
schema: Optional[str] = None,
**kwargs,
) -> None:
diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py
index b7e12dabce..550c317406 100644
--- a/airflow/providers/google/cloud/operators/bigquery.py
+++ b/airflow/providers/google/cloud/operators/bigquery.py
@@ -550,7 +550,7 @@ class BigQueryExecuteQueryOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, Iterable],
+ sql: Union[str, Iterable[str]],
destination_dataset_table: Optional[str] = None,
write_disposition: str = 'WRITE_EMPTY',
allow_large_results: bool = False,
diff --git a/airflow/providers/google/cloud/operators/cloud_sql.py b/airflow/providers/google/cloud/operators/cloud_sql.py
index e90a21e296..a5d40c2ff4 100644
--- a/airflow/providers/google/cloud/operators/cloud_sql.py
+++ b/airflow/providers/google/cloud/operators/cloud_sql.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Cloud SQL operators."""
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from googleapiclient.errors import HttpError
@@ -1054,9 +1054,9 @@ class CloudSQLExecuteQueryOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[List[str], str],
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[Union[Dict, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
gcp_conn_id: str = 'google_cloud_default',
gcp_cloudsql_conn_id: str = 'google_cloud_sql_default',
**kwargs,
diff --git a/airflow/providers/google/suite/transfers/sql_to_sheets.py b/airflow/providers/google/suite/transfers/sql_to_sheets.py
index 8384868199..8626fb7227 100644
--- a/airflow/providers/google/suite/transfers/sql_to_sheets.py
+++ b/airflow/providers/google/suite/transfers/sql_to_sheets.py
@@ -68,7 +68,7 @@ class SQLToGoogleSheetsOperator(BaseSQLOperator):
sql: str,
spreadsheet_id: str,
sql_conn_id: str,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
database: Optional[str] = None,
spreadsheet_range: str = "Sheet1",
gcp_conn_id: str = "google_cloud_default",
diff --git a/airflow/providers/jdbc/operators/jdbc.py b/airflow/providers/jdbc/operators/jdbc.py
index 2c023d9afe..6b38366b41 100644
--- a/airflow/providers/jdbc/operators/jdbc.py
+++ b/airflow/providers/jdbc/operators/jdbc.py
@@ -16,20 +16,16 @@
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from airflow.models import BaseOperator
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.jdbc.hooks.jdbc import JdbcHook
if TYPE_CHECKING:
from airflow.utils.context import Context
-def fetch_all_handler(cursor):
- """Handler for DbApiHook.run() to return results"""
- return cursor.fetchall()
-
-
class JdbcOperator(BaseOperator):
"""
Executes sql code in a database using jdbc driver.
@@ -57,10 +53,10 @@ class JdbcOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
jdbc_conn_id: str = 'jdbc_default',
autocommit: bool = False,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -70,7 +66,7 @@ class JdbcOperator(BaseOperator):
self.autocommit = autocommit
self.hook = None
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = JdbcHook(jdbc_conn_id=self.jdbc_conn_id)
return hook.run(self.sql, self.autocommit, parameters=self.parameters, handler=fetch_all_handler)
diff --git a/airflow/providers/microsoft/mssql/operators/mssql.py b/airflow/providers/microsoft/mssql/operators/mssql.py
index 5a5738eb68..3a6434704f 100644
--- a/airflow/providers/microsoft/mssql/operators/mssql.py
+++ b/airflow/providers/microsoft/mssql/operators/mssql.py
@@ -57,9 +57,9 @@ class MsSqlOperator(BaseOperator):
def __init__(
self,
*,
- sql: str,
+ sql: Union[str, Iterable[str]],
mssql_conn_id: str = 'mssql_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = False,
database: Optional[str] = None,
**kwargs,
diff --git a/airflow/providers/mysql/operators/mysql.py b/airflow/providers/mysql/operators/mysql.py
index d51a97e6fe..975586cd52 100644
--- a/airflow/providers/mysql/operators/mysql.py
+++ b/airflow/providers/mysql/operators/mysql.py
@@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import ast
-from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook
@@ -59,9 +59,9 @@ class MySqlOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
mysql_conn_id: str = 'mysql_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = False,
database: Optional[str] = None,
**kwargs,
diff --git a/airflow/providers/neo4j/operators/neo4j.py b/airflow/providers/neo4j/operators/neo4j.py
index b61f0734f0..82939ad779 100644
--- a/airflow/providers/neo4j/operators/neo4j.py
+++ b/airflow/providers/neo4j/operators/neo4j.py
@@ -44,7 +44,7 @@ class Neo4jOperator(BaseOperator):
*,
sql: str,
neo4j_conn_id: str = 'neo4j_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py
index b60d4b6e89..969d69728b 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -50,9 +50,9 @@ class OracleOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
oracle_conn_id: str = 'oracle_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = False,
**kwargs,
) -> None:
@@ -98,7 +98,7 @@ class OracleStoredProcedureOperator(BaseOperator):
self.procedure = procedure
self.parameters = parameters
- def execute(self, context: 'Context') -> Optional[Union[List, Dict]]:
+ def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.procedure)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py
index e0238aa882..7a787498d2 100644
--- a/airflow/providers/postgres/operators/postgres.py
+++ b/airflow/providers/postgres/operators/postgres.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
from psycopg2.sql import SQL, Identifier
@@ -53,10 +53,10 @@ class PostgresOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
postgres_conn_id: str = 'postgres_default',
autocommit: bool = False,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
database: Optional[str] = None,
runtime_parameters: Optional[Mapping] = None,
**kwargs,
diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py
index 22afc71577..709a378a8d 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -18,7 +18,7 @@
import json
import os
import warnings
-from typing import Any, Callable, Iterable, Optional, overload
+from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload
import prestodb
from prestodb.exceptions import DatabaseError
@@ -142,10 +142,6 @@ class PrestoHook(DbApiHook):
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)
- @staticmethod
- def _strip_sql(sql: str) -> str:
- return sql.strip().rstrip(';')
-
@overload
def get_records(self, sql: str = "", parameters: Optional[dict] = None):
"""Get a set of records from Presto
@@ -169,7 +165,7 @@ class PrestoHook(DbApiHook):
sql = hql
try:
- return super().get_records(self._strip_sql(sql), parameters)
+ return super().get_records(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise PrestoException(e)
@@ -196,7 +192,7 @@ class PrestoHook(DbApiHook):
sql = hql
try:
- return super().get_first(self._strip_sql(sql), parameters)
+ return super().get_first(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise PrestoException(e)
@@ -226,7 +222,7 @@ class PrestoHook(DbApiHook):
cursor = self.get_cursor()
try:
- cursor.execute(self._strip_sql(sql), parameters)
+ cursor.execute(self.strip_sql_string(sql), parameters)
data = cursor.fetchall()
except DatabaseError as e:
raise PrestoException(e)
@@ -241,32 +237,38 @@ class PrestoHook(DbApiHook):
@overload
def run(
self,
- sql: str = "",
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
- ) -> None:
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Optional[Union[Any, List[Any]]]:
"""Execute the statement against Presto. Can be used to create views."""
@overload
def run(
self,
- sql: str = "",
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
hql: str = "",
- ) -> None:
+ ) -> Optional[Union[Any, List[Any]]]:
""":sphinx-autoapi-skip:"""
def run(
self,
- sql: str = "",
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
hql: str = "",
- ) -> None:
+ ) -> Optional[Union[Any, List[Any]]]:
""":sphinx-autoapi-skip:"""
if hql:
warnings.warn(
@@ -276,7 +278,14 @@ class PrestoHook(DbApiHook):
)
sql = hql
- return super().run(sql=self._strip_sql(sql), parameters=parameters, handler=handler)
+ return super().run(
+ sql=sql,
+ autocommit=autocommit,
+ parameters=parameters,
+ handler=handler,
+ split_statements=split_statements,
+ return_last=return_last,
+ )
def insert_rows(
self,
diff --git a/airflow/providers/snowflake/hooks/snowflake.py b/airflow/providers/snowflake/hooks/snowflake.py
index 3dee0989ed..21655aaf5b 100644
--- a/airflow/providers/snowflake/hooks/snowflake.py
+++ b/airflow/providers/snowflake/hooks/snowflake.py
@@ -19,13 +19,12 @@ import os
from contextlib import closing
from io import StringIO
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional, Sequence, Union
+from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake import connector
-from snowflake.connector import DictCursor, SnowflakeConnection
-from snowflake.connector.util_text import split_statements
+from snowflake.connector import DictCursor, SnowflakeConnection, util_text
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine
@@ -286,11 +285,13 @@ class SnowflakeHook(DbApiHook):
def run(
self,
- sql: Union[str, list],
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[Union[Sequence[Any], Dict[Any, Any]]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
- ):
+ split_statements: bool = True,
+ return_last: bool = True,
+ ) -> Optional[Union[Any, List[Any]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -305,15 +306,22 @@ class SnowflakeHook(DbApiHook):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the LAST SQL expression if handler was provided.
"""
self.query_ids = []
+ scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
- split_statements_tuple = split_statements(StringIO(sql))
- sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string]
+ if split_statements:
+ split_statements_tuple = util_text.split_statements(StringIO(sql))
+ sql = [sql_string for sql_string, _ in split_statements_tuple if sql_string]
+ else:
+ sql = [self.strip_sql_string(sql)]
if sql:
- self.log.debug("Executing %d statements against Snowflake DB", len(sql))
+ self.log.debug("Executing following statements against Snowflake DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")
@@ -322,33 +330,29 @@ class SnowflakeHook(DbApiHook):
# SnowflakeCursor does not extend ContextManager, so we have to ignore mypy error here
with closing(conn.cursor(DictCursor)) as cur: # type: ignore[type-var]
-
+ results = []
for sql_statement in sql:
+ self._run_command(cur, sql_statement, parameters)
- self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
- if parameters:
- cur.execute(sql_statement, parameters)
- else:
- cur.execute(sql_statement)
-
- execution_info = []
if handler is not None:
- cur = handler(cur)
- for row in cur:
- self.log.info("Statement execution info - %s", row)
- execution_info.append(row)
+ result = handler(cur)
+ results.append(result)
query_id = cur.sfqid
self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", query_id)
self.query_ids.append(query_id)
- # If autocommit was set to False for db that supports autocommit,
- # or if db does not supports autocommit, we do a manual commit.
+ # If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()
- return execution_info
+ if handler is None:
+ return None
+ elif scalar_return_last:
+ return results[-1]
+ else:
+ return results
def test_connection(self):
"""Test the Snowflake connection by running a simple query."""
diff --git a/airflow/providers/snowflake/operators/snowflake.py b/airflow/providers/snowflake/operators/snowflake.py
index 086c1d6fd5..dd996cc526 100644
--- a/airflow/providers/snowflake/operators/snowflake.py
+++ b/airflow/providers/snowflake/operators/snowflake.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Optional, Sequence, SupportsAbs
+from typing import Any, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
from airflow.models import BaseOperator
from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
@@ -83,9 +84,9 @@ class SnowflakeOperator(BaseOperator):
def __init__(
self,
*,
- sql: Any,
+ sql: Union[str, Iterable[str]],
snowflake_conn_id: str = 'snowflake_default',
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
@@ -113,11 +114,11 @@ class SnowflakeOperator(BaseOperator):
def get_db_hook(self) -> SnowflakeHook:
return get_db_hook(self)
- def execute(self, context: Any) -> None:
+ def execute(self, context: Any):
"""Run query on snowflake"""
self.log.info('Executing: %s', self.sql)
hook = self.get_db_hook()
- execution_info = hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+ execution_info = hook.run(self.sql, self.autocommit, self.parameters, fetch_all_handler)
self.query_ids = hook.query_ids
if self.do_xcom_push:
@@ -186,9 +187,9 @@ class SnowflakeCheckOperator(SQLCheckOperator):
def __init__(
self,
*,
- sql: Any,
+ sql: str,
snowflake_conn_id: str = 'snowflake_default',
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
@@ -257,7 +258,7 @@ class SnowflakeValueCheckOperator(SQLValueCheckOperator):
pass_value: Any,
tolerance: Any = None,
snowflake_conn_id: str = 'snowflake_default',
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
@@ -334,7 +335,7 @@ class SnowflakeIntervalCheckOperator(SQLIntervalCheckOperator):
date_filter_column: str = 'ds',
days_back: SupportsAbs[int] = -7,
snowflake_conn_id: str = 'snowflake_default',
- parameters: Optional[dict] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
do_xcom_push: bool = True,
warehouse: Optional[str] = None,
diff --git a/airflow/providers/sqlite/operators/sqlite.py b/airflow/providers/sqlite/operators/sqlite.py
index 7ef97ca296..ef20d760c0 100644
--- a/airflow/providers/sqlite/operators/sqlite.py
+++ b/airflow/providers/sqlite/operators/sqlite.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Iterable, List, Mapping, Optional, Sequence, Union
+from typing import Any, Iterable, Mapping, Optional, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
@@ -45,9 +45,9 @@ class SqliteOperator(BaseOperator):
def __init__(
self,
*,
- sql: Union[str, List[str]],
+ sql: Union[str, Iterable[str]],
sqlite_conn_id: str = 'sqlite_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py
index 9170e19a54..d8ac5148de 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -19,10 +19,8 @@ import json
import os
import warnings
from contextlib import closing
-from itertools import chain
-from typing import Any, Callable, Iterable, Optional, Tuple, overload
+from typing import Any, Callable, Iterable, List, Mapping, Optional, Union, overload
-import sqlparse
import trino
from trino.exceptions import DatabaseError
from trino.transaction import IsolationLevel
@@ -150,12 +148,8 @@ class TrinoHook(DbApiHook):
isolation_level = db.extra_dejson.get('isolation_level', 'AUTOCOMMIT').upper()
return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT)
- @staticmethod
- def _strip_sql(sql: str) -> str:
- return sql.strip().rstrip(';')
-
@overload
- def get_records(self, sql: str = "", parameters: Optional[dict] = None):
+ def get_records(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None):
"""Get a set of records from Trino
:param sql: SQL statement to be executed.
@@ -163,10 +157,14 @@ class TrinoHook(DbApiHook):
"""
@overload
- def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""):
+ def get_records(
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ ):
""":sphinx-autoapi-skip:"""
- def get_records(self, sql: str = "", parameters: Optional[dict] = None, hql: str = ""):
+ def get_records(
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ ):
""":sphinx-autoapi-skip:"""
if hql:
warnings.warn(
@@ -177,12 +175,12 @@ class TrinoHook(DbApiHook):
sql = hql
try:
- return super().get_records(self._strip_sql(sql), parameters)
+ return super().get_records(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise TrinoException(e)
@overload
- def get_first(self, sql: str = "", parameters: Optional[dict] = None) -> Any:
+ def get_first(self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""Returns only the first row, regardless of how many rows the query returns.
:param sql: SQL statement to be executed.
@@ -190,10 +188,14 @@ class TrinoHook(DbApiHook):
"""
@overload
- def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any:
+ def get_first(
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ ) -> Any:
""":sphinx-autoapi-skip:"""
- def get_first(self, sql: str = "", parameters: Optional[dict] = None, hql: str = "") -> Any:
+ def get_first(
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = ""
+ ) -> Any:
""":sphinx-autoapi-skip:"""
if hql:
warnings.warn(
@@ -204,13 +206,13 @@ class TrinoHook(DbApiHook):
sql = hql
try:
- return super().get_first(self._strip_sql(sql), parameters)
+ return super().get_first(self.strip_sql_string(sql), parameters)
except DatabaseError as e:
raise TrinoException(e)
@overload
def get_pandas_df(
- self, sql: str = "", parameters: Optional[dict] = None, **kwargs
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, **kwargs
): # type: ignore[override]
"""Get a pandas dataframe from a sql query.
@@ -220,12 +222,12 @@ class TrinoHook(DbApiHook):
@overload
def get_pandas_df(
- self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs
): # type: ignore[override]
""":sphinx-autoapi-skip:"""
def get_pandas_df(
- self, sql: str = "", parameters: Optional[dict] = None, hql: str = "", **kwargs
+ self, sql: str = "", parameters: Optional[Union[Iterable, Mapping]] = None, hql: str = "", **kwargs
): # type: ignore[override]
""":sphinx-autoapi-skip:"""
if hql:
@@ -240,7 +242,7 @@ class TrinoHook(DbApiHook):
cursor = self.get_cursor()
try:
- cursor.execute(self._strip_sql(sql), parameters)
+ cursor.execute(self.strip_sql_string(sql), parameters)
data = cursor.fetchall()
except DatabaseError as e:
raise TrinoException(e)
@@ -255,32 +257,38 @@ class TrinoHook(DbApiHook):
@overload
def run(
self,
- sql,
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[Tuple] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
- ) -> None:
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Optional[Union[Any, List[Any]]]:
"""Execute the statement against Trino. Can be used to create views."""
@overload
def run(
self,
- sql,
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[Tuple] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
hql: str = "",
- ) -> None:
+ ) -> Optional[Union[Any, List[Any]]]:
""":sphinx-autoapi-skip:"""
def run(
self,
- sql,
+ sql: Union[str, Iterable[str]],
autocommit: bool = False,
- parameters: Optional[Tuple] = None,
+ parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
+ split_statements: bool = False,
+ return_last: bool = True,
hql: str = "",
- ):
+ ) -> Optional[Union[Any, List[Any]]]:
""":sphinx-autoapi-skip:"""
if hql:
warnings.warn(
@@ -289,38 +297,15 @@ class TrinoHook(DbApiHook):
stacklevel=2,
)
sql = hql
- scalar = isinstance(sql, str)
-
- with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, autocommit)
- if scalar:
- sql = sqlparse.split(sql)
-
- with closing(conn.cursor()) as cur:
- results = []
- for sql_statement in sql:
- self._run_command(cur, self._strip_sql(sql_statement), parameters)
- self.query_id = cur.stats["queryId"]
- if handler is not None:
- result = handler(cur)
- results.append(result)
-
- # If autocommit was set to False for db that supports autocommit,
- # or if db does not supports autocommit, we do a manual commit.
- if not self.get_autocommit(conn):
- conn.commit()
-
- self.log.info("Query Execution Result: %s", str(list(chain.from_iterable(results))))
-
- if handler is None:
- return None
-
- if scalar:
- return results[0]
-
- return results
+ return super().run(
+ sql=sql,
+ autocommit=autocommit,
+ parameters=parameters,
+ handler=handler,
+ split_statements=split_statements,
+ return_last=return_last,
+ )
def insert_rows(
self,
diff --git a/airflow/providers/vertica/operators/vertica.py b/airflow/providers/vertica/operators/vertica.py
index 3a30e0ee27..7a804f8ed7 100644
--- a/airflow/providers/vertica/operators/vertica.py
+++ b/airflow/providers/vertica/operators/vertica.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, List, Sequence, Union
+from typing import TYPE_CHECKING, Any, Iterable, Sequence, Union
from airflow.models import BaseOperator
from airflow.providers.vertica.hooks.vertica import VerticaHook
@@ -40,7 +40,7 @@ class VerticaOperator(BaseOperator):
ui_color = '#b4e0ff'
def __init__(
- self, *, sql: Union[str, List[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any
+ self, *, sql: Union[str, Iterable[str]], vertica_conn_id: str = 'vertica_default', **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.vertica_conn_id = vertica_conn_id
diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json
index 203fa5e22a..743a73d0da 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -62,8 +62,7 @@
"deps": [
"apache-airflow-providers-common-sql",
"apache-airflow>=2.2.0",
- "sqlalchemy-drill>=1.1.0",
- "sqlparse>=0.4.1"
+ "sqlalchemy-drill>=1.1.0"
],
"cross-providers-deps": [
"common.sql"
@@ -191,7 +190,9 @@
"cross-providers-deps": []
},
"common.sql": {
- "deps": [],
+ "deps": [
+ "sqlparse>=0.4.2"
+ ],
"cross-providers-deps": []
},
"databricks": {
diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py
index 05952beabd..d70d203e33 100644
--- a/tests/providers/databricks/hooks/test_databricks_sql.py
+++ b/tests/providers/databricks/hooks/test_databricks_sql.py
@@ -23,6 +23,7 @@ from unittest import mock
import pytest
from airflow.models import Connection
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.utils.session import provide_session
@@ -72,17 +73,17 @@ class TestDatabricksSqlHookQueryByName(unittest.TestCase):
test_schema = [(field,) for field in test_fields]
conn = mock_conn.return_value
- cur = mock.MagicMock(rowcount=0)
+ cur = mock.MagicMock(rowcount=0, description=test_schema)
+ cur.fetchall.return_value = []
conn.cursor.return_value = cur
- type(cur).description = mock.PropertyMock(return_value=test_schema)
- query = "select * from test.test"
- schema, results = self.hook.run(sql=query)
+ query = "select * from test.test;"
+ schema, results = self.hook.run(sql=query, handler=fetch_all_handler)
assert schema == test_schema
assert results == []
- cur.execute.assert_has_calls([mock.call(q) for q in [query]])
+ cur.execute.assert_has_calls([mock.call(q) for q in [query.rstrip(';')]])
cur.close.assert_called()
def test_no_query(self):
diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py
index 783fa520a7..6775ff8398 100644
--- a/tests/providers/databricks/operators/test_databricks_sql.py
+++ b/tests/providers/databricks/operators/test_databricks_sql.py
@@ -25,6 +25,7 @@ import pytest
from databricks.sql.types import Row
from airflow import AirflowException
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.operators.databricks_sql import (
DatabricksCopyIntoOperator,
DatabricksSqlOperator,
@@ -47,7 +48,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
db_mock = db_mock_class.return_value
mock_schema = [('id',), ('value',)]
mock_results = [Row(id=1, value='value1')]
- db_mock.run.return_value = (mock_schema, mock_results)
+ db_mock.run.return_value = [(mock_schema, mock_results)]
results = op.execute(None)
@@ -61,7 +62,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
catalog=None,
schema=None,
)
- db_mock.run.assert_called_once_with(sql, parameters=None)
+ db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler)
@mock.patch('airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook')
def test_exec_write_file(self, db_mock_class):
@@ -74,7 +75,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
db_mock = db_mock_class.return_value
mock_schema = [('id',), ('value',)]
mock_results = [Row(id=1, value='value1')]
- db_mock.run.return_value = (mock_schema, mock_results)
+ db_mock.run.return_value = [(mock_schema, mock_results)]
try:
op.execute(None)
@@ -92,7 +93,7 @@ class TestDatabricksSqlOperator(unittest.TestCase):
catalog=None,
schema=None,
)
- db_mock.run.assert_called_once_with(sql, parameters=None)
+ db_mock.run.assert_called_once_with(sql, parameters=None, handler=fetch_all_handler)
class TestDatabricksSqlCopyIntoOperator(unittest.TestCase):
diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py
index 812d60fd45..9168674c56 100644
--- a/tests/providers/jdbc/operators/test_jdbc.py
+++ b/tests/providers/jdbc/operators/test_jdbc.py
@@ -19,7 +19,8 @@
import unittest
from unittest.mock import patch
-from airflow.providers.jdbc.operators.jdbc import JdbcOperator, fetch_all_handler
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.jdbc.operators.jdbc import JdbcOperator
class TestJdbcOperator(unittest.TestCase):
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index d33dbf79f7..254514bc9f 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -268,7 +268,7 @@ class TestOracleHook(unittest.TestCase):
self.cur.bindvars = None
result = self.db_hook.callproc('proc', True, parameters)
- assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')]
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END')]
assert result == parameters
def test_callproc_dict(self):
@@ -280,7 +280,7 @@ class TestOracleHook(unittest.TestCase):
self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
result = self.db_hook.callproc('proc', True, parameters)
- assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END', parameters)]
assert result == parameters
def test_callproc_list(self):
@@ -292,7 +292,7 @@ class TestOracleHook(unittest.TestCase):
self.cur.bindvars = list(map(bindvar, parameters))
result = self.db_hook.callproc('proc', True, parameters)
- assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END', parameters)]
assert result == parameters
def test_callproc_out_param(self):
@@ -306,7 +306,7 @@ class TestOracleHook(unittest.TestCase):
self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
result = self.db_hook.callproc('proc', True, parameters)
expected = [1, 0, 0.0, False, '']
- assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END', expected)]
assert result == expected
def test_test_connection_use_dual_table(self):