You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/07/19 11:46:30 UTC

[arrow] branch master updated: ARROW-17065: [Python] Allow using subclassed ExtensionScalar in ExtensionType (#13594)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b53adc8de ARROW-17065: [Python] Allow using subclassed ExtensionScalar in ExtensionType (#13594)
0b53adc8de is described below

commit 0b53adc8de0b7fb6fc96ebbe704d928797f28222
Author: Rok Mihevc <ro...@mihevc.org>
AuthorDate: Tue Jul 19 13:46:20 2022 +0200

    ARROW-17065: [Python] Allow using subclassed ExtensionScalar in ExtensionType (#13594)
    
    This is to resolve [ARROW-17065](https://issues.apache.org/jira/browse/ARROW-17065).
    
    Lead-authored-by: Rok Mihevc <ro...@mihevc.org>
    Co-authored-by: Rok <ro...@mihevc.org>
    Signed-off-by: David Li <li...@gmail.com>
---
 docs/source/python/extending_types.rst      | 10 +++++++---
 python/pyarrow/public-api.pxi               |  2 +-
 python/pyarrow/scalar.pxi                   | 28 +++++++++++++++++++---------
 python/pyarrow/tests/test_extension_type.py | 14 ++++++++++++--
 python/pyarrow/types.pxi                    | 22 ++++++++--------------
 5 files changed, 47 insertions(+), 29 deletions(-)

diff --git a/docs/source/python/extending_types.rst b/docs/source/python/extending_types.rst
index 13c3f226d3..7e241f6441 100644
--- a/docs/source/python/extending_types.rst
+++ b/docs/source/python/extending_types.rst
@@ -290,12 +290,16 @@ Custom scalar conversion
 ~~~~~~~~~~~~~~~~~~~~~~~~
 
 If you want scalars of your custom extension type to convert to a custom type when
-:meth:`ExtensionScalar.as_py()` is called, you can override the :meth:`ExtensionType.scalar_as_py()`
+:meth:`ExtensionScalar.as_py()` is called, you can override the :meth:`ExtensionScalar.as_py()` by subclassing :class:`ExtensionScalar`.
 method. For example, if we wanted the above example 3D point type to return a custom
 3D point class instead of a list, we would implement::
 
     Point3D = namedtuple("Point3D", ["x", "y", "z"])
 
+    class Point3DScalar(pa.ExtensionScalar):
+        def as_py(self) -> Point3D:
+            return Point3D(*self.value.as_py())
+
     class Point3DType(pa.PyExtensionType):
         def __init__(self):
             pa.PyExtensionType.__init__(self, pa.list_(pa.float32(), 3))
@@ -303,8 +307,8 @@ method. For example, if we wanted the above example 3D point type to return a cu
         def __reduce__(self):
             return Point3DType, ()
 
-        def scalar_as_py(self, scalar: pa.ListScalar) -> Point3D:
-            return Point3D(*scalar.as_py())
+        def __arrow_ext_scalar_class__(self):
+            return Point3DScalar
 
 Arrays built using this extension type now provide scalars that convert to our ``Point3D`` class::
 
diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi
index c427fb9f5d..6e9edd55b4 100644
--- a/python/pyarrow/public-api.pxi
+++ b/python/pyarrow/public-api.pxi
@@ -257,7 +257,7 @@ cdef api object pyarrow_wrap_scalar(const shared_ptr[CScalar]& sp_scalar):
     if data_type.id() not in _scalar_classes:
         raise ValueError('Scalar type not supported')
 
-    klass = _scalar_classes[data_type.id()]
+    klass = get_scalar_class_from_type(sp_scalar.get().type)
 
     cdef Scalar scalar = klass.__new__(klass)
     scalar.init(sp_scalar)
diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi
index 5995242b2d..c802caa15f 100644
--- a/python/pyarrow/scalar.pxi
+++ b/python/pyarrow/scalar.pxi
@@ -35,16 +35,16 @@ cdef class Scalar(_Weakrefable):
         cdef:
             Scalar self
             Type type_id = wrapped.get().type.get().id()
+            shared_ptr[CDataType] sp_data_type = wrapped.get().type
 
         if type_id == _Type_NA:
             return _NULL
 
-        try:
-            typ = _scalar_classes[type_id]
-        except KeyError:
+        if type_id not in _scalar_classes:
             raise NotImplementedError(
-                "Wrapping scalar of type " +
-                frombytes(wrapped.get().type.get().ToString()))
+                "Wrapping scalar of type " + frombytes(sp_data_type.get().ToString()))
+
+        typ = get_scalar_class_from_type(sp_data_type)
         self = typ.__new__(typ)
         self.init(wrapped)
 
@@ -902,10 +902,7 @@ cdef class ExtensionScalar(Scalar):
         """
         Return this scalar as a Python object.
         """
-        if self.value is None:
-            return None
-        else:
-            return self.type.scalar_as_py(self.value)
+        return None if self.value is None else self.value.as_py()
 
     @staticmethod
     def from_storage(BaseExtensionType typ, value):
@@ -990,6 +987,19 @@ cdef dict _scalar_classes = {
 }
 
 
+cdef object get_scalar_class_from_type(
+        const shared_ptr[CDataType]& sp_data_type):
+    cdef CDataType* data_type = sp_data_type.get()
+    if data_type == NULL:
+        raise ValueError('Scalar data type was NULL')
+
+    if data_type.id() == _Type_EXTENSION:
+        py_ext_data_type = pyarrow_wrap_data_type(sp_data_type)
+        return py_ext_data_type.__arrow_ext_scalar_class__()
+    else:
+        return _scalar_classes[data_type.id()]
+
+
 def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None):
     """
     Create a pyarrow.Scalar instance from a Python object.
diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py
index dedef160f4..9c5a394f89 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -34,6 +34,11 @@ class IntegerType(pa.PyExtensionType):
         return IntegerType, ()
 
 
+class UuidScalarType(pa.ExtensionScalar):
+    def as_py(self):
+        return None if self.value is None else UUID(bytes=self.value.as_py())
+
+
 class UuidType(pa.PyExtensionType):
 
     def __init__(self):
@@ -42,8 +47,8 @@ class UuidType(pa.PyExtensionType):
     def __reduce__(self):
         return UuidType, ()
 
-    def scalar_as_py(self, scalar):
-        return UUID(bytes=scalar.as_py())
+    def __arrow_ext_scalar_class__(self):
+        return UuidScalarType
 
 
 class UuidType2(pa.PyExtensionType):
@@ -302,6 +307,10 @@ def test_ext_scalar_from_array():
     scalars_a = list(a)
     assert len(scalars_a) == 4
 
+    assert ty1.__arrow_ext_scalar_class__() == UuidScalarType
+    assert type(a[0]) == UuidScalarType
+    assert type(scalars_a[0]) == UuidScalarType
+
     for s, val in zip(scalars_a, data):
         assert isinstance(s, pa.ExtensionScalar)
         assert s.is_valid == (val is not None)
@@ -316,6 +325,7 @@ def test_ext_scalar_from_array():
     assert len(scalars_b) == 4
 
     for sa, sb in zip(scalars_a, scalars_b):
+        assert isinstance(sb, pa.ExtensionScalar)
         assert sa.is_valid == sb.is_valid
         if sa.as_py() is None:
             assert sa.as_py() == sb.as_py()
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 894dd9617f..8407f95c98 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -919,22 +919,16 @@ cdef class ExtensionType(BaseExtensionType):
         """
         return ExtensionArray
 
-    def scalar_as_py(self, scalar):
-        """Convert scalar to a Python type.
+    def __arrow_ext_scalar_class__(self):
+        """Return an extension scalar class for building scalars with this
+        extension type.
 
-        This method can be overridden in subclasses to customize what type
-        scalars are converted to.
-
-        Parameters
-        ----------
-        scalar : pyarrow.Scalar
-          Not-None Scalar of storage type to be converted to a Python object.
-
-        Returns
-        -------
-        Scalar value as a native Python object.
+        This method should return subclass of the ExtensionScalar class. By
+        default, if not specialized in the extension implementation, an
+        extension type scalar will be a built-in ExtensionScalar instance.
         """
-        return scalar.as_py()
+        return ExtensionScalar
+
 
 cdef class PyExtensionType(ExtensionType):
     """