You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ro...@apache.org on 2019/04/17 05:08:28 UTC

[arrow] branch master updated: ARROW-3399: [Python] Implementing numpy matrix serialization

This is an automated email from the ASF dual-hosted git repository.

robertnishihara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 7561027  ARROW-3399: [Python] Implementing numpy matrix serialization
7561027 is described below

commit 7561027ea09925f3077c2deb047f8a9c7dfb7ded
Author: Rok <ro...@mihevc.org>
AuthorDate: Tue Apr 16 22:08:20 2019 -0700

    ARROW-3399: [Python] Implementing numpy matrix serialization
    
    See [ARROW-3399](https://issues.apache.org/jira/browse/ARROW-3399).
    
    Author: Rok <ro...@mihevc.org>
    
    Closes #4096 from rok/ARROW-3399 and squashes the following commits:
    
    a2a6ae30 <Rok> Moving np.matrix serialization tests from test_plasma to test_serialization.
    65af7d2f <Rok> Adding cutom type to np.matrix serialization tests.
    5349b3c5 <Rok> Removing copy from np.matrix deserialization. Adding tests for various datatypes at np.matrix deserialization.
    6a9539ee <Rok> Adding numpy matrix serialization support and test.
---
 python/pyarrow/serialization.py            | 24 ++++++++++++++++++++++++
 python/pyarrow/tests/test_serialization.py | 23 +++++++++++++++++++++++
 2 files changed, 47 insertions(+)

diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 3eed68d..fe170b2 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -55,6 +55,25 @@ def _deserialize_numpy_array_list(data):
         return np.array(data[0], dtype=np.dtype(data[1]))
 
 
+def _serialize_numpy_matrix(obj):
+    if obj.dtype.str != '|O':
+        # Make the array c_contiguous if necessary so that we can call change
+        # the view.
+        if not obj.flags.c_contiguous:
+            obj = np.ascontiguousarray(obj.A)
+        return obj.A.view('uint8'), obj.A.dtype.str
+    else:
+        return obj.A.tolist(), obj.A.dtype.str
+
+
+def _deserialize_numpy_matrix(data):
+    if data[1] != '|O':
+        assert data[0].dtype == np.uint8
+        return np.matrix(data[0].view(data[1]), copy=False)
+    else:
+        return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False)
+
+
 # ----------------------------------------------------------------------
 # pyarrow.RecordBatch-specific serialization matters
 
@@ -298,6 +317,11 @@ def register_default_serialization_handlers(serialization_context):
     serialization_context.register_type(type, "type", pickle=True)
 
     serialization_context.register_type(
+        np.matrix, 'np.matrix',
+        custom_serializer=_serialize_numpy_matrix,
+        custom_deserializer=_deserialize_numpy_matrix)
+
+    serialization_context.register_type(
         np.ndarray, 'np.array',
         custom_serializer=_serialize_numpy_array_list,
         custom_deserializer=_deserialize_numpy_array_list)
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index c37e8b6..0a7e443 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -489,6 +489,29 @@ def test_numpy_subclass_serialization():
     assert np.alltrue(new_x.view(np.ndarray) == np.zeros(3))
 
 
+def test_numpy_matrix_serialization(tmpdir):
+    class CustomType(object):
+        def __init__(self, val):
+            self.val = val
+
+    path = os.path.join(str(tmpdir), 'pyarrow_npmatrix_serialization_test.bin')
+    array = np.random.randint(low=-1, high=1, size=(2, 2))
+
+    for data_type in [str, int, float, CustomType]:
+        matrix = np.matrix(array.astype(data_type))
+
+        with open(path, 'wb') as f:
+            f.write(pa.serialize(matrix).to_buffer())
+
+        serialized = pa.read_serialized(pa.OSFile(path))
+        result = serialized.deserialize()
+        assert_equal(result, matrix)
+        assert_equal(result.dtype, matrix.dtype)
+        serialized = None
+        assert_equal(result, matrix)
+        assert result.base is not None
+
+
 def test_pyarrow_objects_serialization(large_buffer):
     # NOTE: We have to put these objects inside,
     # or it will affect 'test_total_bytes_allocated'.