You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/27 15:38:00 UTC
[arrow] branch master updated: ARROW-2920: [Python] Fix pytorch
segfault
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 537e7f7 ARROW-2920: [Python] Fix pytorch segfault
537e7f7 is described below
commit 537e7f7fd503dd920c0b9f0cef8a2de86bc69e3b
Author: Philipp Moritz <pc...@gmail.com>
AuthorDate: Fri Jul 27 11:37:52 2018 -0400
ARROW-2920: [Python] Fix pytorch segfault
This fixes ARROW-2920 (see also https://github.com/ray-project/ray/issues/2447) for me
Unfortunately we might not be able to have regression tests for this right now because we don't have CUDA in our test toolchain.
Author: Philipp Moritz <pc...@gmail.com>
Closes #2329 from pcmoritz/fix-pytorch-segfault and squashes the following commits:
1d828251 <Philipp Moritz> fix
74bc93ea <Philipp Moritz> add note
ff14c4db <Philipp Moritz> fix
b343ca61 <Philipp Moritz> add regression test
5f0cafa5 <Philipp Moritz> fix
2751679d <Philipp Moritz> fix
10c5a5c4 <Philipp Moritz> workaround for pyarrow segfault
---
python/pyarrow/__init__.py | 2 +
python/pyarrow/compat.py | 81 +++++++++++++++++++-----------
python/pyarrow/tests/test_serialization.py | 11 ++++
3 files changed, 65 insertions(+), 29 deletions(-)
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index 1aaad99..8010b72 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -51,8 +51,10 @@ import pyarrow.compat as compat
# Workaround for https://issues.apache.org/jira/browse/ARROW-2657
+# and https://issues.apache.org/jira/browse/ARROW-2920
if _sys.platform in ('linux', 'linux2'):
compat.import_tensorflow_extension()
+ compat.import_pytorch_extension()
from pyarrow.lib import cpu_count, set_cpu_count
diff --git a/python/pyarrow/compat.py b/python/pyarrow/compat.py
index 47aeaa5..44e156e 100644
--- a/python/pyarrow/compat.py
+++ b/python/pyarrow/compat.py
@@ -160,31 +160,17 @@ def encode_file_path(path):
# will convert utf8 to utf16
return encoded_path
-def import_tensorflow_extension():
+def _iterate_python_module_paths(package_name):
"""
- Load the TensorFlow extension if it exists.
+ Return an iterator to full paths of a python package.
- This is used to load the TensorFlow extension before
- pyarrow.lib. If we don't do this there are symbol clashes
- between TensorFlow's use of threading and our global
- thread pool, see also
- https://issues.apache.org/jira/browse/ARROW-2657 and
- https://github.com/apache/arrow/pull/2096.
+ This is a best effort and might fail (for example on Python 2).
+ It uses the official way of loading modules from
+ https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
"""
- import os
- tensorflow_loaded = False
-
- # Try to load the tensorflow extension directly
- # This is a performance optimization, tensorflow will always be
- # loaded via the "import tensorflow" statement below if this
- # doesn't succeed.
- #
- # This uses the official way of loading modules from
- # https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
-
try:
import importlib
- absolute_name = importlib.util.resolve_name("tensorflow", None)
+ absolute_name = importlib.util.resolve_name(package_name, None)
except (ImportError, AttributeError):
# Sometimes, importlib is not available (e.g. Python 2)
# or importlib.util is not available (e.g. Python 2.7)
@@ -205,16 +191,37 @@ def import_tensorflow_extension():
if spec:
module = importlib.util.module_from_spec(spec)
for path in module.__path__:
- ext = os.path.join(path, "libtensorflow_framework.so")
- if os.path.exists(ext):
- import ctypes
- try:
- ctypes.CDLL(ext)
- except OSError:
- pass
- tensorflow_loaded = True
- break
+ yield path
+def import_tensorflow_extension():
+ """
+ Load the TensorFlow extension if it exists.
+
+ This is used to load the TensorFlow extension before
+ pyarrow.lib. If we don't do this there are symbol clashes
+ between TensorFlow's use of threading and our global
+ thread pool, see also
+ https://issues.apache.org/jira/browse/ARROW-2657 and
+ https://github.com/apache/arrow/pull/2096.
+ """
+ import os
+ tensorflow_loaded = False
+
+ # Try to load the tensorflow extension directly
+ # This is a performance optimization, tensorflow will always be
+ # loaded via the "import tensorflow" statement below if this
+ # doesn't succeed.
+
+ for path in _iterate_python_module_paths("tensorflow"):
+ ext = os.path.join(path, "libtensorflow_framework.so")
+ if os.path.exists(ext):
+ import ctypes
+ try:
+ ctypes.CDLL(ext)
+ except OSError:
+ pass
+ tensorflow_loaded = True
+ break
# If the above failed, try to load tensorflow the normal way
# (this is more expensive)
@@ -225,6 +232,22 @@ def import_tensorflow_extension():
except ImportError:
pass
+def import_pytorch_extension():
+ """
+ Load the PyTorch extension if it exists.
+
+ This is used to load the PyTorch extension before
+ pyarrow.lib. If we don't do this there are symbol clashes
+ between PyTorch's use of threading and our global
+ thread pool, see also
+ https://issues.apache.org/jira/browse/ARROW-2920
+ """
+ import ctypes
+ import os
+
+ for path in _iterate_python_module_paths("torch"):
+ ctypes.CDLL(os.path.join(path, "lib/libcaffe2.so"))
+
integer_types = six.integer_types + (np.integer,)
diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py
index 6cc391a..53dd5c0 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -369,6 +369,17 @@ def test_torch_serialization(large_buffer):
context=serialization_context)
+@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
+ reason="requires pytorch with CUDA")
+def test_torch_cuda():
+ # ARROW-2920: This used to segfault if torch is not imported
+ # before pyarrow
+ # Note that this test will only catch the issue if it is run
+ # with a pyarrow that has been built in the manylinux1 environment
+ torch.nn.Conv2d(64, 2, kernel_size=3, stride=1,
+ padding=1, bias=False).cuda()
+
+
def test_numpy_immutable(large_buffer):
obj = np.zeros([10])