You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "rok (via GitHub)" <gi...@apache.org> on 2023/02/15 15:13:23 UTC

[GitHub] [arrow] rok commented on a diff in pull request #33948: GH-33947: [Python][Docs] Tensor canonical type extension example

rok commented on code in PR #33948:
URL: https://github.com/apache/arrow/pull/33948#discussion_r1107257121


##########
python/pyarrow/tests/test_extension_type.py:
##########
@@ -1079,3 +1082,282 @@ def test_array_constructor_from_pandas():
         pd.Series([1, 2, 3], dtype="category"), type=IntegerType()
     )
     assert result.equals(expected)
+
+
+class FixedShapeTensorType(pa.ExtensionType):
+    """
+    Canonical extension type class for fixed shape tensors.
+
+    Parameters
+    ----------
+    value_type : DataType or Field
+        The data type of an individual tensor
+    shape : tuple
+        Shape of the tensors
+    dim_names : tuple, default: None
+        Explicit names of the dimensions.
+    permutation : tuple, default: None
+        Indices of the dimensions ordering.
+
+    Examples
+    --------
+    >>> import pyarrow as pa
+    >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2))
+    >>> tensor_type
+    FixedShapeTensorType(FixedSizeListType(fixed_size_list<item: int32>[4]))
+    >>> pa.register_extension_type(tensor_type)
+    """
+
+    def __init__(self, value_type, shape, dim_names=None, permutation=None):
+        self._value_type = value_type
+        self._shape = shape
+        size = math.prod(shape)
+        self._dim_names = dim_names
+        self._permutation = permutation
+        pa.ExtensionType.__init__(self, pa.list_(self._value_type, size),
+                                  'arrow.fixed_size_tensor')
+
+    @property
+    def value_type(self):
+        """
+        Data type of an individual tensor.
+        """
+        return self._value_type
+
+    @property
+    def shape(self):
+        """
+        Shape of the tensors.
+        """
+        return self._shape
+
+    @property
+    def dim_names(self):
+        """
+        Explicit names of the dimensions.
+        """
+        return self._dim_names
+
+    @property
+    def permutation(self):
+        """
+        Indices of the dimensions ordering.
+        """
+        return self._permutation
+
+    def __arrow_ext_serialize__(self):
+        metadata = {"shape": str(self._shape),
+                    "dim_names": str(self._dim_names),
+                    "permutation": str(self._permutation)}
+        return json.dumps(metadata).encode()
+
+    @classmethod
+    def __arrow_ext_deserialize__(cls, storage_type, serialized):
+        # return an instance of this subclass given the serialized
+        # metadata.
+        assert serialized.decode().startswith('{"shape":')
+
+        metadata = json.loads(serialized.decode())
+        shape = ast.literal_eval(metadata['shape'])
+        dim_names = ast.literal_eval(metadata['dim_names'])
+        permutation = ast.literal_eval(metadata['permutation'])
+
+        return FixedShapeTensorType(storage_type.value_type, shape,
+                                    dim_names, permutation)
+
+    def __arrow_ext_class__(self):
+        return FixedShapeTensorArray
+
+
+class FixedShapeTensorArray(pa.ExtensionArray):
+    """
+    Canonical extension array class for fixed shape tensors.
+
+    Examples
+    --------
+    Define and register extension type for tensor array
+
+    >>> import pyarrow as pa
+    >>> tensor_type = FixedShapeTensorType(pa.int32(), (2, 2))
+    >>> pa.register_extension_type(tensor_type)
+
+    Create an extension array
+
+    >>> arr = [[1, 2, 3, 4], [10, 20, 30, 40], [100, 200, 300, 400]]
+    >>> storage = pa.array(arr, pa.list_(pa.int32(), 4))
+    >>> pa.ExtensionArray.from_storage(tensor_type, storage)
+    <__main__.FixedShapeTensorArray object at ...>
+    [
+      [
+        1,
+        2,
+        3,
+        4
+      ],
+      [
+        10,
+        20,
+        30,
+        40
+      ],
+      [
+        100,
+        200,
+        300,
+        400
+      ]
+    ]
+    """
+
+    def to_numpy_tensor(self):

Review Comment:
   I'd perhaps name these methods `to_numpy` and `from_numpy`. Tensor part is already implied from context and `Tensor` class [uses the same names](https://github.com/apache/arrow/blob/master/python/pyarrow/tests/test_tensor.py#L78).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org