You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/27 00:07:23 UTC
[arrow] branch master updated: ARROW-2917: [Python] Use detach() to
avoid PyTorch gradient errors
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new fdc8e6a ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors
fdc8e6a is described below
commit fdc8e6a7b278dfaf59c7fa2d7367c407bfb264ec
Author: Wes McKinney <we...@apache.org>
AuthorDate: Thu Jul 26 20:07:16 2018 -0400
ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors
`detach()` doesn't copy data unless it has to and will give a RuntimeError if the detached data needs to have its gradient calculated.
Author: Wes McKinney <we...@apache.org>
Author: Alok Singh <83...@users.noreply.github.com>
Closes #2311 from alok/patch-1 and squashes the following commits:
e451de85 <Wes McKinney> Add unit test serializing pytorch tensor requiring gradiant that fails on master
f8e298f5 <Alok Singh> Use detach() to avoid torch gradient errors
---
python/pyarrow/serialization.py | 2 +-
python/pyarrow/tests/test_serialization.py | 4 ++++
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 8669e82..e398e9d 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -136,7 +136,7 @@ def register_torch_serialization_handlers(serialization_context):
import torch
def _serialize_torch_tensor(obj):
- return obj.numpy()
+ return obj.detach().numpy()
def _deserialize_torch_tensor(data):
return torch.from_numpy(data)
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index e484ebb..6cc391a 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -364,6 +364,10 @@ def test_torch_serialization(large_buffer):
serialization_roundtrip(obj, large_buffer,
context=serialization_context)
+ tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
+ serialization_roundtrip(tensor_requiring_grad, large_buffer,
+ context=serialization_context)
+
def test_numpy_immutable(large_buffer):
obj = np.zeros([10])