You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/07/20 18:33:51 UTC
[arrow-adbc] branch main updated: [Python] Complete minimal bindings for ADBC (#41)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new cf43e0c [Python] Complete minimal bindings for ADBC (#41)
cf43e0c is described below
commit cf43e0cc2ae15ad0ce669b531d475ee218698100
Author: David Li <li...@gmail.com>
AuthorDate: Wed Jul 20 14:33:47 2022 -0400
[Python] Complete minimal bindings for ADBC (#41)
* [Python] Complete minimal bindings for ADBC
* [CI] Fix CI conditions
* [CI] Set up flake8
* [CI] flake8 <4 does not work with Python >=3.10
---
.github/workflows/cpp.yml | 8 +-
c/driver_manager/adbc_driver_manager.cc | 78 ++-
.../adbc_driver_manager/__init__.py | 21 +
.../adbc_driver_manager/_lib.pyx | 561 ++++++++++++++++++---
.../adbc_driver_manager/tests/test_lowlevel.py | 190 ++++++-
python/adbc_driver_manager/setup.py | 2 +-
6 files changed, 750 insertions(+), 110 deletions(-)
diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml
index db36354..b4812bf 100644
--- a/.github/workflows/cpp.yml
+++ b/.github/workflows/cpp.yml
@@ -23,14 +23,14 @@ on:
- main
paths:
- "adbc.h"
- - "adbc_driver_manager/**"
- - "drivers/**"
+ - "c/**"
+ - "python/**"
- ".github/workflows/cpp.yml"
push:
paths:
- "adbc.h"
- - "adbc_driver_manager/**"
- - "drivers/**"
+ - "c/**"
+ - "python/**"
- ".github/workflows/cpp.yml"
concurrency:
diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc
index 1a9dc53..91b2f79 100644
--- a/c/driver_manager/adbc_driver_manager.cc
+++ b/c/driver_manager/adbc_driver_manager.cc
@@ -67,11 +67,6 @@ AdbcStatusCode ConnectionCommit(struct AdbcConnection*, struct AdbcError* error)
return ADBC_STATUS_NOT_IMPLEMENTED;
}
-AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection*, struct AdbcStatement*,
- struct AdbcError* error) {
- return ADBC_STATUS_NOT_IMPLEMENTED;
-}
-
AdbcStatusCode ConnectionGetObjects(struct AdbcConnection*, int, const char*, const char*,
const char*, const char**, const char*,
struct AdbcStatement*, struct AdbcError* error) {
@@ -84,6 +79,11 @@ AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection*, const char*, con
return ADBC_STATUS_NOT_IMPLEMENTED;
}
+AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection*, struct AdbcStatement*,
+ struct AdbcError* error) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
AdbcStatusCode ConnectionRollback(struct AdbcConnection*, struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
@@ -102,6 +102,16 @@ AdbcStatusCode StatementExecute(struct AdbcStatement*, struct AdbcError* error)
return ADBC_STATUS_NOT_IMPLEMENTED;
}
+AdbcStatusCode StatementGetPartitionDesc(struct AdbcStatement*, uint8_t*,
+ struct AdbcError*) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
+AdbcStatusCode StatementGetPartitionDescSize(struct AdbcStatement*, size_t*,
+ struct AdbcError*) {
+ return ADBC_STATUS_NOT_IMPLEMENTED;
+}
+
AdbcStatusCode StatementPrepare(struct AdbcStatement*, struct AdbcError* error) {
return ADBC_STATUS_NOT_IMPLEMENTED;
}
@@ -310,6 +320,42 @@ AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection,
return connection->private_driver->ConnectionCommit(connection, error);
}
+AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth,
+ const char* catalog, const char* db_schema,
+ const char* table_name, const char** table_types,
+ const char* column_name,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!connection->private_driver) {
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return connection->private_driver->ConnectionGetObjects(
+ connection, depth, catalog, db_schema, table_name, table_types, column_name,
+ statement, error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection,
+ const char* catalog, const char* db_schema,
+ const char* table_name,
+ struct ArrowSchema* schema,
+ struct AdbcError* error) {
+ if (!connection->private_driver) {
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return connection->private_driver->ConnectionGetTableSchema(
+ connection, catalog, db_schema, table_name, schema, error);
+}
+
+AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection,
+ struct AdbcStatement* statement,
+ struct AdbcError* error) {
+ if (!connection->private_driver) {
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return connection->private_driver->ConnectionGetTableTypes(connection, statement,
+ error);
+}
+
AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection,
struct AdbcDatabase* database,
struct AdbcError* error) {
@@ -399,6 +445,26 @@ AdbcStatusCode AdbcStatementExecute(struct AdbcStatement* statement,
return statement->private_driver->StatementExecute(statement, error);
}
+AdbcStatusCode AdbcStatementGetPartitionDesc(struct AdbcStatement* statement,
+ uint8_t* partition_desc,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return statement->private_driver->StatementGetPartitionDesc(statement, partition_desc,
+ error);
+}
+
+AdbcStatusCode AdbcStatementGetPartitionDescSize(struct AdbcStatement* statement,
+ size_t* length,
+ struct AdbcError* error) {
+ if (!statement->private_driver) {
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ return statement->private_driver->StatementGetPartitionDescSize(statement, length,
+ error);
+}
+
AdbcStatusCode AdbcStatementGetStream(struct AdbcStatement* statement,
struct ArrowArrayStream* out,
struct AdbcError* error) {
@@ -627,6 +693,8 @@ AdbcStatusCode AdbcLoadDriver(const char* driver_name, const char* entrypoint,
CHECK_REQUIRED(driver, StatementRelease);
FILL_DEFAULT(driver, StatementBind);
FILL_DEFAULT(driver, StatementExecute);
+ FILL_DEFAULT(driver, StatementGetPartitionDesc);
+ FILL_DEFAULT(driver, StatementGetPartitionDescSize);
FILL_DEFAULT(driver, StatementPrepare);
FILL_DEFAULT(driver, StatementSetOption);
FILL_DEFAULT(driver, StatementSetSqlQuery);
diff --git a/python/adbc_driver_manager/adbc_driver_manager/__init__.py b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
index 13a8339..c07a38b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/__init__.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/__init__.py
@@ -14,3 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+
+from ._lib import ( # noqa: F401
+ INGEST_OPTION_TARGET_TABLE,
+ AdbcConnection,
+ AdbcDatabase,
+ AdbcStatement,
+ AdbcStatusCode,
+ ArrowArrayHandle,
+ ArrowArrayStreamHandle,
+ ArrowSchemaHandle,
+ DatabaseError,
+ DataError,
+ Error,
+ GetObjectsDepth,
+ IntegrityError,
+ InterfaceError,
+ InternalError,
+ NotSupportedError,
+ OperationalError,
+ ProgrammingError,
+)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 410e9af..6dbcd40 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -19,10 +19,11 @@
"""Low-level ADBC API."""
+import enum
import typing
+from typing import List
import cython
-import pyarrow
from libc.stdint cimport int32_t, uint8_t, uintptr_t
from libc.string cimport memset
@@ -30,7 +31,7 @@ if typing.TYPE_CHECKING:
from typing import Self
-cdef extern from "adbc.h":
+cdef extern from "adbc.h" nogil:
# C ABI
cdef struct CArrowSchema"ArrowSchema":
pass
@@ -40,22 +41,34 @@ cdef extern from "adbc.h":
pass
# ADBC
- ctypedef uint8_t AdbcStatusCode
- cdef AdbcStatusCode ADBC_STATUS_OK
- cdef AdbcStatusCode ADBC_STATUS_UNKNOWN
- cdef AdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED
- cdef AdbcStatusCode ADBC_STATUS_NOT_FOUND
- cdef AdbcStatusCode ADBC_STATUS_ALREADY_EXISTS
- cdef AdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT
- cdef AdbcStatusCode ADBC_STATUS_INVALID_STATE
- cdef AdbcStatusCode ADBC_STATUS_INVALID_DATA
- cdef AdbcStatusCode ADBC_STATUS_INTEGRITY
- cdef AdbcStatusCode ADBC_STATUS_INTERNAL
- cdef AdbcStatusCode ADBC_STATUS_IO
- cdef AdbcStatusCode ADBC_STATUS_CANCELLED
- cdef AdbcStatusCode ADBC_STATUS_TIMEOUT
- cdef AdbcStatusCode ADBC_STATUS_UNAUTHENTICATED
- cdef AdbcStatusCode ADBC_STATUS_UNAUTHORIZED
+ ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode"
+ cdef CAdbcStatusCode ADBC_STATUS_OK
+ cdef CAdbcStatusCode ADBC_STATUS_UNKNOWN
+ cdef CAdbcStatusCode ADBC_STATUS_NOT_IMPLEMENTED
+ cdef CAdbcStatusCode ADBC_STATUS_NOT_FOUND
+ cdef CAdbcStatusCode ADBC_STATUS_ALREADY_EXISTS
+ cdef CAdbcStatusCode ADBC_STATUS_INVALID_ARGUMENT
+ cdef CAdbcStatusCode ADBC_STATUS_INVALID_STATE
+ cdef CAdbcStatusCode ADBC_STATUS_INVALID_DATA
+ cdef CAdbcStatusCode ADBC_STATUS_INTEGRITY
+ cdef CAdbcStatusCode ADBC_STATUS_INTERNAL
+ cdef CAdbcStatusCode ADBC_STATUS_IO
+ cdef CAdbcStatusCode ADBC_STATUS_CANCELLED
+ cdef CAdbcStatusCode ADBC_STATUS_TIMEOUT
+ cdef CAdbcStatusCode ADBC_STATUS_UNAUTHENTICATED
+ cdef CAdbcStatusCode ADBC_STATUS_UNAUTHORIZED
+
+ cdef const char* ADBC_OPTION_VALUE_DISABLED
+ cdef const char* ADBC_OPTION_VALUE_ENABLED
+
+ cdef const char* ADBC_CONNECTION_OPTION_AUTOCOMMIT
+ cdef const char* ADBC_INGEST_OPTION_TARGET_TABLE
+
+ cdef int ADBC_OBJECT_DEPTH_ALL
+ cdef int ADBC_OBJECT_DEPTH_CATALOGS
+ cdef int ADBC_OBJECT_DEPTH_DB_SCHEMAS
+ cdef int ADBC_OBJECT_DEPTH_TABLES
+ cdef int ADBC_OBJECT_DEPTH_COLUMNS
ctypedef void (*CAdbcErrorRelease)(CAdbcError*)
@@ -74,32 +87,134 @@ cdef extern from "adbc.h":
cdef struct CAdbcStatement"AdbcStatement":
void* private_data
- AdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error)
- AdbcStatusCode AdbcDatabaseSetOption(CAdbcDatabase* database, const char* key, const char* value, CAdbcError* error)
- AdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error)
- AdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error)
-
- AdbcStatusCode AdbcConnectionNew(CAdbcConnection* connection, CAdbcError* error)
- AdbcStatusCode AdbcConnectionSetOption(CAdbcConnection* connection, const char* key, const char* value, CAdbcError* error)
- AdbcStatusCode AdbcConnectionInit(CAdbcConnection* connection, CAdbcDatabase* database, CAdbcError* error)
- AdbcStatusCode AdbcConnectionRelease(CAdbcConnection* connection, CAdbcError* error)
-
- AdbcStatusCode AdbcStatementBind(CAdbcStatement* statement, CArrowArray*, CArrowSchema*, CAdbcError* error)
- AdbcStatusCode AdbcStatementBindStream(CAdbcStatement* statement, CArrowArrayStream*, CAdbcError* error)
- AdbcStatusCode AdbcStatementExecute(CAdbcStatement* statement, CAdbcError* error)
- AdbcStatusCode AdbcStatementGetStream(CAdbcStatement* statement, CArrowArrayStream* c_stream, CAdbcError* error)
- AdbcStatusCode AdbcStatementNew(CAdbcConnection* connection, CAdbcStatement* statement, CAdbcError* error)
- AdbcStatusCode AdbcStatementPrepare(CAdbcStatement* statement, CAdbcError* error)
- AdbcStatusCode AdbcStatementSetOption(CAdbcStatement* statement, const char* key, const char* value, CAdbcError* error)
- AdbcStatusCode AdbcStatementSetSqlQuery(CAdbcStatement* statement, const char* query, CAdbcError* error)
- AdbcStatusCode AdbcStatementRelease(CAdbcStatement* statement, CAdbcError* error)
+ CAdbcStatusCode AdbcDatabaseNew(CAdbcDatabase* database, CAdbcError* error)
+ CAdbcStatusCode AdbcDatabaseSetOption(
+ CAdbcDatabase* database,
+ const char* key,
+ const char* value,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcDatabaseInit(CAdbcDatabase* database, CAdbcError* error)
+ CAdbcStatusCode AdbcDatabaseRelease(CAdbcDatabase* database, CAdbcError* error)
+
+ CAdbcStatusCode AdbcConnectionCommit(
+ CAdbcConnection* connection,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionRollback(
+ CAdbcConnection* connection,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionDeserializePartitionDesc(
+ CAdbcConnection* connection,
+ const uint8_t* serialized_partition,
+ size_t serialized_length,
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionGetObjects(
+ CAdbcConnection* connection,
+ int depth,
+ const char* catalog,
+ const char* db_schema,
+ const char* table_name,
+ const char** table_type,
+ const char* column_name,
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionGetTableSchema(
+ CAdbcConnection* connection,
+ const char* catalog,
+ const char* db_schema,
+ const char* table_name,
+ CArrowSchema* schema,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionGetTableTypes(
+ CAdbcConnection* connection,
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionInit(
+ CAdbcConnection* connection,
+ CAdbcDatabase* database,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionNew(
+ CAdbcConnection* connection,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionSetOption(
+ CAdbcConnection* connection,
+ const char* key,
+ const char* value,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcConnectionRelease(
+ CAdbcConnection* connection,
+ CAdbcError* error)
+
+ CAdbcStatusCode AdbcStatementBind(
+ CAdbcStatement* statement,
+ CArrowArray*,
+ CArrowSchema*,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementBindStream(
+ CAdbcStatement* statement,
+ CArrowArrayStream*,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementExecute(
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementGetPartitionDesc(
+ CAdbcStatement* statement,
+ uint8_t* partition_desc,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementGetPartitionDescSize(
+ CAdbcStatement* statement,
+ size_t* length,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementGetStream(
+ CAdbcStatement* statement,
+ CArrowArrayStream* c_stream,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementNew(
+ CAdbcConnection* connection,
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementPrepare(
+ CAdbcStatement* statement,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementSetOption(
+ CAdbcStatement* statement,
+ const char* key,
+ const char* value,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementSetSqlQuery(
+ CAdbcStatement* statement,
+ const char* query,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementSetSubstraitPlan(
+ CAdbcStatement* statement,
+ const uint8_t* plan,
+ size_t length,
+ CAdbcError* error)
+ CAdbcStatusCode AdbcStatementRelease(
+ CAdbcStatement* statement,
+ CAdbcError* error)
cdef extern from "adbc_driver_manager.h":
- const char* AdbcStatusCodeMessage(AdbcStatusCode code)
-
-
-INGEST_OPTION_TARGET_TABLE = "adbc.ingest.target_table"
+ const char* CAdbcStatusCodeMessage"AdbcStatusCodeMessage"(CAdbcStatusCode code)
+
+
+class AdbcStatusCode(enum.IntEnum):
+ OK = ADBC_STATUS_OK
+ UNKNOWN = ADBC_STATUS_UNKNOWN
+ NOT_IMPLEMENTED = ADBC_STATUS_NOT_IMPLEMENTED
+ NOT_FOUND = ADBC_STATUS_NOT_FOUND
+ ALREADY_EXISTS = ADBC_STATUS_ALREADY_EXISTS
+ INVALID_ARGUMENT = ADBC_STATUS_INVALID_ARGUMENT
+ INVALID_STATE = ADBC_STATUS_INVALID_STATE
+ INVALID_DATA = ADBC_STATUS_INVALID_DATA
+ INTEGRITY = ADBC_STATUS_INTEGRITY
+ INTERNAL = ADBC_STATUS_INTERNAL
+ IO = ADBC_STATUS_IO
+ CANCELLED = ADBC_STATUS_CANCELLED
+ TIMEOUT = ADBC_STATUS_TIMEOUT
+ UNAUTHENTICATED = ADBC_STATUS_UNAUTHENTICATED
+ UNAUTHORIZED = ADBC_STATUS_UNAUTHORIZED
class Error(Exception):
@@ -107,7 +222,7 @@ class Error(Exception):
Attributes
----------
- status_code : int
+ status_code : CAdbcStatusCode
The original ADBC status code.
vendor_code : int, optional
A vendor-specific status code if present.
@@ -117,7 +232,7 @@ class Error(Exception):
def __init__(self, message, *, status_code, vendor_code=None, sqlstate=None):
super().__init__(message)
- self.status_code = status_code
+ self.status_code = AdbcStatusCode(status_code)
self.vendor_code = None
self.sqlstate = None
@@ -154,11 +269,14 @@ class NotSupportedError(DatabaseError):
pass
-cdef void check_error(AdbcStatusCode status, CAdbcError* error) except *:
+INGEST_OPTION_TARGET_TABLE = ADBC_INGEST_OPTION_TARGET_TABLE.decode("utf-8")
+
+
+cdef void check_error(CAdbcStatusCode status, CAdbcError* error) except *:
if status == ADBC_STATUS_OK:
return
- message = AdbcStatusCodeMessage(status).decode("utf-8")
+ message = CAdbcStatusCodeMessage(status).decode("utf-8")
vendor_code = None
sqlstate = None
@@ -170,7 +288,8 @@ cdef void check_error(AdbcStatusCode status, CAdbcError* error) except *:
vendor_code = error.vendor_code
if error.sqlstate[0] != 0:
sqlstate = error.sqlstate.decode("ascii")
- error.release(error)
+ if error.release:
+ error.release(error)
klass = Error
if status in (ADBC_STATUS_INVALID_DATA,):
@@ -181,7 +300,11 @@ cdef void check_error(AdbcStatusCode status, CAdbcError* error) except *:
klass = IntegrityError
elif status in (ADBC_STATUS_INTERNAL,):
klass = InternalError
- elif status in (ADBC_STATUS_ALREADY_EXISTS, ADBC_STATUS_INVALID_ARGUMENT, ADBC_STATUS_INVALID_STATE, ADBC_STATUS_UNAUTHENTICATED, ADBC_STATUS_UNAUTHORIZED):
+ elif status in (ADBC_STATUS_ALREADY_EXISTS,
+ ADBC_STATUS_INVALID_ARGUMENT,
+ ADBC_STATUS_INVALID_STATE,
+ ADBC_STATUS_UNAUTHENTICATED,
+ ADBC_STATUS_UNAUTHORIZED):
klass = ProgrammingError
elif status == ADBC_STATUS_NOT_IMPLEMENTED:
klass = NotSupportedError
@@ -194,7 +317,18 @@ cdef CAdbcError empty_error():
return error
+cdef bytes _to_bytes(obj, str name):
+ if isinstance(obj, bytes):
+ return obj
+ elif isinstance(obj, str):
+ return obj.encode("utf-8")
+ raise ValueError(f"{name} must be str or bytes")
+
+
cdef class _AdbcHandle:
+ """
+ Base class for ADBC handles, which are context managers.
+ """
def __enter__(self) -> "Self":
return self
@@ -202,13 +336,70 @@ cdef class _AdbcHandle:
self.close()
+cdef class ArrowSchemaHandle:
+ """
+ A wrapper for an allocated ArrowSchema.
+ """
+ cdef:
+ CArrowSchema schema
+
+ @property
+ def address(self) -> int:
+ """The address of the ArrowSchema."""
+ return <uintptr_t> &self.schema
+
+
+cdef class ArrowArrayHandle:
+ """
+ A wrapper for an allocated ArrowArray.
+ """
+ cdef:
+ CArrowArray array
+
+ @property
+ def address(self) -> int:
+ """The address of the ArrowArray."""
+ return <uintptr_t> &self.array
+
+
+cdef class ArrowArrayStreamHandle:
+ """
+ A wrapper for an allocated ArrowArrayStream.
+ """
+ cdef:
+ CArrowArrayStream stream
+
+ @property
+ def address(self) -> int:
+ """The address of the ArrowArrayStream."""
+ return <uintptr_t> &self.stream
+
+
+class GetObjectsDepth(enum.IntEnum):
+ ALL = ADBC_OBJECT_DEPTH_ALL
+ CATALOGS = ADBC_OBJECT_DEPTH_CATALOGS
+ DB_SCHEMAS = ADBC_OBJECT_DEPTH_DB_SCHEMAS
+ TABLES = ADBC_OBJECT_DEPTH_TABLES
+ COLUMNS = ADBC_OBJECT_DEPTH_COLUMNS
+
+
cdef class AdbcDatabase(_AdbcHandle):
+ """
+ An instance of a database.
+
+ Parameters
+ ----------
+ kwargs : dict
+ String key-value options to pass to the underlying database.
+ Must include at least "driver" and "entrypoint" to identify
+ the underlying database driver to load.
+ """
cdef:
CAdbcDatabase database
def __init__(self, **kwargs) -> None:
cdef CAdbcError c_error = empty_error()
- cdef AdbcStatusCode status
+ cdef CAdbcStatusCode status
cdef const char* c_key
cdef const char* c_value
memset(&self.database, 0, cython.sizeof(CAdbcDatabase))
@@ -228,23 +419,40 @@ cdef class AdbcDatabase(_AdbcHandle):
check_error(status, &c_error)
def close(self) -> None:
+ """Release the handle to the database."""
if self.database.private_data == NULL:
return
cdef CAdbcError c_error = empty_error()
- cdef AdbcStatusCode status = AdbcDatabaseRelease(&self.database, &c_error)
+ cdef CAdbcStatusCode status = AdbcDatabaseRelease(&self.database, &c_error)
check_error(status, &c_error)
cdef class AdbcConnection(_AdbcHandle):
+ """
+ An active database connection.
+
+ Connections are not thread-safe and clients should take care to
+ serialize accesses to a connection.
+
+ Parameters
+ ----------
+ database : AdbcDatabase
+ The database to connect to.
+ kwargs : dict
+ String key-value options to pass to the underlying database.
+ """
cdef:
+ AdbcDatabase database
CAdbcConnection connection
def __init__(self, AdbcDatabase database, **kwargs) -> None:
cdef CAdbcError c_error = empty_error()
- cdef AdbcStatusCode status
+ cdef CAdbcStatusCode status
cdef const char* c_key
cdef const char* c_value
+
+ self.database = database
memset(&self.connection, 0, cython.sizeof(CAdbcConnection))
status = AdbcConnectionNew(&self.connection, &c_error)
@@ -261,50 +469,204 @@ cdef class AdbcConnection(_AdbcHandle):
status = AdbcConnectionInit(&self.connection, &database.database, &c_error)
check_error(status, &c_error)
+ def commit(self) -> None:
+ """Commit the current transaction."""
+ cdef CAdbcError c_error = empty_error()
+ check_error(AdbcConnectionCommit(&self.connection, &c_error), &c_error)
+
+ def get_objects(self, depth, catalog=None, db_schema=None, table_name=None,
+ table_types=None, column_name=None) -> AdbcStatement:
+ """
+ Get a hierarchical view of database objects.
+ """
+ cdef CAdbcError c_error = empty_error()
+ cdef CAdbcStatusCode status
+ cdef AdbcStatement statement = AdbcStatement(self)
+
+ cdef char* c_catalog = NULL
+ if catalog is not None:
+ catalog = _to_bytes(catalog, "catalog")
+ c_catalog = catalog
+
+ cdef char* c_db_schema = NULL
+ if db_schema is not None:
+ db_schema = _to_bytes(db_schema, "db_schema")
+ c_db_schema = db_schema
+
+ cdef char* c_table_name = NULL
+ if table_name is not None:
+ table_name = _to_bytes(table_name, "table_name")
+ c_table_name = table_name
+
+ cdef char* c_column_name = NULL
+ if column_name is not None:
+ column_name = _to_bytes(column_name, "column_name")
+ c_column_name = column_name
+
+ status = AdbcConnectionGetObjects(
+ &self.connection,
+ GetObjectsDepth(depth).value,
+ c_catalog,
+ c_db_schema,
+ c_table_name,
+ NULL, # TODO: support table_types
+ c_column_name,
+ &statement.statement,
+ &c_error)
+ check_error(status, &c_error)
+
+ return statement
+
+ def get_table_schema(self, catalog, db_schema, table_name) -> ArrowSchemaHandle:
+ """
+ Get the Arrow schema of a table.
+
+ Returns
+ -------
+ ArrowSchemaHandle
+ A C Data Interface ArrowSchema struct containing the schema.
+ """
+ cdef CAdbcError c_error = empty_error()
+ cdef CAdbcStatusCode status
+ cdef ArrowSchemaHandle handle = ArrowSchemaHandle()
+
+ cdef char* c_catalog = NULL
+ if catalog is not None:
+ catalog = _to_bytes(catalog, "catalog")
+ c_catalog = catalog
+
+ cdef char* c_db_schema = NULL
+ if db_schema is not None:
+ db_schema = _to_bytes(db_schema, "db_schema")
+ c_db_schema = db_schema
+
+ status = AdbcConnectionGetTableSchema(
+ &self.connection,
+ c_catalog,
+ c_db_schema,
+ _to_bytes(table_name, "table_name"),
+ &handle.schema,
+ &c_error)
+ check_error(status, &c_error)
+
+ return handle
+
+ def get_table_types(self) -> AdbcStatement:
+ """
+ Get the list of supported table types.
+ """
+ cdef CAdbcError c_error = empty_error()
+ cdef CAdbcStatusCode status
+ cdef AdbcStatement statement = AdbcStatement(self)
+
+ status = AdbcConnectionGetTableTypes(
+ &self.connection, &statement.statement, &c_error)
+ check_error(status, &c_error)
+
+ return statement
+
+ def rollback(self) -> None:
+ """Rollback the current transaction."""
+ cdef CAdbcError c_error = empty_error()
+ check_error(AdbcConnectionRollback(&self.connection, &c_error), &c_error)
+
+ def set_autocommit(self, bint enabled) -> None:
+ """Toggle whether autocommit is enabled."""
+ cdef CAdbcError c_error = empty_error()
+ if enabled:
+ value = ADBC_OPTION_VALUE_ENABLED
+ else:
+ value = ADBC_OPTION_VALUE_DISABLED
+ status = AdbcConnectionSetOption(
+ &self.connection,
+ ADBC_CONNECTION_OPTION_AUTOCOMMIT,
+ value,
+ &c_error)
+ check_error(status, &c_error)
+
def close(self) -> None:
+ """Release the handle to the connection."""
if self.connection.private_data == NULL:
return
cdef CAdbcError c_error = empty_error()
- cdef AdbcStatusCode status = AdbcConnectionRelease(&self.connection, &c_error)
+ cdef CAdbcStatusCode status = AdbcConnectionRelease(&self.connection, &c_error)
check_error(status, &c_error)
cdef class AdbcStatement(_AdbcHandle):
+ """
+ A database statement.
+
+ Statements are not thread-safe and clients should take care to
+ serialize accesses to a connection.
+
+ Parameters
+ ----------
+ connection : AdbcConnection
+ The connection to create the statement for.
+ """
cdef:
CAdbcStatement statement
def __init__(self, AdbcConnection connection) -> None:
cdef CAdbcError c_error = empty_error()
- cdef const char* c_key
- cdef const char* c_value
memset(&self.statement, 0, cython.sizeof(CAdbcStatement))
status = AdbcStatementNew(&connection.connection, &self.statement, &c_error)
check_error(status, &c_error)
- def bind(self, data) -> None:
+ def bind(self, data, schema) -> None:
+ """
+ Bind an ArrowArray to this statement.
+
+ Parameters
+ ----------
+ data : int or ArrowArrayHandle
+ schema : int or ArrowSchemaHandle
+ """
+ cdef CAdbcError c_error = empty_error()
+ cdef CArrowArray* c_array
+ cdef CArrowSchema* c_schema
+
+ if isinstance(data, ArrowArrayHandle):
+ c_array = &(<ArrowArrayHandle> data).array
+ elif isinstance(data, int):
+ c_array = <CArrowArray*> data
+ else:
+ raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}")
+
+ if isinstance(schema, ArrowSchemaHandle):
+ c_schema = &(<ArrowSchemaHandle> schema).schema
+ elif isinstance(schema, int):
+ c_schema = <CArrowSchema*> schema
+ else:
+ raise TypeError(f"schema must be int or ArrowSchemaHandle, "
+ f"not {type(schema)}")
+
+ status = AdbcStatementBind(&self.statement, c_array, c_schema, &c_error)
+ check_error(status, &c_error)
+
+ def bind_stream(self, stream) -> None:
"""
+ Bind an ArrowArrayStream to this statement.
+
Parameters
----------
- data : pyarrow.RecordBatch, pyarrow.RecordBatchReader, or pyarrow.Table
+ stream : int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
- cdef CArrowArray c_array
- cdef CArrowSchema c_schema
- cdef CArrowArrayStream c_stream
- if isinstance(data, pyarrow.RecordBatch):
- data._export_to_c(<uintptr_t> &c_array, <uintptr_t>&c_schema)
- status = AdbcStatementBind(&self.statement, &c_array, &c_schema, &c_error)
+ cdef CArrowArrayStream* c_stream
+
+ if isinstance(stream, ArrowArrayStreamHandle):
+ c_stream = &(<ArrowArrayStreamHandle> stream).stream
+ elif isinstance(stream, int):
+ c_stream = <CArrowArrayStream*> stream
else:
- if isinstance(data, pyarrow.Table):
- # Table lacks the export function
- data = data.to_reader()
- elif not isinstance(data, pyarrow.RecordBatchReader):
- raise TypeError("data must be RecordBatch(Reader) or Table")
- data._export_to_c(<uintptr_t> &c_stream)
- status = AdbcStatementBindStream(&self.statement, &c_stream, &c_error)
+ raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
+ f"not {type(stream)}")
+ status = AdbcStatementBindStream(&self.statement, c_stream, &c_error)
check_error(status, &c_error)
def close(self) -> None:
@@ -312,37 +674,84 @@ cdef class AdbcStatement(_AdbcHandle):
return
cdef CAdbcError c_error = empty_error()
- cdef AdbcStatusCode status = AdbcStatementRelease(&self.statement, &c_error)
+ cdef CAdbcStatusCode status = AdbcStatementRelease(&self.statement, &c_error)
check_error(status, &c_error)
def execute(self) -> None:
+ """Execute the query."""
cdef CAdbcError c_error = empty_error()
- status = AdbcStatementExecute(&self.statement, &c_error)
+ with nogil:
+ status = AdbcStatementExecute(&self.statement, &c_error)
check_error(status, &c_error)
- def get_stream(self) -> pyarrow.RecordBatchReader:
+ def get_partitions(self) -> List[bytes]:
+ """Get the partitions of a distributed result set."""
+ cdef CAdbcError c_error = empty_error()
+ cdef size_t length = 0
+ cdef bytes buf
+ cdef uint8_t* c_buf
+
+ result = []
+ while True:
+ with nogil:
+ status = AdbcStatementGetPartitionDescSize(
+ &self.statement,
+ &length,
+ &c_error,
+ )
+ check_error(status, &c_error)
+ if length == 0:
+ break
+
+ buf = bytes(length)
+ c_buf = <uint8_t*> buf
+ with nogil:
+ status = AdbcStatementGetPartitionDesc(
+ &self.statement,
+ c_buf,
+ &c_error,
+ )
+ check_error(status, &c_error)
+ result.append(buf)
+
+ return result
+
+ def get_stream(self) -> ArrowArrayStreamHandle:
+ """Get a reader for the result set."""
cdef CAdbcError c_error = empty_error()
- cdef CArrowArrayStream c_stream
- status = AdbcStatementGetStream(&self.statement, &c_stream, &c_error)
+ cdef ArrowArrayStreamHandle stream = ArrowArrayStreamHandle()
+ status = AdbcStatementGetStream(&self.statement, &stream.stream, &c_error)
check_error(status, &c_error)
- return pyarrow.RecordBatchReader._import_from_c(<uintptr_t> &c_stream)
+ return stream
def prepare(self) -> None:
+ """Turn this statement into a prepared statement."""
cdef CAdbcError c_error = empty_error()
status = AdbcStatementPrepare(&self.statement, &c_error)
check_error(status, &c_error)
def set_options(self, **kwargs) -> None:
+ """Set arbitrary key-value options."""
cdef CAdbcError c_error = empty_error()
for key, value in kwargs.items():
key = key.encode("utf-8")
value = value.encode("utf-8")
c_key = key
c_value = value
- status = AdbcStatementSetOption(&self.statement, c_key, c_value, &c_error)
+ status = AdbcStatementSetOption(
+ &self.statement, c_key, c_value, &c_error)
check_error(status, &c_error)
def set_sql_query(self, query: str) -> None:
+ """Set a SQL query to be executed."""
+ cdef CAdbcError c_error = empty_error()
+ status = AdbcStatementSetSqlQuery(
+ &self.statement, query.encode("utf-8"), &c_error)
+ check_error(status, &c_error)
+
+ def set_substrait_plan(self, plan: bytes) -> None:
+ """Set a Substrait plan to be executed."""
cdef CAdbcError c_error = empty_error()
- status = AdbcStatementSetSqlQuery(&self.statement, query.encode("utf-8"), &c_error)
+ status = AdbcStatementSetSubstraitPlan(
+ &self.statement, plan, len(plan), &c_error)
check_error(status, &c_error)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
index 114ccde..5f716e9 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/tests/test_lowlevel.py
@@ -15,66 +15,208 @@
# specific language governing permissions and limitations
# under the License.
-import adbc_driver_manager._lib as _lib
+import adbc_driver_manager
import pyarrow
import pytest
-# TODO: make this parameterizable on different drivers?
+
+@pytest.fixture
+def sqlite():
+ """Dynamically load the SQLite driver."""
+ with adbc_driver_manager.AdbcDatabase(
+ driver="adbc_driver_sqlite",
+ entrypoint="AdbcSqliteDriverInit",
+ ) as db:
+ with adbc_driver_manager.AdbcConnection(db) as conn:
+ yield (db, conn)
+
+
+def _import(handle):
+ """Helper to import a C Data Interface handle."""
+ if isinstance(handle, adbc_driver_manager.ArrowArrayStreamHandle):
+ return pyarrow.RecordBatchReader._import_from_c(handle.address)
+ elif isinstance(handle, adbc_driver_manager.ArrowSchemaHandle):
+ return pyarrow.Schema._import_from_c(handle.address)
+ raise NotImplementedError(f"Importing {handle!r}")
+
+
+def _bind(stmt, batch):
+ array = adbc_driver_manager.ArrowArrayHandle()
+ schema = adbc_driver_manager.ArrowSchemaHandle()
+ batch._export_to_c(array.address, schema.address)
+ stmt.bind(array, schema)
def test_database_init():
with pytest.raises(
- _lib.ProgrammingError, match=".*Must provide 'driver' parameter.*"
+ adbc_driver_manager.ProgrammingError,
+ match=".*Must provide 'driver' parameter.*",
):
- with _lib.AdbcDatabase():
+ with adbc_driver_manager.AdbcDatabase():
pass
-@pytest.fixture
-def sqlite():
- with _lib.AdbcDatabase(
- driver="adbc_driver_sqlite",
- entrypoint="AdbcSqliteDriverInit",
- ) as db:
- with _lib.AdbcConnection(db) as conn:
- yield (db, conn)
+def test_connection_get_objects(sqlite):
+ _, conn = sqlite
+ data = pyarrow.record_batch(
+ [
+ [1, 2, 3, 4],
+ ["a", "b", "c", "d"],
+ ],
+ names=["ints", "strs"],
+ )
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute()
+
+ with conn.get_objects(adbc_driver_manager.GetObjectsDepth.ALL) as stmt:
+ table = _import(stmt.get_stream()).read_all()
+
+ db_schemas = pyarrow.concat_arrays(table[1].chunks).flatten()
+ tables = db_schemas.flatten()[1].flatten()
+ table_names, _, columns, *_ = tables.flatten()
+ columns = columns.flatten()
+ column_names = columns.flatten()[0]
+
+ assert "foo" in table_names.to_pylist()
+ assert "ints" in column_names.to_pylist()
+ assert "strs" in column_names.to_pylist()
+
+
+def test_connection_get_table_schema(sqlite):
+ _, conn = sqlite
+ data = pyarrow.record_batch(
+ [
+ [1, 2, 3, 4],
+ ["a", "b", "c", "d"],
+ ],
+ names=["ints", "strs"],
+ )
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute()
+
+ handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo")
+ assert data.schema == _import(handle)
+
+
+def test_connection_get_table_types(sqlite):
+ _, conn = sqlite
+ with conn.get_table_types() as stmt:
+ table = _import(stmt.get_stream()).read_all()
+ assert "table" in table[0].to_pylist()
def test_query(sqlite):
_, conn = sqlite
- with _lib.AdbcStatement(conn) as stmt:
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_sql_query("SELECT 1")
stmt.execute()
- assert stmt.get_stream().read_all() == pyarrow.table([[1]], names=["1"])
+ table = _import(stmt.get_stream()).read_all()
+ assert table == pyarrow.table([[1]], names=["1"])
def test_prepared(sqlite):
_, conn = sqlite
- with _lib.AdbcStatement(conn) as stmt:
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_sql_query("SELECT ?")
stmt.prepare()
- stmt.bind(pyarrow.table([[1, 2, 3, 4]], names=["1"]))
+ _bind(stmt, pyarrow.record_batch([[1, 2, 3, 4]], names=["1"]))
stmt.execute()
- assert stmt.get_stream().read_all() == pyarrow.table(
- [[1, 2, 3, 4]], names=["?"]
- )
+ table = _import(stmt.get_stream()).read_all()
+ assert table == pyarrow.table([[1, 2, 3, 4]], names=["?"])
def test_ingest(sqlite):
_, conn = sqlite
- data = pyarrow.table(
+ data = pyarrow.record_batch(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
],
names=["ints", "strs"],
)
- with _lib.AdbcStatement(conn) as stmt:
- stmt.set_options(**{_lib.INGEST_OPTION_TARGET_TABLE: "foo"})
- stmt.bind(data)
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute()
+
+ stmt.set_sql_query("SELECT * FROM foo")
+ stmt.execute()
+ table = _import(stmt.get_stream()).read_all()
+ assert table == pyarrow.Table.from_batches([data])
+
+
+def test_autocommit(sqlite):
+ _, conn = sqlite
+
+ # Autocommit enabled by default
+ with pytest.raises(adbc_driver_manager.ProgrammingError) as errholder:
+ conn.commit()
+ assert (
+ errholder.value.status_code == adbc_driver_manager.AdbcStatusCode.INVALID_STATE
+ )
+
+ with pytest.raises(adbc_driver_manager.ProgrammingError) as errholder:
+ conn.rollback()
+ assert (
+ errholder.value.status_code == adbc_driver_manager.AdbcStatusCode.INVALID_STATE
+ )
+
+ conn.set_autocommit(True)
+
+ conn.set_autocommit(False)
+
+ # Test rollback
+ data = pyarrow.record_batch(
+ [
+ [1, 2, 3, 4],
+ ["a", "b", "c", "d"],
+ ],
+ names=["ints", "strs"],
+ )
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute()
+
+ stmt.set_sql_query("SELECT * FROM foo")
+ stmt.execute()
+ table = _import(stmt.get_stream()).read_all()
+ assert table == pyarrow.Table.from_batches([data])
+
+ conn.rollback()
+
+ # Data should not be readable
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ with pytest.raises(adbc_driver_manager.OperationalError):
+ stmt.set_sql_query("SELECT * FROM foo")
+ stmt.execute()
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute()
+
+ # Enabling autocommit should implicitly commit
+ conn.set_autocommit(True)
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM foo")
+ stmt.execute()
+
+ conn.set_autocommit(False)
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "bar"})
+ _bind(stmt, data)
stmt.execute()
+ # Explicit commit
+ conn.commit()
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_sql_query("SELECT * FROM foo")
stmt.execute()
- assert stmt.get_stream().read_all() == data
+ table = _import(stmt.get_stream()).read_all()
+ assert table == pyarrow.Table.from_batches([data])
diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py
index 9a426a8..74f1f7c 100644
--- a/python/adbc_driver_manager/setup.py
+++ b/python/adbc_driver_manager/setup.py
@@ -30,7 +30,7 @@ setup(
"../../c/driver_manager/adbc_driver_manager.cc",
],
include_dirs=["../../", "../../c/driver_manager"],
- # extra_compile_args=["-ggdb", "-Og"],
+ extra_compile_args=["-ggdb", "-Og"],
),
),
packages=["adbc_driver_manager"],