You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2023/06/13 08:58:19 UTC

[arrow] branch main updated: GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)

This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 059d8b669a GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
059d8b669a is described below

commit 059d8b669aeeda0ab8ac1ae9eab336a93daa3613
Author: Alenka Frim <Al...@users.noreply.github.com>
AuthorDate: Tue Jun 13 10:58:10 2023 +0200

    GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
    
    This PR adds `__reduce__` method to the `FixedShapeTensorType`.
    * Closes: #35599
    
    Authored-by: AlenkaF <fr...@gmail.com>
    Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
 python/pyarrow/tests/test_extension_type.py | 16 ++++++++++++++++
 python/pyarrow/types.pxi                    |  4 ++++
 2 files changed, 20 insertions(+)

diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py
index 023968b20d..78618c19ff 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -1306,3 +1306,19 @@ def test_extension_to_pandas_storage_type(registered_period_type):
         # Check the usage of types_mapper
         result = table.to_pandas(types_mapper=pd.ArrowDtype)
         assert isinstance(result["ext"].dtype, pd.ArrowDtype)
+
+
+def test_tensor_type_is_picklable():
+    # GH-35599
+
+    expected_type = pa.fixed_shape_tensor(pa.int32(), (2, 2))
+    result = pickle.loads(pickle.dumps(expected_type))
+
+    assert result == expected_type
+
+    arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
+    storage = pa.array(arr, pa.list_(pa.int32(), 4))
+    expected_arr = pa.ExtensionArray.from_storage(expected_type, storage)
+    result = pickle.loads(pickle.dumps(expected_arr))
+
+    assert result == expected_arr
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index bcd358c9a5..48605c293e 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -1586,6 +1586,10 @@ cdef class FixedShapeTensorType(BaseExtensionType):
     def __arrow_ext_class__(self):
         return FixedShapeTensorArray
 
+    def __reduce__(self):
+        return fixed_shape_tensor, (self.value_type, self.shape,
+                                    self.dim_names, self.permutation)
+
 
 cdef class PyExtensionType(ExtensionType):
     """