You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by pa...@apache.org on 2023/11/17 17:14:07 UTC
(arrow-nanoarrow) branch main updated: feat(python): Support the PyCapsule protocol (#318)
This is an automated email from the ASF dual-hosted git repository.
paleolimbot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-nanoarrow.git
The following commit(s) were added to refs/heads/main by this push:
new 2d28306 feat(python): Support the PyCapsule protocol (#318)
2d28306 is described below
commit 2d2830686eb5a1ceb5c6389a22f97865187e40ba
Author: Joris Van den Bossche <jo...@gmail.com>
AuthorDate: Fri Nov 17 18:14:02 2023 +0100
feat(python): Support the PyCapsule protocol (#318)
First commit with just importing for now. Will do exporting next.
---------
Co-authored-by: Dewey Dunnington <de...@fishandwhistle.net>
---
python/src/nanoarrow/_lib.pyx | 71 ++++++++++++++++++++++++++++++++++---
python/src/nanoarrow/lib.py | 25 +++++++------
python/tests/test_capsules.py | 81 +++++++++++++++++++++++++++++++++++++++++++
3 files changed, 162 insertions(+), 15 deletions(-)
diff --git a/python/src/nanoarrow/_lib.pyx b/python/src/nanoarrow/_lib.pyx
index 906eecf..61a53e2 100644
--- a/python/src/nanoarrow/_lib.pyx
+++ b/python/src/nanoarrow/_lib.pyx
@@ -30,9 +30,11 @@ be literal and stay close to the structure definitions.
from libc.stdint cimport uintptr_t, int64_t
from cpython.mem cimport PyMem_Malloc, PyMem_Free
from cpython.bytes cimport PyBytes_FromStringAndSize
+from cpython.pycapsule cimport PyCapsule_GetPointer
from cpython cimport Py_buffer
from nanoarrow_c cimport *
+
def c_version():
"""Return the nanoarrow C library version string
"""
@@ -200,6 +202,22 @@ cdef class Schema:
self._base = base,
self._ptr = <ArrowSchema*>addr
+ @staticmethod
+ def _import_from_c_capsule(schema_capsule):
+ """
+ Import from a ArrowSchema PyCapsule
+
+ Parameters
+ ----------
+ schema_capsule : PyCapsule
+ A valid PyCapsule with name 'arrow_schema' containing an
+ ArrowSchema pointer.
+ """
+ return Schema(
+ schema_capsule,
+ <uintptr_t>PyCapsule_GetPointer(schema_capsule, 'arrow_schema')
+ )
+
def _addr(self):
return <uintptr_t>self._ptr
@@ -428,6 +446,33 @@ cdef class Array:
self._ptr = <ArrowArray*>addr
self._schema = schema
+ @staticmethod
+ def _import_from_c_capsule(schema_capsule, array_capsule):
+ """
+ Import from a ArrowSchema and ArrowArray PyCapsule tuple.
+
+ Parameters
+ ----------
+ schema_capsule : PyCapsule
+ A valid PyCapsule with name 'arrow_schema' containing an
+ ArrowSchema pointer.
+ array_capsule : PyCapsule
+ A valid PyCapsule with name 'arrow_array' containing an
+ ArrowArray pointer.
+ """
+ cdef:
+ Schema out_schema
+ Array out
+
+ out_schema = Schema._import_from_c_capsule(schema_capsule)
+ out = Array(
+ array_capsule,
+ <uintptr_t>PyCapsule_GetPointer(array_capsule, 'arrow_array'),
+ out_schema
+ )
+
+ return out
+
def _addr(self):
return <uintptr_t>self._ptr
@@ -818,11 +863,32 @@ cdef class ArrayStream:
cdef ArrowArrayStream* _ptr
cdef object _cached_schema
+ @staticmethod
+ def allocate():
+ base = ArrayStreamHolder()
+ return ArrayStream(base, base._addr())
+
def __cinit__(self, object base, uintptr_t addr):
self._base = base
self._ptr = <ArrowArrayStream*>addr
self._cached_schema = None
+ @staticmethod
+ def _import_from_c_capsule(stream_capsule):
+ """
+ Import from a ArrowArrayStream PyCapsule.
+
+ Parameters
+ ----------
+ stream_capsule : PyCapsule
+ A valid PyCapsule with name 'arrow_array_stream' containing an
+ ArrowArrayStream pointer.
+ """
+ return ArrayStream(
+ stream_capsule,
+ <uintptr_t>PyCapsule_GetPointer(stream_capsule, 'arrow_array_stream')
+ )
+
def _addr(self):
return <uintptr_t>self._ptr
@@ -898,8 +964,3 @@ cdef class ArrayStream:
def __next__(self):
return self.get_next()
-
- @staticmethod
- def allocate():
- base = ArrayStreamHolder()
- return ArrayStream(base, base._addr())
diff --git a/python/src/nanoarrow/lib.py b/python/src/nanoarrow/lib.py
index ad9323a..6625aff 100644
--- a/python/src/nanoarrow/lib.py
+++ b/python/src/nanoarrow/lib.py
@@ -22,9 +22,10 @@ def schema(obj):
if isinstance(obj, Schema):
return obj
- # Not particularly safe because _export_to_c() could be exporting an
- # array, schema, or array_stream. The ideal
- # solution here would be something like __arrow_c_schema__()
+ if hasattr(obj, "__arrow_c_schema__"):
+ return Schema._import_from_c_capsule(obj.__arrow_c_schema__())
+
+ # for pyarrow < 14.0
if hasattr(obj, "_export_to_c"):
out = Schema.allocate()
obj._export_to_c(out._addr())
@@ -39,9 +40,11 @@ def array(obj):
if isinstance(obj, Array):
return obj
- # Somewhat safe because calling _export_to_c() with two arguments will
- # not fail with a crash (but will fail with a confusing error). The ideal
- # solution here would be something like __arrow_c_array__()
+ if hasattr(obj, "__arrow_c_array__"):
+ # TODO support requested schema
+ return Array._import_from_c_capsule(*obj.__arrow_c_array__())
+
+ # for pyarrow < 14.0
if hasattr(obj, "_export_to_c"):
out = Array.allocate(Schema.allocate())
obj._export_to_c(out._addr(), out.schema._addr())
@@ -53,12 +56,14 @@ def array(obj):
def array_stream(obj):
- if isinstance(obj, Schema):
+ if isinstance(obj, ArrayStream):
return obj
- # Not particularly safe because _export_to_c() could be exporting an
- # array, schema, or array_stream. The ideal
- # solution here would be something like __arrow_c_array_stream__()
+ if hasattr(obj, "__arrow_c_stream__"):
+ # TODO support requested schema
+ return ArrayStream._import_from_c_capsule(obj.__arrow_c_stream__())
+
+ # for pyarrow < 14.0
if hasattr(obj, "_export_to_c"):
out = ArrayStream.allocate()
obj._export_to_c(out._addr())
diff --git a/python/tests/test_capsules.py b/python/tests/test_capsules.py
new file mode 100644
index 0000000..f42f57e
--- /dev/null
+++ b/python/tests/test_capsules.py
@@ -0,0 +1,81 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+
+import nanoarrow as na
+
+
+class SchemaWrapper:
+ def __init__(self, schema):
+ self.schema = schema
+
+ def __arrow_c_schema__(self):
+ return self.schema.__arrow_c_schema__()
+
+
+class ArrayWrapper:
+ def __init__(self, array):
+ self.array = array
+
+ def __arrow_c_array__(self, requested_schema=None):
+ return self.array.__arrow_c_array__(requested_schema=requested_schema)
+
+
+class StreamWrapper:
+ def __init__(self, stream):
+ self.stream = stream
+
+ def __arrow_c_stream__(self, requested_schema=None):
+ return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
+
+
+def test_schema_import():
+ pa_schema = pa.schema([pa.field("some_name", pa.int32())])
+
+ for schema_obj in [pa_schema, SchemaWrapper(pa_schema)]:
+ schema = na.schema(schema_obj)
+ # some basic validation
+ assert schema.is_valid()
+ assert schema.format == "+s"
+ assert str(schema) == "struct<some_name: int32>"
+
+
+def test_array_import():
+ pa_arr = pa.array([1, 2, 3], pa.int32())
+
+ for arr_obj in [pa_arr, ArrayWrapper(pa_arr)]:
+ array = na.array(arr_obj)
+ # some basic validation
+ assert array.is_valid()
+ assert array.length == 3
+ assert str(array.schema) == "int32"
+
+
+def test_array_stream_import():
+ def make_reader():
+ pa_array_child = pa.array([1, 2, 3], pa.int32())
+ pa_array = pa.record_batch([pa_array_child], names=["some_column"])
+ return pa.RecordBatchReader.from_batches(pa_array.schema, [pa_array])
+
+ for stream_obj in [make_reader(), StreamWrapper(make_reader())]:
+ array_stream = na.array_stream(stream_obj)
+ # some basic validation
+ assert array_stream.is_valid()
+ array = array_stream.get_next()
+ assert array.length == 3
+ assert str(array_stream.get_schema()) == "struct<some_column: int32>"