You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/10/31 05:30:06 UTC

[airflow] branch main updated: Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c37b503f1 Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)
5c37b503f1 is described below

commit 5c37b503f118b8ad2585dff9949dd8fdb96689ed
Author: Dmytro Kazanzhy <dk...@gmail.com>
AuthorDate: Mon Oct 31 07:29:58 2022 +0200

    Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)
---
 airflow/providers/common/sql/hooks/sql.py      | 35 +++++++++++---------------
 airflow/providers/exasol/hooks/exasol.py       |  3 +--
 airflow/providers/presto/hooks/presto.py       |  5 ++--
 airflow/providers/trino/hooks/trino.py         |  3 +--
 tests/providers/common/sql/hooks/test_dbapi.py | 22 ++++++++++++++--
 5 files changed, 38 insertions(+), 30 deletions(-)

diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
index 73f0eea0e8..1c67350c4d 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -38,6 +38,14 @@ def fetch_all_handler(cursor) -> list[tuple] | None:
         return None
 
 
+def fetch_one_handler(cursor) -> list[tuple] | None:
+    """Handler for DbApiHook.run() to return results"""
+    if cursor.description is not None:
+        return cursor.fetchone()
+    else:
+        return None
+
+
 class ConnectorProtocol(Protocol):
     """A protocol where you can connect to a database."""
 
@@ -178,38 +186,23 @@ class DbApiHook(BaseForDbApiHook):
         self,
         sql: str | list[str],
         parameters: Iterable | Mapping | None = None,
-        **kwargs: dict,
-    ):
+    ) -> Any:
         """
         Executes the sql and returns a set of records.
 
-        :param sql: the sql statement to be executed (str) or a list of
-            sql statements to execute
+        :param sql: the sql statement to be executed (str) or a list of sql statements to execute
         :param parameters: The parameters to render the SQL query with.
         """
-        with closing(self.get_conn()) as conn:
-            with closing(conn.cursor()) as cur:
-                if parameters is not None:
-                    cur.execute(sql, parameters)
-                else:
-                    cur.execute(sql)
-                return cur.fetchall()
+        return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)
 
-    def get_first(self, sql: str | list[str], parameters=None):
+    def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
         """
         Executes the sql and returns the first resulting row.
 
-        :param sql: the sql statement to be executed (str) or a list of
-            sql statements to execute
+        :param sql: the sql statement to be executed (str) or a list of sql statements to execute
         :param parameters: The parameters to render the SQL query with.
         """
-        with closing(self.get_conn()) as conn:
-            with closing(conn.cursor()) as cur:
-                if parameters is not None:
-                    cur.execute(sql, parameters)
-                else:
-                    cur.execute(sql)
-                return cur.fetchone()
+        return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler)
 
     @staticmethod
     def strip_sql_string(sql: str) -> str:
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index 51d2e30759..3b45e8f2f2 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -82,7 +82,6 @@ class ExasolHook(DbApiHook):
         self,
         sql: str | list[str],
         parameters: Iterable | Mapping | None = None,
-        **kwargs: dict,
     ) -> list[dict | tuple[Any, ...]]:
         """
         Executes the sql and returns a set of records.
@@ -95,7 +94,7 @@ class ExasolHook(DbApiHook):
             with closing(conn.execute(sql, parameters)) as cur:
                 return cur.fetchall()
 
-    def get_first(self, sql: str | list[str], parameters: dict | None = None) -> Any | None:
+    def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
         """
         Executes the sql and returns the first resulting row.
 
diff --git a/airflow/providers/presto/hooks/presto.py b/airflow/providers/presto/hooks/presto.py
index 902ae67eac..b5f8fab545 100644
--- a/airflow/providers/presto/hooks/presto.py
+++ b/airflow/providers/presto/hooks/presto.py
@@ -145,8 +145,7 @@ class PrestoHook(DbApiHook):
         self,
         sql: str | list[str] = "",
         parameters: Iterable | Mapping | None = None,
-        **kwargs: dict,
-    ):
+    ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
         try:
@@ -154,7 +153,7 @@ class PrestoHook(DbApiHook):
         except DatabaseError as e:
             raise PrestoException(e)
 
-    def get_first(self, sql: str | list[str] = "", parameters: dict | None = None) -> Any:
+    def get_first(self, sql: str | list[str] = "", parameters: Iterable | Mapping | None = None) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Presto Hook must be a string and is {sql}!")
         try:
diff --git a/airflow/providers/trino/hooks/trino.py b/airflow/providers/trino/hooks/trino.py
index 629ea2bf57..e4be4a092a 100644
--- a/airflow/providers/trino/hooks/trino.py
+++ b/airflow/providers/trino/hooks/trino.py
@@ -158,8 +158,7 @@ class TrinoHook(DbApiHook):
         self,
         sql: str | list[str] = "",
         parameters: Iterable | Mapping | None = None,
-        **kwargs: dict,
-    ):
+    ) -> Any:
         if not isinstance(sql, str):
             raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
         try:
diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py
index 65abbb2768..beca713949 100644
--- a/tests/providers/common/sql/hooks/test_dbapi.py
+++ b/tests/providers/common/sql/hooks/test_dbapi.py
@@ -25,7 +25,7 @@ import pytest
 
 from airflow.hooks.base import BaseHook
 from airflow.models import Connection
-from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler, fetch_one_handler
 
 
 class DbApiHookInProvider(DbApiHook):
@@ -41,7 +41,7 @@ class TestDbApiHook(unittest.TestCase):
         super().setUp()
 
         self.cur = mock.MagicMock(
-            rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "close"]
+            rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"]
         )
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
@@ -430,3 +430,21 @@ class TestDbApiHook(unittest.TestCase):
 
         self.cur.fetchall.side_effect = Exception("Should not get called !")
         assert rows == self.db_hook.run(sql=query, handler=fetch_all_handler)
+
+    def test_run_fetch_one_handler_select_1(self):
+        self.cur.rowcount = -1  # can be -1 according to pep249
+        self.cur.description = (tuple([None] * 7),)
+        query = "SELECT 1"
+        rows = [[1]]
+
+        self.cur.fetchone.return_value = rows
+        assert rows == self.db_hook.run(sql=query, handler=fetch_one_handler)
+
+    def test_run_fetch_one_handler_print(self):
+        self.cur.rowcount = -1
+        self.cur.description = None
+        query = "PRINT('Hello World !')"
+        rows = None
+
+        self.cur.fetchone.side_effect = Exception("Should not get called !")
+        assert rows == self.db_hook.run(sql=query, handler=fetch_one_handler)