You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/01/23 22:06:58 UTC
[arrow-adbc] branch main updated: feat(python/adbc_driver_manager): add more sanity checking (#370)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new f763cd1 feat(python/adbc_driver_manager): add more sanity checking (#370)
f763cd1 is described below
commit f763cd113ce417cdaa3c43ace4a2beeea2fcd623
Author: David Li <li...@gmail.com>
AuthorDate: Mon Jan 23 17:06:52 2023 -0500
feat(python/adbc_driver_manager): add more sanity checking (#370)
Fixes #368.
---
.../adbc_driver_manager/_lib.pyx | 84 +++++++++++++++++-----
python/adbc_driver_manager/tests/test_lowlevel.py | 21 ++++++
2 files changed, 87 insertions(+), 18 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 0e619d9..7598b69 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -20,6 +20,7 @@
"""Low-level ADBC API."""
import enum
+import threading
import typing
from typing import List, Tuple
@@ -389,12 +390,40 @@ cdef class _AdbcHandle:
"""
Base class for ADBC handles, which are context managers.
"""
+
+ cdef:
+ size_t _open_children
+ object _lock
+ str _child_type
+
+ def __init__(self, str child_type):
+ self._lock = threading.Lock()
+ self._child_type = child_type
+
def __enter__(self) -> "Self":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()
+ cdef _open_child(self):
+ with self._lock:
+ self._open_children += 1
+
+ cdef _close_child(self):
+ with self._lock:
+ if self._open_children == 0:
+ raise RuntimeError(
+ f"Underflow in closing this {self._child_type}")
+ self._open_children -= 1
+
+ cdef _check_open_children(self):
+ with self._lock:
+ if self._open_children != 0:
+ raise RuntimeError(
+ f"Cannot close {self.__class__.__name__} "
+ f"with open {self._child_type}")
+
cdef class ArrowSchemaHandle:
"""
@@ -458,6 +487,7 @@ cdef class AdbcDatabase(_AdbcHandle):
CAdbcDatabase database
def __init__(self, **kwargs) -> None:
+ super().__init__("AdbcConnection")
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
cdef const char* c_key
@@ -487,14 +517,16 @@ cdef class AdbcDatabase(_AdbcHandle):
def close(self) -> None:
"""Release the handle to the database."""
- if self.database.private_data == NULL:
- return
-
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
- with nogil:
- status = AdbcDatabaseRelease(&self.database, &c_error)
- check_error(status, &c_error)
+ self._check_open_children()
+ with self._lock:
+ if self.database.private_data == NULL:
+ return
+
+ with nogil:
+ status = AdbcDatabaseRelease(&self.database, &c_error)
+ check_error(status, &c_error)
cdef class AdbcConnection(_AdbcHandle):
@@ -516,6 +548,7 @@ cdef class AdbcConnection(_AdbcHandle):
CAdbcConnection connection
def __init__(self, AdbcDatabase database, **kwargs) -> None:
+ super().__init__("AdbcStatement")
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
cdef const char* c_key
@@ -534,12 +567,18 @@ cdef class AdbcConnection(_AdbcHandle):
c_key = key
c_value = value
status = AdbcConnectionSetOption(&self.connection, c_key, c_value, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
check_error(status, &c_error)
with nogil:
status = AdbcConnectionInit(&self.connection, &database.database, &c_error)
+ if status != ADBC_STATUS_OK:
+ AdbcConnectionRelease(&self.connection, NULL)
check_error(status, &c_error)
+ database._open_child()
+
def commit(self) -> None:
"""Commit the current transaction."""
cdef CAdbcError c_error = empty_error()
@@ -719,15 +758,17 @@ cdef class AdbcConnection(_AdbcHandle):
def close(self) -> None:
"""Release the handle to the connection."""
- if self.connection.private_data == NULL:
- return
-
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
+ self._check_open_children()
+ with self._lock:
+ if self.connection.private_data == NULL:
+ return
- with nogil:
- status = AdbcConnectionRelease(&self.connection, &c_error)
- check_error(status, &c_error)
+ with nogil:
+ status = AdbcConnectionRelease(&self.connection, &c_error)
+ check_error(status, &c_error)
+ self.database._close_child()
cdef class AdbcStatement(_AdbcHandle):
@@ -743,10 +784,13 @@ cdef class AdbcStatement(_AdbcHandle):
The connection to create the statement for.
"""
cdef:
+ AdbcConnection connection
CAdbcStatement statement
def __init__(self, AdbcConnection connection) -> None:
+ super().__init__("(no child type)")
cdef CAdbcError c_error = empty_error()
+ self.connection = connection
memset(&self.statement, 0, cython.sizeof(CAdbcStatement))
with nogil:
@@ -756,6 +800,8 @@ cdef class AdbcStatement(_AdbcHandle):
&c_error)
check_error(status, &c_error)
+ connection._open_child()
+
def bind(self, data, schema) -> None:
"""
Bind an ArrowArray to this statement.
@@ -819,14 +865,16 @@ cdef class AdbcStatement(_AdbcHandle):
check_error(status, &c_error)
def close(self) -> None:
- if self.statement.private_data == NULL:
- return
-
cdef CAdbcError c_error = empty_error()
cdef CAdbcStatusCode status
- with nogil:
- status = AdbcStatementRelease(&self.statement, &c_error)
- check_error(status, &c_error)
+ self.connection._close_child()
+ with self._lock:
+ if self.statement.private_data == NULL:
+ return
+
+ with nogil:
+ status = AdbcStatementRelease(&self.statement, &c_error)
+ check_error(status, &c_error)
def execute_query(self) -> Tuple[ArrowArrayStreamHandle, int]:
"""
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index c8f9de7..4233ee0 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -254,3 +254,24 @@ def test_autocommit(sqlite):
handle, _ = stmt.execute_query()
table = _import(handle).read_all()
assert table == pyarrow.Table.from_batches([data])
+
+
+@pytest.mark.sqlite
+def test_child_tracking(sqlite):
+ with adbc_driver_manager.AdbcDatabase(driver="adbc_driver_sqlite") as db:
+ with adbc_driver_manager.AdbcConnection(db) as conn:
+ with adbc_driver_manager.AdbcStatement(conn):
+ with pytest.raises(
+ RuntimeError,
+ match="Cannot close AdbcDatabase with open AdbcConnection",
+ ):
+ db.close()
+ with pytest.raises(
+ RuntimeError,
+ match="Cannot close AdbcConnection with open AdbcStatement",
+ ):
+ conn.close()
+ with pytest.raises(
+ RuntimeError, match="Cannot close AdbcDatabase with open AdbcConnection"
+ ):
+ db.close()