You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/13 21:41:24 UTC
[airflow] branch main updated: Add method 'callproc' on Oracle hook (#20072)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c7f36f2 Add method 'callproc' on Oracle hook (#20072)
c7f36f2 is described below
commit c7f36f25cb1d7d35a658e08552c4b6ac480e0cbf
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Mon Dec 13 22:40:55 2021 +0100
Add method 'callproc' on Oracle hook (#20072)
---
airflow/providers/oracle/hooks/oracle.py | 67 ++++++++++++++++++++++++-
airflow/providers/oracle/operators/oracle.py | 40 ++++++++++++++-
tests/providers/oracle/hooks/test_oracle.py | 38 ++++++++++++++
tests/providers/oracle/operators/test_oracle.py | 28 ++++++++++-
4 files changed, 169 insertions(+), 4 deletions(-)
diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py
index 057ec1a..f079197 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -17,13 +17,26 @@
# under the License.
from datetime import datetime
-from typing import List, Optional
+from typing import Dict, List, Optional, TypeVar
import cx_Oracle
import numpy
from airflow.hooks.dbapi import DbApiHook
+PARAM_TYPES = {bool, float, int, str}
+
+ParameterType = TypeVar('ParameterType', Dict, List, None)
+
+
+def _map_param(value):
+ if value in PARAM_TYPES:
+ # In this branch, value is a Python type; calling it produces
+ # an instance of the type which is understood by the Oracle driver
+ # in the out parameter mapping mechanism.
+ value = value()
+ return value
+
class OracleHook(DbApiHook):
"""
@@ -266,3 +279,55 @@ class OracleHook(DbApiHook):
self.log.info('[%s] inserted %s rows', table, row_count)
cursor.close()
conn.close() # type: ignore[attr-defined]
+
+ def callproc(
+ self,
+ identifier: str,
+ autocommit: bool = False,
+ parameters: ParameterType = None,
+ ) -> ParameterType:
+ """
+ Call the stored procedure identified by the provided string.
+
+ Any 'OUT parameters' must be provided with a value of either the
+ expected Python type (e.g., `int`) or an instance of that type.
+
+ The return value is a list or mapping that includes parameters in
+ both directions; the actual return type depends on the type of the
+ provided `parameters` argument.
+
+ See
+ https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var
+ for further reference.
+ """
+ if parameters is None:
+ parameters = ()
+
+ args = ",".join(
+ f":{name}"
+ for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1))
+ )
+
+ sql = f"BEGIN {identifier}({args}); END;"
+
+ def handler(cursor):
+ if isinstance(cursor.bindvars, list):
+ return [v.getvalue() for v in cursor.bindvars]
+
+ if isinstance(cursor.bindvars, dict):
+ return {n: v.getvalue() for (n, v) in cursor.bindvars.items()}
+
+ raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}")
+
+ result = self.run(
+ sql,
+ autocommit=autocommit,
+ parameters=(
+ {name: _map_param(value) for (name, value) in parameters.items()}
+ if isinstance(parameters, dict)
+ else [_map_param(value) for value in parameters]
+ ),
+ handler=handler,
+ )
+
+ return result
diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py
index dcc07a2..b80d570 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Iterable, List, Mapping, Optional, Union
+from typing import Dict, Iterable, List, Mapping, Optional, Union
from airflow.models import BaseOperator
from airflow.providers.oracle.hooks.oracle import OracleHook
@@ -62,4 +62,40 @@ class OracleOperator(BaseOperator):
def execute(self, context) -> None:
self.log.info('Executing: %s', self.sql)
hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
- hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+ if self.sql:
+ hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+
+
+class OracleStoredProcedureOperator(BaseOperator):
+ """
+ Executes stored procedure in a specific Oracle database.
+
+ :param procedure: name of stored procedure to call (templated)
+ :type procedure: str
+ :param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
+ reference to a specific Oracle database.
+ :type oracle_conn_id: str
+ :param parameters: (optional) the parameters provided in the call
+ :type parameters: dict or iterable
+ """
+
+ template_fields = ('procedure',)
+ ui_color = '#ededed'
+
+ def __init__(
+ self,
+ *,
+ procedure: str,
+ oracle_conn_id: str = 'oracle_default',
+ parameters: Optional[Union[Dict, List]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.oracle_conn_id = oracle_conn_id
+ self.procedure = procedure
+ self.parameters = parameters
+
+ def execute(self, context) -> None:
+ self.log.info('Executing: %s', self.procedure)
+ hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
+ return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index 0101a34..0f5c7df 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -291,3 +291,41 @@ class TestOracleHook(unittest.TestCase):
rows = []
with pytest.raises(ValueError):
self.db_hook.bulk_insert_rows('table', rows)
+
+ def test_callproc_dict(self):
+ parameters = {"a": 1, "b": 2, "c": 3}
+
+ class bindvar(int):
+ def getvalue(self):
+ return self
+
+ self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
+ result = self.db_hook.callproc('proc', True, parameters)
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
+ assert result == parameters
+
+ def test_callproc_list(self):
+ parameters = [1, 2, 3]
+
+ class bindvar(int):
+ def getvalue(self):
+ return self
+
+ self.cur.bindvars = list(map(bindvar, parameters))
+ result = self.db_hook.callproc('proc', True, parameters)
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
+ assert result == parameters
+
+ def test_callproc_out_param(self):
+ parameters = [1, int, float, bool, str]
+
+ def bindvar(value):
+ m = mock.Mock()
+ m.getvalue.return_value = value
+ return m
+
+ self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
+ result = self.db_hook.callproc('proc', True, parameters)
+ expected = [1, 0, 0.0, False, '']
+ assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
+ assert result == expected
diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py
index 8565efe..40359f6 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -19,7 +19,7 @@ import unittest
from unittest import mock
from airflow.providers.oracle.hooks.oracle import OracleHook
-from airflow.providers.oracle.operators.oracle import OracleOperator
+from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator
class TestOracleOperator(unittest.TestCase):
@@ -46,3 +46,29 @@ class TestOracleOperator(unittest.TestCase):
autocommit=autocommit,
parameters=parameters,
)
+
+
+class TestOracleStoredProcedureOperator(unittest.TestCase):
+ @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
+ def test_execute(self, mock_run):
+ procedure = 'test'
+ oracle_conn_id = 'oracle_default'
+ parameters = {'parameter': 'value'}
+ context = "test_context"
+ task_id = "test_task_id"
+
+ operator = OracleStoredProcedureOperator(
+ procedure=procedure,
+ oracle_conn_id=oracle_conn_id,
+ parameters=parameters,
+ task_id=task_id,
+ )
+ result = operator.execute(context=context)
+ assert result is mock_run.return_value
+ mock_run.assert_called_once_with(
+ mock.ANY,
+ 'BEGIN test(:parameter); END;',
+ autocommit=True,
+ parameters=parameters,
+ handler=mock.ANY,
+ )