You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/27 00:07:23 UTC

[arrow] branch master updated: ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new fdc8e6a  ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors
fdc8e6a is described below

commit fdc8e6a7b278dfaf59c7fa2d7367c407bfb264ec
Author: Wes McKinney <we...@apache.org>
AuthorDate: Thu Jul 26 20:07:16 2018 -0400

    ARROW-2917: [Python] Use detach() to avoid PyTorch gradient errors
    
    `detach()` doesn't copy data unless it has to and will give a RuntimeError if the detached data needs to have its gradient calculated.
    
    Author: Wes McKinney <we...@apache.org>
    Author: Alok Singh <83...@users.noreply.github.com>
    
    Closes #2311 from alok/patch-1 and squashes the following commits:
    
    e451de85 <Wes McKinney> Add unit test serializing pytorch tensor requiring gradiant that fails on master
    f8e298f5 <Alok Singh> Use detach() to avoid torch gradient errors
---
 python/pyarrow/serialization.py            | 2 +-
 python/pyarrow/tests/test_serialization.py | 4 ++++
 2 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 8669e82..e398e9d 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -136,7 +136,7 @@ def register_torch_serialization_handlers(serialization_context):
         import torch
 
         def _serialize_torch_tensor(obj):
-            return obj.numpy()
+            return obj.detach().numpy()
 
         def _deserialize_torch_tensor(data):
             return torch.from_numpy(data)
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index e484ebb..6cc391a 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -364,6 +364,10 @@ def test_torch_serialization(large_buffer):
         serialization_roundtrip(obj, large_buffer,
                                 context=serialization_context)
 
+    tensor_requiring_grad = torch.randn(10, 10, requires_grad=True)
+    serialization_roundtrip(tensor_requiring_grad, large_buffer,
+                            context=serialization_context)
+
 
 def test_numpy_immutable(large_buffer):
     obj = np.zeros([10])