You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2023/06/27 15:42:11 UTC

[arrow] branch main updated: GH-36038: [Python] Implement __reduce__ on ExtensionType class (#36170)

This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new f865fbd76c GH-36038: [Python] Implement __reduce__ on ExtensionType class (#36170)
f865fbd76c is described below

commit f865fbd76ce6235bbef2f7ba706d78dc7fd32190
Author: Alenka Frim <Al...@users.noreply.github.com>
AuthorDate: Tue Jun 27 17:42:04 2023 +0200

    GH-36038: [Python] Implement __reduce__ on ExtensionType class (#36170)
    
    ### Rationale for this change
    `ExtensionType` subclasses can't be pickled if `__reduce__` method isn't implemented separately.
    
    ### What changes are included in this PR?
    Add `__reduce__` method to `ExtensionType` class.
    
    ### Are these changes tested?
    Yes, test is added to python/pyarrow/tests/test_extension_type.py.
    
    ### Are there any user-facing changes?
    No.
    * Closes: #36038
    
    Authored-by: AlenkaF <fr...@gmail.com>
    Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
 python/pyarrow/tests/test_extension_type.py | 25 +++++++++++++++++++++++++
 python/pyarrow/types.pxi                    |  3 +++
 2 files changed, 28 insertions(+)

diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py
index 78618c19ff..51bf57c9ba 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -867,6 +867,31 @@ def test_generic_ext_type_equality():
     assert not period_type == period_type3
 
 
+def test_generic_ext_type_pickling(registered_period_type):
+    # GH-36038
+    for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+        period_type, _ = registered_period_type
+        ser = pickle.dumps(period_type, protocol=proto)
+        period_type_pickled = pickle.loads(ser)
+        assert period_type == period_type_pickled
+
+
+def test_generic_ext_array_pickling(registered_period_type):
+    for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+        period_type, _ = registered_period_type
+        storage = pa.array([1, 2, 3, 4], pa.int64())
+        arr = pa.ExtensionArray.from_storage(period_type, storage)
+        ser = pickle.dumps(arr, protocol=proto)
+        del storage, arr
+        arr = pickle.loads(ser)
+        arr.validate()
+        assert isinstance(arr, pa.ExtensionArray)
+        assert arr.type == period_type
+        assert arr.type.storage_type == pa.int64()
+        assert arr.storage.type == pa.int64()
+        assert arr.storage.to_pylist() == [1, 2, 3, 4]
+
+
 def test_generic_ext_type_register(registered_period_type):
     # test that trying to register other type does not segfault
     with pytest.raises(TypeError):
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 48605c293e..a3311cbbcf 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -1487,6 +1487,9 @@ cdef class ExtensionType(BaseExtensionType):
         """
         return NotImplementedError
 
+    def __reduce__(self):
+        return self.__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())
+
     def __arrow_ext_class__(self):
         """Return an extension array class to be used for building or
         deserializing arrays with this extension type.