You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/01/23 22:06:58 UTC

[arrow-adbc] branch main updated: feat(python/adbc_driver_manager): add more sanity checking (#370)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new f763cd1  feat(python/adbc_driver_manager): add more sanity checking (#370)
f763cd1 is described below

commit f763cd113ce417cdaa3c43ace4a2beeea2fcd623
Author: David Li <li...@gmail.com>
AuthorDate: Mon Jan 23 17:06:52 2023 -0500

    feat(python/adbc_driver_manager): add more sanity checking (#370)
    
    Fixes #368.
---
 .../adbc_driver_manager/_lib.pyx                   | 84 +++++++++++++++++-----
 python/adbc_driver_manager/tests/test_lowlevel.py  | 21 ++++++
 2 files changed, 87 insertions(+), 18 deletions(-)

diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 0e619d9..7598b69 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -20,6 +20,7 @@
 """Low-level ADBC API."""
 
 import enum
+import threading
 import typing
 from typing import List, Tuple
 
@@ -389,12 +390,40 @@ cdef class _AdbcHandle:
     """
     Base class for ADBC handles, which are context managers.
     """
+
+    cdef:
+        size_t _open_children
+        object _lock
+        str _child_type
+
+    def __init__(self, str child_type):
+        self._lock = threading.Lock()
+        self._child_type = child_type
+
     def __enter__(self) -> "Self":
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         self.close()
 
+    cdef _open_child(self):
+        with self._lock:
+            self._open_children += 1
+
+    cdef _close_child(self):
+        with self._lock:
+            if self._open_children == 0:
+                raise RuntimeError(
+                    f"Underflow in closing this {self._child_type}")
+            self._open_children -= 1
+
+    cdef _check_open_children(self):
+        with self._lock:
+            if self._open_children != 0:
+                raise RuntimeError(
+                    f"Cannot close {self.__class__.__name__} "
+                    f"with open {self._child_type}")
+
 
 cdef class ArrowSchemaHandle:
     """
@@ -458,6 +487,7 @@ cdef class AdbcDatabase(_AdbcHandle):
         CAdbcDatabase database
 
     def __init__(self, **kwargs) -> None:
+        super().__init__("AdbcConnection")
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
         cdef const char* c_key
@@ -487,14 +517,16 @@ cdef class AdbcDatabase(_AdbcHandle):
 
     def close(self) -> None:
         """Release the handle to the database."""
-        if self.database.private_data == NULL:
-            return
-
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
-        with nogil:
-            status = AdbcDatabaseRelease(&self.database, &c_error)
-        check_error(status, &c_error)
+        self._check_open_children()
+        with self._lock:
+            if self.database.private_data == NULL:
+                return
+
+            with nogil:
+                status = AdbcDatabaseRelease(&self.database, &c_error)
+            check_error(status, &c_error)
 
 
 cdef class AdbcConnection(_AdbcHandle):
@@ -516,6 +548,7 @@ cdef class AdbcConnection(_AdbcHandle):
         CAdbcConnection connection
 
     def __init__(self, AdbcDatabase database, **kwargs) -> None:
+        super().__init__("AdbcStatement")
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
         cdef const char* c_key
@@ -534,12 +567,18 @@ cdef class AdbcConnection(_AdbcHandle):
             c_key = key
             c_value = value
             status = AdbcConnectionSetOption(&self.connection, c_key, c_value, &c_error)
+            if status != ADBC_STATUS_OK:
+                AdbcConnectionRelease(&self.connection, NULL)
             check_error(status, &c_error)
 
         with nogil:
             status = AdbcConnectionInit(&self.connection, &database.database, &c_error)
+            if status != ADBC_STATUS_OK:
+                AdbcConnectionRelease(&self.connection, NULL)
         check_error(status, &c_error)
 
+        database._open_child()
+
     def commit(self) -> None:
         """Commit the current transaction."""
         cdef CAdbcError c_error = empty_error()
@@ -719,15 +758,17 @@ cdef class AdbcConnection(_AdbcHandle):
 
     def close(self) -> None:
         """Release the handle to the connection."""
-        if self.connection.private_data == NULL:
-            return
-
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
+        self._check_open_children()
+        with self._lock:
+            if self.connection.private_data == NULL:
+                return
 
-        with nogil:
-            status = AdbcConnectionRelease(&self.connection, &c_error)
-        check_error(status, &c_error)
+            with nogil:
+                status = AdbcConnectionRelease(&self.connection, &c_error)
+            check_error(status, &c_error)
+            self.database._close_child()
 
 
 cdef class AdbcStatement(_AdbcHandle):
@@ -743,10 +784,13 @@ cdef class AdbcStatement(_AdbcHandle):
         The connection to create the statement for.
     """
     cdef:
+        AdbcConnection connection
         CAdbcStatement statement
 
     def __init__(self, AdbcConnection connection) -> None:
+        super().__init__("(no child type)")
         cdef CAdbcError c_error = empty_error()
+        self.connection = connection
         memset(&self.statement, 0, cython.sizeof(CAdbcStatement))
 
         with nogil:
@@ -756,6 +800,8 @@ cdef class AdbcStatement(_AdbcHandle):
                 &c_error)
         check_error(status, &c_error)
 
+        connection._open_child()
+
     def bind(self, data, schema) -> None:
         """
         Bind an ArrowArray to this statement.
@@ -819,14 +865,16 @@ cdef class AdbcStatement(_AdbcHandle):
         check_error(status, &c_error)
 
     def close(self) -> None:
-        if self.statement.private_data == NULL:
-            return
-
         cdef CAdbcError c_error = empty_error()
         cdef CAdbcStatusCode status
-        with nogil:
-            status = AdbcStatementRelease(&self.statement, &c_error)
-        check_error(status, &c_error)
+        self.connection._close_child()
+        with self._lock:
+            if self.statement.private_data == NULL:
+                return
+
+            with nogil:
+                status = AdbcStatementRelease(&self.statement, &c_error)
+            check_error(status, &c_error)
 
     def execute_query(self) -> Tuple[ArrowArrayStreamHandle, int]:
         """
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index c8f9de7..4233ee0 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -254,3 +254,24 @@ def test_autocommit(sqlite):
         handle, _ = stmt.execute_query()
         table = _import(handle).read_all()
         assert table == pyarrow.Table.from_batches([data])
+
+
+@pytest.mark.sqlite
+def test_child_tracking(sqlite):
+    with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
+        with adbc_driver_manager.AdbcConnection(db) as conn:
+            with adbc_driver_manager.AdbcStatement(conn):
+                with pytest.raises(
+                    RuntimeError,
+                    match="Cannot close AdbcDatabase with open AdbcConnection",
+                ):
+                    db.close()
+                with pytest.raises(
+                    RuntimeError,
+                    match="Cannot close AdbcConnection with open AdbcStatement",
+                ):
+                    conn.close()
+            with pytest.raises(
+                RuntimeError, match="Cannot close AdbcDatabase with open AdbcConnection"
+            ):
+                db.close()