You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/06/07 15:26:08 UTC

[arrow] branch master updated: ARROW-4452: [Python] Serialize sparse torch tensors

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e49631d  ARROW-4452: [Python] Serialize sparse torch tensors
e49631d is described below

commit e49631dc18ff11ee755ddb94d2fa82cf1be2e9a0
Author: Philipp Moritz <pc...@gmail.com>
AuthorDate: Fri Jun 7 10:25:54 2019 -0500

    ARROW-4452: [Python] Serialize sparse torch tensors
    
    Author: Philipp Moritz <pc...@gmail.com>
    
    Closes #3542 from pcmoritz/torch-sparse-tensors and squashes the following commits:
    
    0ba0a33f3 <Philipp Moritz> fix comparison
    fd4774936 <Philipp Moritz> test for sparse tensor properly
    9da9bfe3d <Philipp Moritz> add serialization for sparse pytorch tensors
    2c37e8ebd <Philipp Moritz> add sparse tensor support for torch
---
 python/pyarrow/serialization.py            | 13 +++++++++++--
 python/pyarrow/tests/test_serialization.py | 19 +++++++++++++++++++
 2 files changed, 30 insertions(+), 2 deletions(-)

diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 3a605a9..89f6e23 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -230,10 +230,19 @@ def register_torch_serialization_handlers(serialization_context):
         import torch
 
         def _serialize_torch_tensor(obj):
-            return obj.detach().numpy()
+            if obj.is_sparse:
+                # TODO(pcm): Once ARROW-4453 is resolved, return sparse
+                # tensor representation here
+                return (obj._indices().detach().numpy(),
+                        obj._values().detach().numpy(), list(obj.shape))
+            else:
+                return obj.detach().numpy()
 
         def _deserialize_torch_tensor(data):
-            return torch.from_numpy(data)
+            if isinstance(data, tuple):
+                return torch.sparse_coo_tensor(data[0], data[1], data[2])
+            else:
+                return torch.from_numpy(data)
 
         for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
                   torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index cbff0ff..22983e7 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -43,6 +43,10 @@ except ImportError:
 
 def assert_equal(obj1, obj2):
     if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2):
+        if obj1.is_sparse:
+            obj1 = obj1.to_dense()
+        if obj2.is_sparse:
+            obj2 = obj2.to_dense()
         assert torch.equal(obj1, obj2)
         return
     module_numpy = (type(obj1).__module__ == np.__name__ or
@@ -390,6 +394,9 @@ def test_torch_serialization(large_buffer):
 
     serialization_context = pa.default_serialization_context()
     pa.register_torch_serialization_handlers(serialization_context)
+
+    # Dense tensors:
+
     # These are the only types that are supported for the
     # PyTorch to NumPy conversion
     for t in ["float32", "float64",
@@ -402,6 +409,18 @@ def test_torch_serialization(large_buffer):
     serialization_roundtrip(tensor_requiring_grad, large_buffer,
                             context=serialization_context)
 
+    # Sparse tensors:
+
+    # These are the only types that are supported for the
+    # PyTorch to NumPy conversion
+    for t in ["float32", "float64",
+              "uint8", "int16", "int32", "int64"]:
+        i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
+        v = torch.from_numpy(np.array([3, 4, 5]).astype(t))
+        obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3]))
+        serialization_roundtrip(obj, large_buffer,
+                                context=serialization_context)
+
 
 @pytest.mark.skipif(not torch or not torch.cuda.is_available(),
                     reason="requires pytorch with CUDA")