You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/06/07 15:26:08 UTC
[arrow] branch master updated: ARROW-4452: [Python] Serialize
sparse torch tensors
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e49631d ARROW-4452: [Python] Serialize sparse torch tensors
e49631d is described below
commit e49631dc18ff11ee755ddb94d2fa82cf1be2e9a0
Author: Philipp Moritz <pc...@gmail.com>
AuthorDate: Fri Jun 7 10:25:54 2019 -0500
ARROW-4452: [Python] Serialize sparse torch tensors
Author: Philipp Moritz <pc...@gmail.com>
Closes #3542 from pcmoritz/torch-sparse-tensors and squashes the following commits:
0ba0a33f3 <Philipp Moritz> fix comparison
fd4774936 <Philipp Moritz> test for sparse tensor properly
9da9bfe3d <Philipp Moritz> add serialization for sparse pytorch tensors
2c37e8ebd <Philipp Moritz> add sparse tensor support for torch
---
python/pyarrow/serialization.py | 13 +++++++++++--
python/pyarrow/tests/test_serialization.py | 19 +++++++++++++++++++
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 3a605a9..89f6e23 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -230,10 +230,19 @@ def register_torch_serialization_handlers(serialization_context):
import torch
def _serialize_torch_tensor(obj):
- return obj.detach().numpy()
+ if obj.is_sparse:
+ # TODO(pcm): Once ARROW-4453 is resolved, return sparse
+ # tensor representation here
+ return (obj._indices().detach().numpy(),
+ obj._values().detach().numpy(), list(obj.shape))
+ else:
+ return obj.detach().numpy()
def _deserialize_torch_tensor(data):
- return torch.from_numpy(data)
+ if isinstance(data, tuple):
+ return torch.sparse_coo_tensor(data[0], data[1], data[2])
+ else:
+ return torch.from_numpy(data)
for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index cbff0ff..22983e7 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -43,6 +43,10 @@ except ImportError:
def assert_equal(obj1, obj2):
if torch is not None and torch.is_tensor(obj1) and torch.is_tensor(obj2):
+ if obj1.is_sparse:
+ obj1 = obj1.to_dense()
+ if obj2.is_sparse:
+ obj2 = obj2.to_dense()
assert torch.equal(obj1, obj2)
return
module_numpy = (type(obj1).__module__ == np.__name__ or
@@ -390,6 +394,9 @@ def test_torch_serialization(large_buffer):
serialization_context = pa.default_serialization_context()
pa.register_torch_serialization_handlers(serialization_context)
+
+ # Dense tensors:
+
# These are the only types that are supported for the
# PyTorch to NumPy conversion
for t in ["float32", "float64",
@@ -402,6 +409,18 @@ def test_torch_serialization(large_buffer):
serialization_roundtrip(tensor_requiring_grad, large_buffer,
context=serialization_context)
+ # Sparse tensors:
+
+ # These are the only types that are supported for the
+ # PyTorch to NumPy conversion
+ for t in ["float32", "float64",
+ "uint8", "int16", "int32", "int64"]:
+ i = torch.LongTensor([[0, 2], [1, 0], [1, 2]])
+ v = torch.from_numpy(np.array([3, 4, 5]).astype(t))
+ obj = torch.sparse_coo_tensor(i.t(), v, torch.Size([2, 3]))
+ serialization_roundtrip(obj, large_buffer,
+ context=serialization_context)
+
@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
reason="requires pytorch with CUDA")