You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/05/17 16:30:24 UTC

[airflow] branch master updated: Add optional result handler to database hooks (#15581)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new abcd487  Add optional result handler to database hooks (#15581)
abcd487 is described below

commit abcd48731303d9e141bdc94acc2db46d73ccbe12
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Mon May 17 18:30:05 2021 +0200

    Add optional result handler to database hooks (#15581)
    
    This allows retrieving result rows, or other post-query access
    to the cursor object.
    
    Co-authored-by: Ash Berlin-Taylor <as...@apache.org>
---
 CHANGELOG.txt                                      |  1 +
 airflow/hooks/dbapi.py                             | 41 ++++++++++++++++------
 airflow/providers/oracle/operators/oracle.py       |  2 +-
 tests/hooks/test_dbapi.py                          | 36 ++++++++++++++++++-
 tests/providers/apache/druid/hooks/test_druid.py   |  2 +-
 tests/providers/apache/pinot/hooks/test_pinot.py   |  2 +-
 .../elasticsearch/hooks/test_elasticsearch.py      |  2 +-
 tests/providers/exasol/hooks/test_exasol.py        |  2 +-
 tests/providers/mysql/hooks/test_mysql.py          |  2 +-
 tests/providers/oracle/hooks/test_oracle.py        |  2 +-
 tests/providers/oracle/operators/test_oracle.py    | 10 ++++--
 tests/providers/postgres/hooks/test_postgres.py    |  2 +-
 tests/providers/presto/hooks/test_presto.py        |  2 +-
 tests/providers/snowflake/hooks/test_snowflake.py  |  4 +--
 tests/providers/sqlite/hooks/test_sqlite.py        |  2 +-
 tests/providers/trino/hooks/test_trino.py          |  2 +-
 tests/providers/vertica/hooks/test_vertica.py      |  2 +-
 17 files changed, 88 insertions(+), 28 deletions(-)

diff --git a/CHANGELOG.txt b/CHANGELOG.txt
index 9719a14..34f4c38 100644
--- a/CHANGELOG.txt
+++ b/CHANGELOG.txt
@@ -43,6 +43,7 @@ New Features
 Improvements
 """"""""""""
 
+- Add optional result handler callback to ``DbApiHook`` (#15581)
 - Update Flask App Builder limit to recently released 3.3 (#15792)
 - Prevent creating flask sessions on REST API requests (#15295)
 - Sync DAG specific permissions when parsing (#15311)
diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index 9821643..6c00320 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -152,7 +152,7 @@ class DbApiHook(BaseHook):
                     cur.execute(sql)
                 return cur.fetchone()
 
-    def run(self, sql, autocommit=False, parameters=None):
+    def run(self, sql, autocommit=False, parameters=None, handler=None):
         """
         Runs a command or a list of commands. Pass a list of sql
         statements to the sql parameter to get them to execute
@@ -166,8 +166,12 @@ class DbApiHook(BaseHook):
         :type autocommit: bool
         :param parameters: The parameters to render the SQL query with.
         :type parameters: dict or iterable
+        :param handler: The result handler which is called with the result of each statement.
+        :type handler: callable
+        :return: query results if handler was provided.
         """
-        if isinstance(sql, str):
+        scalar = isinstance(sql, str)
+        if scalar:
             sql = [sql]
 
         with closing(self.get_conn()) as conn:
@@ -175,21 +179,38 @@ class DbApiHook(BaseHook):
                 self.set_autocommit(conn, autocommit)
 
             with closing(conn.cursor()) as cur:
+                results = []
                 for sql_statement in sql:
-
-                    self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
-                    if parameters:
-                        cur.execute(sql_statement, parameters)
-                    else:
-                        cur.execute(sql_statement)
-                    if hasattr(cur, 'rowcount'):
-                        self.log.info("Rows affected: %s", cur.rowcount)
+                    self._run_command(cur, sql_statement, parameters)
+                    if handler is not None:
+                        result = handler(cur)
+                        results.append(result)
 
             # If autocommit was set to False for db that supports autocommit,
             # or if db does not supports autocommit, we do a manual commit.
             if not self.get_autocommit(conn):
                 conn.commit()
 
+        if handler is None:
+            return None
+
+        if scalar:
+            return results[0]
+
+        return results
+
+    def _run_command(self, cur, sql_statement, parameters):
+        """Runs a statement using an already open cursor."""
+        self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
+        if parameters:
+            cur.execute(sql_statement, parameters)
+        else:
+            cur.execute(sql_statement)
+
+        # According to PEP 249, this is -1 when query result is not applicable.
+        if cur.rowcount >= 0:
+            self.log.info("Rows affected: %s", cur.rowcount)
+
     def set_autocommit(self, conn, autocommit):
         """Sets the autocommit flag on the connection"""
         if not self.supports_autocommit and autocommit:
diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py
index d764225..41e9c31 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -23,7 +23,7 @@ from airflow.providers.oracle.hooks.oracle import OracleHook
 
 class OracleOperator(BaseOperator):
     """
-    Executes sql code in a specific Oracle database
+    Executes sql code in a specific Oracle database.
 
     :param sql: the sql code to be executed. Can receive a str representing a sql statement,
         a list of str (sql statements), or reference to a template file.
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index 0f6c55a..383d69e 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -30,7 +30,7 @@ class TestDbApiHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
@@ -184,3 +184,37 @@ class TestDbApiHook(unittest.TestCase):
         statement = 'SQL'
         self.db_hook.run(statement)
         assert self.db_hook.log.info.call_count == 2
+
+    def test_run_with_handler(self):
+        sql = 'SQL'
+        param = ('p1', 'p2')
+        called = 0
+        obj = object()
+
+        def handler(cur):
+            cur.execute.assert_called_once_with(sql, param)
+            nonlocal called
+            called += 1
+            return obj
+
+        result = self.db_hook.run(sql, parameters=param, handler=handler)
+        assert called == 1
+        assert self.conn.commit.called
+        assert result == obj
+
+    def test_run_with_handler_multiple(self):
+        sql = ['SQL', 'SQL']
+        param = ('p1', 'p2')
+        called = 0
+        obj = object()
+
+        def handler(cur):
+            cur.execute.assert_called_with(sql[0], param)
+            nonlocal called
+            called += 1
+            return obj
+
+        result = self.db_hook.run(sql, parameters=param, handler=handler)
+        assert called == 2
+        assert self.conn.commit.called
+        assert result == [obj, obj]
diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py
index afd7c1c..fca5d06 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -192,7 +192,7 @@ class TestDruidHook(unittest.TestCase):
 class TestDruidDbApiHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
-        self.cur = MagicMock()
+        self.cur = MagicMock(rowcount=0)
         self.conn = conn = MagicMock()
         self.conn.host = 'host'
         self.conn.port = '1000'
diff --git a/tests/providers/apache/pinot/hooks/test_pinot.py b/tests/providers/apache/pinot/hooks/test_pinot.py
index af0676f..763704f 100644
--- a/tests/providers/apache/pinot/hooks/test_pinot.py
+++ b/tests/providers/apache/pinot/hooks/test_pinot.py
@@ -212,7 +212,7 @@ class TestPinotDbApiHook(unittest.TestCase):
         self.conn.port = '1000'
         self.conn.conn_type = 'http'
         self.conn.extra_dejson = {'endpoint': 'query/sql'}
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn.cursor.return_value = self.cur
         self.conn.__enter__.return_value = self.cur
         self.conn.__exit__.return_value = None
diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
index b96f64e..06057b0 100644
--- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py
+++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
@@ -48,7 +48,7 @@ class TestElasticsearchHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py
index 6dc4299..607d11c 100644
--- a/tests/providers/exasol/hooks/test_exasol.py
+++ b/tests/providers/exasol/hooks/test_exasol.py
@@ -70,7 +70,7 @@ class TestExasolHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.execute.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py
index c50b8a0..c18be07 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -236,7 +236,7 @@ class TestMySqlHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index d22a0b4..b3a6a61 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -183,7 +183,7 @@ class TestOracleHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py
index 2c4d984..8565efe 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -23,7 +23,7 @@ from airflow.providers.oracle.operators.oracle import OracleOperator
 
 
 class TestOracleOperator(unittest.TestCase):
-    @mock.patch.object(OracleHook, 'run')
+    @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
     def test_execute(self, mock_run):
         sql = 'SELECT * FROM test_table'
         oracle_conn_id = 'oracle_default'
@@ -40,5 +40,9 @@ class TestOracleOperator(unittest.TestCase):
             task_id=task_id,
         )
         operator.execute(context=context)
-
-        mock_run.assert_called_once_with(sql, autocommit=autocommit, parameters=parameters)
+        mock_run.assert_called_once_with(
+            mock.ANY,
+            sql,
+            autocommit=autocommit,
+            parameters=parameters,
+        )
diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py
index 2890ad7..9a0226f 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -148,7 +148,7 @@ class TestPostgresHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
 
diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py
index e6ebb73..08278dd 100644
--- a/tests/providers/presto/hooks/test_presto.py
+++ b/tests/providers/presto/hooks/test_presto.py
@@ -147,7 +147,7 @@ class TestPrestoHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py
index f2175fa..deec76d 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -31,8 +31,8 @@ class TestSnowflakeHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
-        self.cur2 = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
+        self.cur2 = mock.MagicMock(rowcount=0)
 
         self.cur.sfqid = 'uuid'
         self.cur2.sfqid = 'uuid2'
diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py
index 7a25479..80fd6a4 100644
--- a/tests/providers/sqlite/hooks/test_sqlite.py
+++ b/tests/providers/sqlite/hooks/test_sqlite.py
@@ -53,7 +53,7 @@ class TestSqliteHookConn(unittest.TestCase):
 class TestSqliteHook(unittest.TestCase):
     def setUp(self):
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py
index e649d2b..a688362 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -147,7 +147,7 @@ class TestTrinoHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn
diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py
index 513627f..23ee7e5 100644
--- a/tests/providers/vertica/hooks/test_vertica.py
+++ b/tests/providers/vertica/hooks/test_vertica.py
@@ -55,7 +55,7 @@ class TestVerticaHook(unittest.TestCase):
     def setUp(self):
         super().setUp()
 
-        self.cur = mock.MagicMock()
+        self.cur = mock.MagicMock(rowcount=0)
         self.conn = mock.MagicMock()
         self.conn.cursor.return_value = self.cur
         conn = self.conn