You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:30:06 UTC
[airflow] branch main updated: Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5c37b503f1 Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)
5c37b503f1 is described below
commit 5c37b503f118b8ad2585dff9949dd8fdb96689ed
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Mon Oct 31 07:29:58 2022 +0200
Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)
---
airflow/providers/common/sql/hooks/sql.py | 35 +++++++++++---------------
airflow/providers/exasol/hooks/exasol.py | 3 +--
airflow/providers/presto/hooks/presto.py | 5 ++--
airflow/providers/trino/hooks/trino.py | 3 +--
tests/providers/common/sql/hooks/test_dbapi.py | 22 ++++++++++++++--
5 files changed, 38 insertions(+), 30 deletions(-)
diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
index 73f0eea0e8..1c67350c4d 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -38,6 +38,14 @@ def fetch_all_handler(cursor) -> list[tuple] | None:
return None
+def fetch_one_handler(cursor) -> list[tuple] | None:
+ """Handler for DbApiHook.run() to return results"""
+ if cursor.description is not None:
+ return cursor.fetchone()
+ else:
+ return None
+
+
class ConnectorProtocol(Protocol):
"""A protocol where you can connect to a database."""
@@ -178,38 +186,23 @@ class DbApiHook(BaseForDbApiHook):
self,
sql: str | list[str],
parameters: Iterable | Mapping | None = None,
- **kwargs: dict,
- ):
+ ) -> Any:
"""
Executes the sql and returns a set of records.
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
+ :param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
- with closing(self.get_conn()) as conn:
- with closing(conn.cursor()) as cur:
- if parameters is not None:
- cur.execute(sql, parameters)
- else:
- cur.execute(sql)
- return cur.fetchall()
+ return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)
- def get_first(self, sql: str | list[str], parameters=None):
+ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
+ :param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param parameters: The parameters to render the SQL query with.
"""
- with closing(self.get_conn()) as conn:
- with closing(conn.cursor()) as cur:
- if parameters is not None:
- cur.execute(sql, parameters)
- else:
- cur.execute(sql)
- return cur.fetchone()
+ return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler)
@staticmethod
def strip_sql_string(sql: str) -> str:
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index 51d2e30759..3b45e8f2f2 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -82,7 +82,6 @@ class ExasolHook(DbApiHook):
self,
sql: str | list[str],
parameters: Iterable | Mapping | None = None,
- **kwargs: dict,
) -> list[dict | tuple[Any, ...]]:
"""
Executes the sql and returns a set of records.
@@ -95,7 +94,7 @@ class ExasolHook(DbApiHook):
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()
- def get_first(self, sql: str | list[str], parameters: dict | None = None) -> Any | None:
+ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py
index 902ae67eac..b5f8fab545 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -145,8 +145,7 @@ class PrestoHook(DbApiHook):
self,
sql: str | list[str] = "",
parameters: Iterable | Mapping | None = None,
- **kwargs: dict,
- ):
+ ) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
try:
@@ -154,7 +153,7 @@ class PrestoHook(DbApiHook):
except DatabaseError as e:
raise PrestoException(e)
- def get_first(self, sql: str | list[str] = "", parameters: dict | None = None) -> Any:
+ def get_first(self, sql: str | list[str] = "", parameters: Iterable | Mapping | None = None) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
try:
diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py
index 629ea2bf57..e4be4a092a 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -158,8 +158,7 @@ class TrinoHook(DbApiHook):
self,
sql: str | list[str] = "",
parameters: Iterable | Mapping | None = None,
- **kwargs: dict,
- ):
+ ) -> Any:
if not isinstance(sql, str):
raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
try:
diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py
index 65abbb2768..beca713949 100644
--- a/tests/providers/common/sql/hooks/test_dbapi.py
+++ b/tests/providers/common/sql/hooks/test_dbapi.py
@@ -25,7 +25,7 @@ import pytest
from airflow.hooks.base import BaseHook
from airflow.models import Connection
-from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, fetch_one_handler
class DbApiHookInProvider(DbApiHook):
@@ -41,7 +41,7 @@ class TestDbApiHook(unittest.TestCase):
super().setUp()
self.cur = mock.MagicMock(
- rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "close"]
+ rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"]
)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
@@ -430,3 +430,21 @@ class TestDbApiHook(unittest.TestCase):
self.cur.fetchall.side_effect = Exception("Should not get called !")
assert rows == self.db_hook.run(sql=query, handler=fetch_all_handler)
+
+ def test_run_fetch_one_handler_select_1(self):
+ self.cur.rowcount = -1 # can be -1 according to pep249
+ self.cur.description = (tuple([None] * 7),)
+ query = "SELECT 1"
+ rows = [[1]]
+
+ self.cur.fetchone.return_value = rows
+ assert rows == self.db_hook.run(sql=query, handler=fetch_one_handler)
+
+ def test_run_fetch_one_handler_print(self):
+ self.cur.rowcount = -1
+ self.cur.description = None
+ query = "PRINT('Hello World !')"
+ rows = None
+
+ self.cur.fetchone.side_effect = Exception("Should not get called !")
+ assert rows == self.db_hook.run(sql=query, handler=fetch_one_handler)