You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/07 20:38:34 UTC

[airflow] branch main updated: [Oracle] Oracle Hook - automatically set current_schema when defined in Connection (#19084)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 471e368  [Oracle] Oracle Hook - automatically set current_schema when defined in Connection (#19084)
471e368 is described below

commit 471e368eacbcae1eedf9b7e1cb4290c385396ea9
Author: mehmax <84...@users.noreply.github.com>
AuthorDate: Mon Feb 7 21:37:51 2022 +0100

    [Oracle] Oracle Hook - automatically set current_schema when defined in Connection (#19084)
---
 airflow/providers/oracle/hooks/oracle.py    | 20 ++++++++++++++++++--
 tests/providers/oracle/hooks/test_oracle.py |  4 ++++
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py
index 95b5fdc..9ebed94 100644
--- a/airflow/providers/oracle/hooks/oracle.py
+++ b/airflow/providers/oracle/hooks/oracle.py
@@ -16,6 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import warnings
 from datetime import datetime
 from typing import Dict, List, Optional, Union
 
@@ -87,6 +88,7 @@ class OracleHook(DbApiHook):
         conn_config = {'user': conn.login, 'password': conn.password}
         sid = conn.extra_dejson.get('sid')
         mod = conn.extra_dejson.get('module')
+        schema = conn.schema
 
         service_name = conn.extra_dejson.get('service_name')
         port = conn.port if conn.port else 1521
@@ -100,8 +102,16 @@ class OracleHook(DbApiHook):
                 dsn = conn.host
                 if conn.port is not None:
                     dsn += ":" + str(conn.port)
-                if service_name or conn.schema:
-                    dsn += "/" + (service_name or conn.schema)
+                if service_name:
+                    dsn += "/" + service_name
+                elif conn.schema:
+                    warnings.warn(
+                        """Using conn.schema to pass the Oracle Service Name is deprecated.
+                        Please use conn.extra.service_name instead.""",
+                        DeprecationWarning,
+                        stacklevel=2,
+                    )
+                    dsn += "/" + conn.schema
             conn_config['dsn'] = dsn
 
         if 'encoding' in conn.extra_dejson:
@@ -146,6 +156,12 @@ class OracleHook(DbApiHook):
         if mod is not None:
             conn.module = mod
 
+        # if Connection.schema is defined, set schema after connecting successfully
+        # cannot be part of conn_config
+        # https://cx-oracle.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema
+        if schema is not None:
+            conn.current_schema = schema
+
         return conn
 
     def insert_rows(
diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py
index 3eae248..9217fec 100644
--- a/tests/providers/oracle/hooks/test_oracle.py
+++ b/tests/providers/oracle/hooks/test_oracle.py
@@ -176,6 +176,10 @@ class TestOracleHookConn(unittest.TestCase):
             assert args == ()
             assert kwargs['purity'] == purity.get(pur)
 
+    @mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect')
+    def test_set_current_schema(self, mock_connect):
+        assert self.db_hook.get_conn().current_schema == self.connection.schema
+
 
 @unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present')
 class TestOracleHook(unittest.TestCase):