You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/27 15:38:00 UTC

[arrow] branch master updated: ARROW-2920: [Python] Fix pytorch segfault

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 537e7f7  ARROW-2920: [Python] Fix pytorch segfault
537e7f7 is described below

commit 537e7f7fd503dd920c0b9f0cef8a2de86bc69e3b
Author: Philipp Moritz <pc...@gmail.com>
AuthorDate: Fri Jul 27 11:37:52 2018 -0400

    ARROW-2920: [Python] Fix pytorch segfault
    
    This fixes ARROW-2920 (see also https://github.com/ray-project/ray/issues/2447) for me
    
    Unfortunately we might not be able to have regression tests for this right now because we don't have CUDA in our test toolchain.
    
    Author: Philipp Moritz <pc...@gmail.com>
    
    Closes #2329 from pcmoritz/fix-pytorch-segfault and squashes the following commits:
    
    1d828251 <Philipp Moritz> fix
    74bc93ea <Philipp Moritz> add note
    ff14c4db <Philipp Moritz> fix
    b343ca61 <Philipp Moritz> add regression test
    5f0cafa5 <Philipp Moritz> fix
    2751679d <Philipp Moritz> fix
    10c5a5c4 <Philipp Moritz> workaround for pyarrow segfault
---
 python/pyarrow/__init__.py                 |  2 +
 python/pyarrow/compat.py                   | 81 +++++++++++++++++++-----------
 python/pyarrow/tests/test_serialization.py | 11 ++++
 3 files changed, 65 insertions(+), 29 deletions(-)

diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index 1aaad99..8010b72 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -51,8 +51,10 @@ import pyarrow.compat as compat
 
 
 # Workaround for https://issues.apache.org/jira/browse/ARROW-2657
+# and https://issues.apache.org/jira/browse/ARROW-2920
 if _sys.platform in ('linux', 'linux2'):
     compat.import_tensorflow_extension()
+    compat.import_pytorch_extension()
 
 
 from pyarrow.lib import cpu_count, set_cpu_count
diff --git a/python/pyarrow/compat.py b/python/pyarrow/compat.py
index 47aeaa5..44e156e 100644
--- a/python/pyarrow/compat.py
+++ b/python/pyarrow/compat.py
@@ -160,31 +160,17 @@ def encode_file_path(path):
     # will convert utf8 to utf16
     return encoded_path
 
-def import_tensorflow_extension():
+def _iterate_python_module_paths(package_name):
     """
-    Load the TensorFlow extension if it exists.
+    Return an iterator to full paths of a python package.
 
-    This is used to load the TensorFlow extension before
-    pyarrow.lib. If we don't do this there are symbol clashes
-    between TensorFlow's use of threading and our global
-    thread pool, see also
-    https://issues.apache.org/jira/browse/ARROW-2657 and
-    https://github.com/apache/arrow/pull/2096.
+    This is a best effort and might fail (for example on Python 2).
+    It uses the official way of loading modules from
+    https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
     """
-    import os
-    tensorflow_loaded = False
-
-    # Try to load the tensorflow extension directly
-    # This is a performance optimization, tensorflow will always be
-    # loaded via the "import tensorflow" statement below if this
-    # doesn't succeed.
-    #
-    # This uses the official way of loading modules from
-    # https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
-
     try:
         import importlib
-        absolute_name = importlib.util.resolve_name("tensorflow", None)
+        absolute_name = importlib.util.resolve_name(package_name, None)
     except (ImportError, AttributeError):
         # Sometimes, importlib is not available (e.g. Python 2)
         # or importlib.util is not available (e.g. Python 2.7)
@@ -205,16 +191,37 @@ def import_tensorflow_extension():
     if spec:
         module = importlib.util.module_from_spec(spec)
         for path in module.__path__:
-            ext = os.path.join(path, "libtensorflow_framework.so")
-            if os.path.exists(ext):
-                import ctypes
-                try:
-                    ctypes.CDLL(ext)
-                except OSError:
-                    pass
-                tensorflow_loaded = True
-                break
+            yield path
 
+def import_tensorflow_extension():
+    """
+    Load the TensorFlow extension if it exists.
+
+    This is used to load the TensorFlow extension before
+    pyarrow.lib. If we don't do this there are symbol clashes
+    between TensorFlow's use of threading and our global
+    thread pool, see also
+    https://issues.apache.org/jira/browse/ARROW-2657 and
+    https://github.com/apache/arrow/pull/2096.
+    """
+    import os
+    tensorflow_loaded = False
+
+    # Try to load the tensorflow extension directly
+    # This is a performance optimization, tensorflow will always be
+    # loaded via the "import tensorflow" statement below if this
+    # doesn't succeed.
+
+    for path in _iterate_python_module_paths("tensorflow"):
+        ext = os.path.join(path, "libtensorflow_framework.so")
+        if os.path.exists(ext):
+            import ctypes
+            try:
+                ctypes.CDLL(ext)
+            except OSError:
+                pass
+            tensorflow_loaded = True
+            break
 
     # If the above failed, try to load tensorflow the normal way
     # (this is more expensive)
@@ -225,6 +232,22 @@ def import_tensorflow_extension():
         except ImportError:
             pass
 
+def import_pytorch_extension():
+    """
+    Load the PyTorch extension if it exists.
+
+    This is used to load the PyTorch extension before
+    pyarrow.lib. If we don't do this there are symbol clashes
+    between PyTorch's use of threading and our global
+    thread pool, see also
+    https://issues.apache.org/jira/browse/ARROW-2920
+    """
+    import ctypes
+    import os
+
+    for path in _iterate_python_module_paths("torch"):
+        ctypes.CDLL(os.path.join(path, "lib/libcaffe2.so"))
+
 
 integer_types = six.integer_types + (np.integer,)
 
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index 6cc391a..53dd5c0 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -369,6 +369,17 @@ def test_torch_serialization(large_buffer):
                             context=serialization_context)
 
 
+@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
+                    reason="requires pytorch with CUDA")
+def test_torch_cuda():
+    # ARROW-2920: This used to segfault if torch is not imported
+    # before pyarrow
+    # Note that this test will only catch the issue if it is run
+    # with a pyarrow that has been built in the manylinux1 environment
+    torch.nn.Conv2d(64, 2, kernel_size=3, stride=1,
+                    padding=1, bias=False).cuda()
+
+
 def test_numpy_immutable(large_buffer):
     obj = np.zeros([10])