You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/12/13 21:41:24 UTC

[airflow] branch main updated: Add method 'callproc' on Oracle hook (#20072)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c7f36f2  Add method 'callproc' on Oracle hook (#20072)
c7f36f2 is described below

commit c7f36f25cb1d7d35a658e08552c4b6ac480e0cbf
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Mon Dec 13 22:40:55 2021 +0100

    Add method 'callproc' on Oracle hook (#20072)
---
 airflow/providers/oracle/hooks/oracle.py        | 67 ++++++++++++++++++++++++-
 airflow/providers/oracle/operators/oracle.py    | 40 ++++++++++++++-
 tests/providers/oracle/hooks/test_oracle.py     | 38 ++++++++++++++
 tests/providers/oracle/operators/test_oracle.py | 28 ++++++++++-
 4 files changed, 169 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py
index 057ec1a..f079197 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -17,13 +17,26 @@
 # under the License.
 
 from datetime import datetime
-from typing import List, Optional
+from typing import Dict, List, Optional, TypeVar
 
 import cx_Oracle
 import numpy
 
 from airflow.hooks.dbapi import DbApiHook
 
+PARAM_TYPES = {bool, float, int, str}
+
+ParameterType = TypeVar('ParameterType', Dict, List, None)
+
+
+def _map_param(value):
+    if value in PARAM_TYPES:
+        # In this branch, value is a Python type; calling it produces
+        # an instance of the type which is understood by the Oracle driver
+        # in the out parameter mapping mechanism.
+        value = value()
+    return value
+
 
 class OracleHook(DbApiHook):
     """
@@ -266,3 +279,55 @@ class OracleHook(DbApiHook):
         self.log.info('[%s] inserted %s rows', table, row_count)
         cursor.close()
         conn.close()  # type: ignore[attr-defined]
+
+    def callproc(
+        self,
+        identifier: str,
+        autocommit: bool = False,
+        parameters: ParameterType = None,
+    ) -> ParameterType:
+        """
+        Call the stored procedure identified by the provided string.
+
+        Any 'OUT parameters' must be provided with a value of either the
+        expected Python type (e.g., `int`) or an instance of that type.
+
+        The return value is a list or mapping that includes parameters in
+        both directions; the actual return type depends on the type of the
+        provided `parameters` argument.
+
+        See
+        https://cx-oracle.readthedocs.io/en/latest/api_manual/cursor.html#Cursor.var
+        for further reference.
+        """
+        if parameters is None:
+            parameters = ()
+
+        args = ",".join(
+            f":{name}"
+            for name in (parameters if isinstance(parameters, dict) else range(1, len(parameters) + 1))
+        )
+
+        sql = f"BEGIN {identifier}({args}); END;"
+
+        def handler(cursor):
+            if isinstance(cursor.bindvars, list):
+                return [v.getvalue() for v in cursor.bindvars]
+
+            if isinstance(cursor.bindvars, dict):
+                return {n: v.getvalue() for (n, v) in cursor.bindvars.items()}
+
+            raise TypeError(f"Unexpected bindvars: {cursor.bindvars!r}")
+
+        result = self.run(
+            sql,
+            autocommit=autocommit,
+            parameters=(
+                {name: _map_param(value) for (name, value) in parameters.items()}
+                if isinstance(parameters, dict)
+                else [_map_param(value) for value in parameters]
+            ),
+            handler=handler,
+        )
+
+        return result
diff --git a/airflow/providers/oracle/operators/oracle.py b/airflow/providers/oracle/operators/oracle.py
index dcc07a2..b80d570 100644
--- a/airflow/providers/oracle/operators/oracle.py
+++ b/airflow/providers/oracle/operators/oracle.py
@@ -15,7 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Iterable, List, Mapping, Optional, Union
+from typing import Dict, Iterable, List, Mapping, Optional, Union
 
 from airflow.models import BaseOperator
 from airflow.providers.oracle.hooks.oracle import OracleHook
@@ -62,4 +62,40 @@ class OracleOperator(BaseOperator):
     def execute(self, context) -> None:
         self.log.info('Executing: %s', self.sql)
         hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
-        hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+        if self.sql:
+            hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+
+
+class OracleStoredProcedureOperator(BaseOperator):
+    """
+    Executes stored procedure in a specific Oracle database.
+
+    :param procedure: name of stored procedure to call (templated)
+    :type procedure: str
+    :param oracle_conn_id: The :ref:`Oracle connection id <howto/connection:oracle>`
+        reference to a specific Oracle database.
+    :type oracle_conn_id: str
+    :param parameters: (optional) the parameters provided in the call
+    :type parameters: dict or iterable
+    """
+
+    template_fields = ('procedure',)
+    ui_color = '#ededed'
+
+    def __init__(
+        self,
+        *,
+        procedure: str,
+        oracle_conn_id: str = 'oracle_default',
+        parameters: Optional[Union[Dict, List]] = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.oracle_conn_id = oracle_conn_id
+        self.procedure = procedure
+        self.parameters = parameters
+
+    def execute(self, context) -> None:
+        self.log.info('Executing: %s', self.procedure)
+        hook = OracleHook(oracle_conn_id=self.oracle_conn_id)
+        return hook.callproc(self.procedure, autocommit=True, parameters=self.parameters)
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index 0101a34..0f5c7df 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -291,3 +291,41 @@ class TestOracleHook(unittest.TestCase):
         rows = []
         with pytest.raises(ValueError):
             self.db_hook.bulk_insert_rows('table', rows)
+
+    def test_callproc_dict(self):
+        parameters = {"a": 1, "b": 2, "c": 3}
+
+        class bindvar(int):
+            def getvalue(self):
+                return self
+
+        self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
+        result = self.db_hook.callproc('proc', True, parameters)
+        assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
+        assert result == parameters
+
+    def test_callproc_list(self):
+        parameters = [1, 2, 3]
+
+        class bindvar(int):
+            def getvalue(self):
+                return self
+
+        self.cur.bindvars = list(map(bindvar, parameters))
+        result = self.db_hook.callproc('proc', True, parameters)
+        assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
+        assert result == parameters
+
+    def test_callproc_out_param(self):
+        parameters = [1, int, float, bool, str]
+
+        def bindvar(value):
+            m = mock.Mock()
+            m.getvalue.return_value = value
+            return m
+
+        self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
+        result = self.db_hook.callproc('proc', True, parameters)
+        expected = [1, 0, 0.0, False, '']
+        assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
+        assert result == expected
diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py
index 8565efe..40359f6 100644
--- a/tests/providers/oracle/operators/test_oracle.py
+++ b/tests/providers/oracle/operators/test_oracle.py
@@ -19,7 +19,7 @@ import unittest
 from unittest import mock
 
 from airflow.providers.oracle.hooks.oracle import OracleHook
-from airflow.providers.oracle.operators.oracle import OracleOperator
+from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator
 
 
 class TestOracleOperator(unittest.TestCase):
@@ -46,3 +46,29 @@ class TestOracleOperator(unittest.TestCase):
             autocommit=autocommit,
             parameters=parameters,
         )
+
+
+class TestOracleStoredProcedureOperator(unittest.TestCase):
+    @mock.patch.object(OracleHook, 'run', autospec=OracleHook.run)
+    def test_execute(self, mock_run):
+        procedure = 'test'
+        oracle_conn_id = 'oracle_default'
+        parameters = {'parameter': 'value'}
+        context = "test_context"
+        task_id = "test_task_id"
+
+        operator = OracleStoredProcedureOperator(
+            procedure=procedure,
+            oracle_conn_id=oracle_conn_id,
+            parameters=parameters,
+            task_id=task_id,
+        )
+        result = operator.execute(context=context)
+        assert result is mock_run.return_value
+        mock_run.assert_called_once_with(
+            mock.ANY,
+            'BEGIN test(:parameter); END;',
+            autocommit=True,
+            parameters=parameters,
+            handler=mock.ANY,
+        )