You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2023/06/13 08:58:19 UTC
[arrow] branch main updated: GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 059d8b669a GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
059d8b669a is described below
commit 059d8b669aeeda0ab8ac1ae9eab336a93daa3613
Author: Alenka Frim <Al...@users.noreply.github.com>
AuthorDate: Tue Jun 13 10:58:10 2023 +0200
GH-35599: [Python] Canonical fixed-shape tensor extension array/type is not picklable. (#35933)
This PR adds `__reduce__` method to the `FixedShapeTensorType`.
* Closes: #35599
Authored-by: AlenkaF <fr...@gmail.com>
Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
python/pyarrow/tests/test_extension_type.py | 16 ++++++++++++++++
python/pyarrow/types.pxi | 4 ++++
2 files changed, 20 insertions(+)
diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py
index 023968b20d..78618c19ff 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -1306,3 +1306,19 @@ def test_extension_to_pandas_storage_type(registered_period_type):
# Check the usage of types_mapper
result = table.to_pandas(types_mapper=pd.ArrowDtype)
assert isinstance(result["ext"].dtype, pd.ArrowDtype)
+
+
+def test_tensor_type_is_picklable():
+ # GH-35599
+
+ expected_type = pa.fixed_shape_tensor(pa.int32(), (2, 2))
+ result = pickle.loads(pickle.dumps(expected_type))
+
+ assert result == expected_type
+
+ arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
+ storage = pa.array(arr, pa.list_(pa.int32(), 4))
+ expected_arr = pa.ExtensionArray.from_storage(expected_type, storage)
+ result = pickle.loads(pickle.dumps(expected_arr))
+
+ assert result == expected_arr
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index bcd358c9a5..48605c293e 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -1586,6 +1586,10 @@ cdef class FixedShapeTensorType(BaseExtensionType):
def __arrow_ext_class__(self):
return FixedShapeTensorArray
+ def __reduce__(self):
+ return fixed_shape_tensor, (self.value_type, self.shape,
+ self.dim_names, self.permutation)
+
cdef class PyExtensionType(ExtensionType):
"""