You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/05/17 16:30:24 UTC
[airflow] branch master updated: Add optional result handler to
database hooks (#15581)
This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new abcd487 Add optional result handler to database hooks (#15581)
abcd487 is described below
commit abcd48731303d9e141bdc94acc2db46d73ccbe12
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Mon May 17 18:30:05 2021 +0200
Add optional result handler to database hooks (#15581)
This allows retrieving result rows, or other post-query access
to the cursor object.
Co-authored-by: Ash Berlin-Taylor <as...@apache.org>
---
CHANGELOG.txt | 1 +
airflow/hooks/dbapi.py | 41 ++++++++++++++++------
airflow/providers/oracle/operators/oracle.py | 2 +-
tests/hooks/test_dbapi.py | 36 ++++++++++++++++++-
tests/providers/apache/druid/hooks/test_druid.py | 2 +-
tests/providers/apache/pinot/hooks/test_pinot.py | 2 +-
.../elasticsearch/hooks/test_elasticsearch.py | 2 +-
tests/providers/exasol/hooks/test_exasol.py | 2 +-
tests/providers/mysql/hooks/test_mysql.py | 2 +-
tests/providers/oracle/hooks/test_oracle.py | 2 +-
tests/providers/oracle/operators/test_oracle.py | 10 ++++--
tests/providers/postgres/hooks/test_postgres.py | 2 +-
tests/providers/presto/hooks/test_presto.py | 2 +-
tests/providers/snowflake/hooks/test_snowflake.py | 4 +--
tests/providers/sqlite/hooks/test_sqlite.py | 2 +-
tests/providers/trino/hooks/test_trino.py | 2 +-
tests/providers/vertica/hooks/test_vertica.py | 2 +-
17 files changed, 88 insertions(+), 28 deletions(-)
diff --git a/CHANGELOG.txt b/CHANGELOG.txt
index 9719a14..34f4c38 100644
--- a/CHANGELOG.txt
+++ b/CHANGELOG.txt
@@ -43,6 +43,7 @@ New Features
Improvements
""""""""""""
+- Add optional result handler callback to ``DbApiHook`` (#15581)
- Update Flask App Builder limit to recently released 3.3 (#15792)
- Prevent creating flask sessions on REST API requests (#15295)
- Sync DAG specific permissions when parsing (#15311)
diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index 9821643..6c00320 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -152,7 +152,7 @@ class DbApiHook(BaseHook):
cur.execute(sql)
return cur.fetchone()
- def run(self, sql, autocommit=False, parameters=None):
+ def run(self, sql, autocommit=False, parameters=None, handler=None):
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -166,8 +166,12 @@ class DbApiHook(BaseHook):
:type autocommit: bool
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
+ :param handler: The result handler which is called with the result of each statement.
+ :type handler: callable
+ :return: query results if handler was provided.
"""
- if isinstance(sql, str):
+ scalar = isinstance(sql, str)
+ if scalar:
sql = [sql]
with closing(self.get_conn()) as conn:
@@ -175,21 +179,38 @@ class DbApiHook(BaseHook):
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
+ results = []
for sql_statement in sql:
-
- self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
- if parameters:
- cur.execute(sql_statement, parameters)
- else:
- cur.execute(sql_statement)
- if hasattr(cur, 'rowcount'):
- self.log.info("Rows affected: %s", cur.rowcount)
+ self._run_command(cur, sql_statement, parameters)
+ if handler is not None:
+ result = handler(cur)
+ results.append(result)
# If autocommit was set to False for db that supports autocommit,
# or if db does not supports autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()
+ if handler is None:
+ return None
+
+ if scalar:
+ return results[0]
+
+ return results
+
+ def _run_command(self, cur, sql_statement, parameters):
+ """Runs a statement using an already open cursor."""
+ self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
+ if parameters:
+ cur.execute(sql_statement, parameters)
+ else:
+ cur.execute(sql_statement)
+
+ # According to PEP 249, this is -1 when query result is not applicable.
+ if cur.rowcount >= 0:
+ self.log.info("Rows affected: %s", cur.rowcount)
+
def set_autocommit(self, conn, autocommit):
"""Sets the autocommit flag on the connection"""
if not self.supports_autocommit and autocommit:
diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py
index d764225..41e9c31 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -23,7 +23,7 @@ from airflow.providers.oracle.hooks.oracle import OracleHook
class OracleOperator(BaseOperator):
"""
- Executes sql code in a specific Oracle database
+ Executes sql code in a specific Oracle database.
:param sql: the sql code to be executed. Can receive a str representing a sql statement,
a list of str (sql statements), or reference to a template file.
diff --git a/tests/hooks/test_dbapi.py b/tests/hooks/test_dbapi.py
index 0f6c55a..383d69e 100644
--- a/tests/hooks/test_dbapi.py
+++ b/tests/hooks/test_dbapi.py
@@ -30,7 +30,7 @@ class TestDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
@@ -184,3 +184,37 @@ class TestDbApiHook(unittest.TestCase):
statement = 'SQL'
self.db_hook.run(statement)
assert self.db_hook.log.info.call_count == 2
+
+ def test_run_with_handler(self):
+ sql = 'SQL'
+ param = ('p1', 'p2')
+ called = 0
+ obj = object()
+
+ def handler(cur):
+ cur.execute.assert_called_once_with(sql, param)
+ nonlocal called
+ called += 1
+ return obj
+
+ result = self.db_hook.run(sql, parameters=param, handler=handler)
+ assert called == 1
+ assert self.conn.commit.called
+ assert result == obj
+
+ def test_run_with_handler_multiple(self):
+ sql = ['SQL', 'SQL']
+ param = ('p1', 'p2')
+ called = 0
+ obj = object()
+
+ def handler(cur):
+ cur.execute.assert_called_with(sql[0], param)
+ nonlocal called
+ called += 1
+ return obj
+
+ result = self.db_hook.run(sql, parameters=param, handler=handler)
+ assert called == 2
+ assert self.conn.commit.called
+ assert result == [obj, obj]
diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py
index afd7c1c..fca5d06 100644
--- a/tests/providers/apache/druid/hooks/test_druid.py
+++ b/tests/providers/apache/druid/hooks/test_druid.py
@@ -192,7 +192,7 @@ class TestDruidHook(unittest.TestCase):
class TestDruidDbApiHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = MagicMock()
+ self.cur = MagicMock(rowcount=0)
self.conn = conn = MagicMock()
self.conn.host = 'host'
self.conn.port = '1000'
diff --git a/tests/providers/apache/pinot/hooks/test_pinot.py b/tests/providers/apache/pinot/hooks/test_pinot.py
index af0676f..763704f 100644
--- a/tests/providers/apache/pinot/hooks/test_pinot.py
+++ b/tests/providers/apache/pinot/hooks/test_pinot.py
@@ -212,7 +212,7 @@ class TestPinotDbApiHook(unittest.TestCase):
self.conn.port = '1000'
self.conn.conn_type = 'http'
self.conn.extra_dejson = {'endpoint': 'query/sql'}
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn.cursor.return_value = self.cur
self.conn.__enter__.return_value = self.cur
self.conn.__exit__.return_value = None
diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
index b96f64e..06057b0 100644
--- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py
+++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py
@@ -48,7 +48,7 @@ class TestElasticsearchHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py
index 6dc4299..607d11c 100644
--- a/tests/providers/exasol/hooks/test_exasol.py
+++ b/tests/providers/exasol/hooks/test_exasol.py
@@ -70,7 +70,7 @@ class TestExasolHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.execute.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py
index c50b8a0..c18be07 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -236,7 +236,7 @@ class TestMySqlHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index d22a0b4..b3a6a61 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -183,7 +183,7 @@ class TestOracleHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py
index 2c4d984..8565efe 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -23,7 +23,7 @@ from airflow.providers.oracle.operators.oracle import OracleOperator
class TestOracleOperator(unittest.TestCase):
- @mock.patch.object(OracleHook, 'run')
+ @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
def test_execute(self, mock_run):
sql = 'SELECT * FROM test_table'
oracle_conn_id = 'oracle_default'
@@ -40,5 +40,9 @@ class TestOracleOperator(unittest.TestCase):
task_id=task_id,
)
operator.execute(context=context)
-
- mock_run.assert_called_once_with(sql, autocommit=autocommit, parameters=parameters)
+ mock_run.assert_called_once_with(
+ mock.ANY,
+ sql,
+ autocommit=autocommit,
+ parameters=parameters,
+ )
diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py
index 2890ad7..9a0226f 100644
--- a/tests/providers/postgres/hooks/test_postgres.py
+++ b/tests/providers/postgres/hooks/test_postgres.py
@@ -148,7 +148,7 @@ class TestPostgresHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py
index e6ebb73..08278dd 100644
--- a/tests/providers/presto/hooks/test_presto.py
+++ b/tests/providers/presto/hooks/test_presto.py
@@ -147,7 +147,7 @@ class TestPrestoHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py
index f2175fa..deec76d 100644
--- a/tests/providers/snowflake/hooks/test_snowflake.py
+++ b/tests/providers/snowflake/hooks/test_snowflake.py
@@ -31,8 +31,8 @@ class TestSnowflakeHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
- self.cur2 = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
+ self.cur2 = mock.MagicMock(rowcount=0)
self.cur.sfqid = 'uuid'
self.cur2.sfqid = 'uuid2'
diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py
index 7a25479..80fd6a4 100644
--- a/tests/providers/sqlite/hooks/test_sqlite.py
+++ b/tests/providers/sqlite/hooks/test_sqlite.py
@@ -53,7 +53,7 @@ class TestSqliteHookConn(unittest.TestCase):
class TestSqliteHook(unittest.TestCase):
def setUp(self):
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py
index e649d2b..a688362 100644
--- a/tests/providers/trino/hooks/test_trino.py
+++ b/tests/providers/trino/hooks/test_trino.py
@@ -147,7 +147,7 @@ class TestTrinoHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py
index 513627f..23ee7e5 100644
--- a/tests/providers/vertica/hooks/test_vertica.py
+++ b/tests/providers/vertica/hooks/test_vertica.py
@@ -55,7 +55,7 @@ class TestVerticaHook(unittest.TestCase):
def setUp(self):
super().setUp()
- self.cur = mock.MagicMock()
+ self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn