You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2018/04/04 19:54:53 UTC

[arrow] branch master updated: ARROW-2276: [Python] Expose buffer protocol on Tensor

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 640fc83  ARROW-2276: [Python] Expose buffer protocol on Tensor
640fc83 is described below

commit 640fc83fd8e6ebdfd1b4dca8b8ca36bca00f77f4
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Wed Apr 4 21:54:26 2018 +0200

    ARROW-2276: [Python] Expose buffer protocol on Tensor
    
    Also add a bit_width property to the DataType class.
    
    Author: Antoine Pitrou <an...@python.org>
    
    Closes #1741 from pitrou/ARROW-2276-tensor-buffer-protocol and squashes the following commits:
    
    104388a <Antoine Pitrou> ARROW-2276:  Expose buffer protocol on Tensor
---
 python/pyarrow/array.pxi            | 24 +++++++++++++++++++
 python/pyarrow/lib.pxd              |  3 +++
 python/pyarrow/tests/test_tensor.py | 27 ++++++++++++++++++++++
 python/pyarrow/tests/test_types.py  | 13 +++++++++++
 python/pyarrow/types.pxi            | 46 +++++++++++++++++++++++++++++++++++++
 5 files changed, 113 insertions(+)

diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index a67bd8b..490a37b 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -651,6 +651,30 @@ strides: {0.strides}""".format(self)
         self._validate()
         return tuple(self.tp.strides())
 
+    def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
+        self._validate()
+
+        buffer.buf = <char *> self.tp.data().get().data()
+        pep3118_format = self.type.pep3118_format
+        if pep3118_format is None:
+            raise NotImplementedError("type %s not supported for buffer "
+                                      "protocol" % (self.type,))
+        buffer.format = pep3118_format
+        buffer.itemsize = self.type.bit_width // 8
+        buffer.internal = NULL
+        buffer.len = self.tp.size() * buffer.itemsize
+        buffer.ndim = self.tp.ndim()
+        buffer.obj = self
+        if self.tp.is_mutable():
+            buffer.readonly = 0
+        else:
+            buffer.readonly = 1
+        # NOTE: This assumes Py_ssize_t == int64_t, and that the shape
+        # and strides arrays lifetime is tied to the tensor's
+        buffer.shape = <Py_ssize_t *> &self.tp.shape()[0]
+        buffer.strides = <Py_ssize_t *> &self.tp.strides()[0]
+        buffer.suboffsets = NULL
+
 
 cdef wrap_array_output(PyObject* output):
     cdef object obj = PyObject_to_object(output)
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 4c24b48..6f4100f 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -20,6 +20,8 @@ from pyarrow.includes.libarrow cimport *
 from pyarrow.includes.libarrow cimport CStatus
 from cpython cimport PyObject
 from libcpp cimport nullptr
+from libcpp.cast cimport dynamic_cast
+
 
 cdef extern from "Python.h":
     int PySlice_Check(object)
@@ -42,6 +44,7 @@ cdef class DataType:
     cdef:
         shared_ptr[CDataType] sp_type
         CDataType* type
+        bytes pep3118_format
 
     cdef void init(self, const shared_ptr[CDataType]& type)
 
diff --git a/python/pyarrow/tests/test_tensor.py b/python/pyarrow/tests/test_tensor.py
index 093bc86..188a4a5 100644
--- a/python/pyarrow/tests/test_tensor.py
+++ b/python/pyarrow/tests/test_tensor.py
@@ -165,3 +165,30 @@ def test_read_tensor(tmpdir):
     read_mmap = pa.memory_map(path, mode='r')
     array = pa.read_tensor(read_mmap).to_numpy()
     np.testing.assert_equal(data, array)
+
+
+@pytest.mark.skipif(sys.version_info < (3,),
+                    reason="requires Python 3+")
+def test_tensor_memoryview():
+    # Tensors support the PEP 3118 buffer protocol
+    for dtype, expected_format in [(np.int8, '=b'),
+                                   (np.int64, '=q'),
+                                   (np.uint64, '=Q'),
+                                   (np.float16, 'e'),
+                                   (np.float64, 'd'),
+                                   ]:
+        data = np.arange(10, dtype=dtype)
+        dtype = data.dtype
+        lst = data.tolist()
+        tensor = pa.Tensor.from_numpy(data)
+        m = memoryview(tensor)
+        assert m.format == expected_format
+        assert m.shape == data.shape
+        assert m.strides == data.strides
+        assert m.ndim == 1
+        assert m.nbytes == data.nbytes
+        assert m.itemsize == data.itemsize
+        assert m.itemsize * 8 == tensor.type.bit_width
+        assert np.frombuffer(m, dtype).tolist() == lst
+        del tensor, data
+        assert np.frombuffer(m, dtype).tolist() == lst
diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py
index b517020..5057359 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -230,6 +230,19 @@ def test_exact_primitive_types(t, check_func):
     assert check_func(t)
 
 
+def test_bit_width():
+    for ty, expected in [(pa.bool_(), 1),
+                         (pa.int8(), 8),
+                         (pa.uint32(), 32),
+                         (pa.float16(), 16),
+                         (pa.decimal128(19, 4), 128),
+                         (pa.binary(42), 42 * 8)]:
+        assert ty.bit_width == expected
+    for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]:
+        with pytest.raises(ValueError, match="fixed width"):
+            ty.bit_width
+
+
 def test_fixed_size_binary_byte_width():
     ty = pa.binary(5)
     assert ty.byte_width == 5
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 2abdb30..cb3a72d 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -43,6 +43,42 @@ cdef dict _pandas_type_map = {
     _Type_DECIMAL: np.object_,
 }
 
+cdef dict _pep3118_type_map = {
+    _Type_INT8: b'b',
+    _Type_INT16: b'h',
+    _Type_INT32: b'i',
+    _Type_INT64: b'q',
+    _Type_UINT8: b'B',
+    _Type_UINT16: b'H',
+    _Type_UINT32: b'I',
+    _Type_UINT64: b'Q',
+    _Type_HALF_FLOAT: b'e',
+    _Type_FLOAT: b'f',
+    _Type_DOUBLE: b'd',
+}
+
+
+cdef bytes _datatype_to_pep3118(CDataType* type):
+    """
+    Construct a PEP 3118 format string describing the given datatype.
+    None is returned for unsupported types.
+    """
+    try:
+        char = _pep3118_type_map[type.id()]
+    except KeyError:
+        return None
+    else:
+        if char in b'bBhHiIqQ':
+            # Use "standard" int widths, not native
+            return b'=' + char
+        else:
+            return char
+
+
+# Workaround for Cython parsing bug
+# https://github.com/cython/cython/issues/2143
+ctypedef CFixedWidthType* _CFixedWidthTypePtr
+
 
 cdef class DataType:
     """
@@ -54,12 +90,22 @@ cdef class DataType:
     cdef void init(self, const shared_ptr[CDataType]& type):
         self.sp_type = type
         self.type = type.get()
+        self.pep3118_format = _datatype_to_pep3118(self.type)
 
     property id:
 
         def __get__(self):
             return self.type.id()
 
+    property bit_width:
+
+        def __get__(self):
+            cdef _CFixedWidthTypePtr ty
+            ty = dynamic_cast[_CFixedWidthTypePtr](self.type)
+            if ty == nullptr:
+                raise ValueError("Non-fixed width type")
+            return ty.bit_width()
+
     def __str__(self):
         if self.type is NULL:
             raise TypeError(

-- 
To stop receiving notification emails like this one, please contact
apitrou@apache.org.