You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2023/04/04 12:51:46 UTC
[arrow] branch main updated: GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)
This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new a84a39b640 GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)
a84a39b640 is described below
commit a84a39b640475dcc7cf8f6bde38e7e471e489e3d
Author: Rok Mihevc <ro...@mihevc.org>
AuthorDate: Tue Apr 4 14:51:38 2023 +0200
GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)
> [ARROW-1614](https://issues.apache.org/jira/browse/ARROW-1614): In an Arrow table, we would like to add support for a column that has values cells each containing a tensor value, with all tensors having the same dimensions. These would be stored as a binary value, plus some metadata to store type and shape/strides.
* Closes: #15483
Lead-authored-by: Rok Mihevc <ro...@mihevc.org>
Co-authored-by: Rok <ro...@mihevc.org>
Co-authored-by: Joris Van den Bossche <jo...@gmail.com>
Co-authored-by: Ben Harkins <60...@users.noreply.github.com>
Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
cpp/src/arrow/CMakeLists.txt | 2 +
cpp/src/arrow/extension/CMakeLists.txt | 24 +++
cpp/src/arrow/extension/fixed_shape_tensor.cc | 170 ++++++++++++++++
cpp/src/arrow/extension/fixed_shape_tensor.h | 92 +++++++++
cpp/src/arrow/extension/fixed_shape_tensor_test.cc | 215 +++++++++++++++++++++
cpp/src/arrow/extension_type.cc | 12 ++
6 files changed, 515 insertions(+)
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 07ea8930ff..143fb13ddc 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -520,6 +520,7 @@ endif()
if(ARROW_JSON)
list(APPEND
ARROW_SRCS
+ extension/fixed_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
@@ -856,6 +857,7 @@ endif()
if(ARROW_JSON)
add_subdirectory(json)
+ add_subdirectory(extension)
endif()
if(ARROW_ORC)
diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt
new file mode 100644
index 0000000000..c15c42874d
--- /dev/null
+++ b/cpp/src/arrow/extension/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_arrow_test(test
+ SOURCES
+ fixed_shape_tensor_test.cc
+ PREFIX
+ "arrow-fixed-shape-tensor")
+
+arrow_install_all_headers("arrow/extension")
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc
new file mode 100644
index 0000000000..8b0ed43df5
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <numeric>
+#include <sstream>
+
+#include "arrow/extension/fixed_shape_tensor.h"
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+#include "arrow/util/int_util_overflow.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/sort.h"
+
+#include <rapidjson/document.h>
+#include <rapidjson/writer.h>
+
+namespace rj = arrow::rapidjson;
+
+namespace arrow {
+namespace extension {
+
+bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
+ if (extension_name() != other.extension_name()) {
+ return false;
+ }
+ const auto& other_ext = static_cast<const FixedShapeTensorType&>(other);
+
+ auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
+ for (size_t i = 1; i < permutation.size(); ++i) {
+ if (permutation[i - 1] + 1 != permutation[i]) {
+ return false;
+ }
+ }
+ return true;
+ };
+ const bool permutation_equivalent =
+ ((permutation_ == other_ext.permutation()) ||
+ (permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
+ (is_permutation_trivial(permutation_) && other_ext.permutation().empty()));
+
+ return (storage_type()->Equals(other_ext.storage_type())) &&
+ (this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
+ permutation_equivalent;
+}
+
+std::string FixedShapeTensorType::Serialize() const {
+ rj::Document document;
+ document.SetObject();
+ rj::Document::AllocatorType& allocator = document.GetAllocator();
+
+ rj::Value shape(rj::kArrayType);
+ for (auto v : shape_) {
+ shape.PushBack(v, allocator);
+ }
+ document.AddMember(rj::Value("shape", allocator), shape, allocator);
+
+ if (!permutation_.empty()) {
+ rj::Value permutation(rj::kArrayType);
+ for (auto v : permutation_) {
+ permutation.PushBack(v, allocator);
+ }
+ document.AddMember(rj::Value("permutation", allocator), permutation, allocator);
+ }
+
+ if (!dim_names_.empty()) {
+ rj::Value dim_names(rj::kArrayType);
+ for (std::string v : dim_names_) {
+ dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator);
+ }
+ document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator);
+ }
+
+ rj::StringBuffer buffer;
+ rj::Writer<rj::StringBuffer> writer(buffer);
+ document.Accept(writer);
+ return buffer.GetString();
+}
+
+Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
+ std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
+ if (storage_type->id() != Type::FIXED_SIZE_LIST) {
+ return Status::Invalid("Expected FixedSizeList storage type, got ",
+ storage_type->ToString());
+ }
+ auto value_type =
+ internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
+ rj::Document document;
+ if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
+ !document.HasMember("shape") || !document["shape"].IsArray()) {
+ return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
+ }
+
+ std::vector<int64_t> shape;
+ for (auto& x : document["shape"].GetArray()) {
+ shape.emplace_back(x.GetInt64());
+ }
+ std::vector<int64_t> permutation;
+ if (document.HasMember("permutation")) {
+ for (auto& x : document["permutation"].GetArray()) {
+ permutation.emplace_back(x.GetInt64());
+ }
+ if (shape.size() != permutation.size()) {
+ return Status::Invalid("Invalid permutation");
+ }
+ }
+ std::vector<std::string> dim_names;
+ if (document.HasMember("dim_names")) {
+ for (auto& x : document["dim_names"].GetArray()) {
+ dim_names.emplace_back(x.GetString());
+ }
+ if (shape.size() != dim_names.size()) {
+ return Status::Invalid("Invalid dim_names");
+ }
+ }
+
+ return fixed_shape_tensor(value_type, shape, permutation, dim_names);
+}
+
+std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
+ std::shared_ptr<ArrayData> data) const {
+ DCHECK_EQ(data->type->id(), Type::EXTENSION);
+ DCHECK_EQ("arrow.fixed_shape_tensor",
+ static_cast<const ExtensionType&>(*data->type).extension_name());
+ return std::make_shared<ExtensionArray>(data);
+}
+
+Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
+ const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
+ if (!permutation.empty() && shape.size() != permutation.size()) {
+ return Status::Invalid("permutation size must match shape size. Expected: ",
+ shape.size(), " Got: ", permutation.size());
+ }
+ if (!dim_names.empty() && shape.size() != dim_names.size()) {
+ return Status::Invalid("dim_names size must match shape size. Expected: ",
+ shape.size(), " Got: ", dim_names.size());
+ }
+ const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
+ std::multiplies<>());
+ return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
+ shape, permutation, dim_names);
+}
+
+std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& permutation,
+ const std::vector<std::string>& dim_names) {
+ auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
+ ARROW_DCHECK_OK(maybe_type.status());
+ return maybe_type.MoveValueUnsafe();
+}
+
+} // namespace extension
+} // namespace arrow
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h
new file mode 100644
index 0000000000..4ee2b894ee
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension_type.h"
+
+namespace arrow {
+namespace extension {
+
+class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
+ public:
+ using ExtensionArray::ExtensionArray;
+};
+
+/// \brief Concrete type class for constant-size Tensor data.
+/// This is a canonical arrow extension type.
+/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
+class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
+ public:
+ FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
+ const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& permutation = {},
+ const std::vector<std::string>& dim_names = {})
+ : ExtensionType(fixed_size_list(value_type, size)),
+ value_type_(value_type),
+ shape_(shape),
+ permutation_(permutation),
+ dim_names_(dim_names) {}
+
+ std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
+
+ /// Number of dimensions of tensor elements
+ size_t ndim() { return shape_.size(); }
+
+ /// Shape of tensor elements
+ const std::vector<int64_t> shape() const { return shape_; }
+
+ /// Value type of tensor elements
+ const std::shared_ptr<DataType> value_type() const { return value_type_; }
+
+ /// Permutation mapping from logical to physical memory layout of tensor elements
+ const std::vector<int64_t>& permutation() const { return permutation_; }
+
+ /// Dimension names of tensor elements. Dimensions are ordered physically.
+ const std::vector<std::string>& dim_names() const { return dim_names_; }
+
+ bool ExtensionEquals(const ExtensionType& other) const override;
+
+ std::string Serialize() const override;
+
+ Result<std::shared_ptr<DataType>> Deserialize(
+ std::shared_ptr<DataType> storage_type,
+ const std::string& serialized_data) const override;
+
+ /// Create a FixedShapeTensorArray from ArrayData
+ std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+ /// \brief Create a FixedShapeTensorType instance
+ static Result<std::shared_ptr<DataType>> Make(
+ const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& permutation = {},
+ const std::vector<std::string>& dim_names = {});
+
+ private:
+ std::shared_ptr<DataType> storage_type_;
+ std::shared_ptr<DataType> value_type_;
+ std::vector<int64_t> shape_;
+ std::vector<int64_t> permutation_;
+ std::vector<std::string> dim_names_;
+};
+
+/// \brief Return a FixedShapeTensorType instance.
+ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
+ const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
+ const std::vector<int64_t>& permutation = {},
+ const std::vector<std::string>& dim_names = {});
+
+} // namespace extension
+} // namespace arrow
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
new file mode 100644
index 0000000000..16ba9d2014
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension/fixed_shape_tensor.h"
+
+#include "arrow/testing/matchers.h"
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/tensor.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+using FixedShapeTensorType = extension::FixedShapeTensorType;
+using extension::fixed_shape_tensor;
+using extension::FixedShapeTensorArray;
+
+class TestExtensionType : public ::testing::Test {
+ public:
+ void SetUp() override {
+ shape_ = {3, 3, 4};
+ cell_shape_ = {3, 4};
+ value_type_ = int64();
+ cell_type_ = fixed_size_list(value_type_, 12);
+ dim_names_ = {"x", "y"};
+ ext_type_ = internal::checked_pointer_cast<ExtensionType>(
+ fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_));
+ values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
+ serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})";
+ }
+
+ protected:
+ std::vector<int64_t> shape_;
+ std::vector<int64_t> cell_shape_;
+ std::shared_ptr<DataType> value_type_;
+ std::shared_ptr<DataType> cell_type_;
+ std::vector<std::string> dim_names_;
+ std::shared_ptr<ExtensionType> ext_type_;
+ std::vector<int64_t> values_;
+ std::string serialized_;
+};
+
+auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
+ std::shared_ptr<RecordBatch>* out) {
+ ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
+ ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
+ out_stream.get()));
+
+ ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
+
+ io::BufferReader reader(complete_ipc_stream);
+ std::shared_ptr<RecordBatchReader> batch_reader;
+ ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
+ ASSERT_OK(batch_reader->ReadNext(out));
+};
+
+TEST_F(TestExtensionType, CheckDummyRegistration) {
+ // We need a registered dummy type at runtime to allow for IPC deserialization
+ auto registered_type = GetExtensionType("arrow.fixed_shape_tensor");
+ ASSERT_TRUE(registered_type->type_id == Type::EXTENSION);
+}
+
+TEST_F(TestExtensionType, CreateExtensionType) {
+ auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ // Test ExtensionType methods
+ ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor");
+ ASSERT_TRUE(ext_type_->Equals(*exact_ext_type));
+ ASSERT_FALSE(ext_type_->Equals(*cell_type_));
+ ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_));
+ ASSERT_EQ(ext_type_->Serialize(), serialized_);
+ ASSERT_OK_AND_ASSIGN(auto ds,
+ ext_type_->Deserialize(ext_type_->storage_type(), serialized_));
+ auto deserialized = std::reinterpret_pointer_cast<ExtensionType>(ds);
+ ASSERT_TRUE(deserialized->Equals(*ext_type_));
+
+ // Test FixedShapeTensorType methods
+ ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION);
+ ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size());
+ ASSERT_EQ(exact_ext_type->shape(), cell_shape_);
+ ASSERT_EQ(exact_ext_type->value_type(), value_type_);
+ ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Invalid: permutation size must match shape size."),
+ FixedShapeTensorType::Make(value_type_, cell_shape_, {0}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("Invalid: dim_names size must match shape size."),
+ FixedShapeTensorType::Make(value_type_, cell_shape_, {}, {"x"}));
+}
+
+TEST_F(TestExtensionType, EqualsCases) {
+ auto ext_type_permutation_1 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}, {"x", "y"});
+ auto ext_type_permutation_2 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}, {"x", "y"});
+ auto ext_type_no_permutation = fixed_shape_tensor(int64(), {3, 4}, {}, {"x", "y"});
+
+ ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1));
+
+ ASSERT_FALSE(fixed_shape_tensor(int32(), {3, 4}, {}, {"x", "y"})
+ ->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(fixed_shape_tensor(int64(), {2, 4}, {}, {"x", "y"})
+ ->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(fixed_shape_tensor(int64(), {3, 4}, {}, {"H", "W"})
+ ->Equals(ext_type_no_permutation));
+
+ ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1));
+ ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2));
+ ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation));
+ ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2));
+ ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1));
+}
+
+TEST_F(TestExtensionType, CreateFromArray) {
+ auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)};
+ auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0);
+ auto arr = std::make_shared<Int64Array>(arr_data);
+ ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_));
+ auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr);
+ ASSERT_EQ(ext_arr->length(), shape_[0]);
+ ASSERT_EQ(ext_arr->null_count(), 0);
+}
+
+void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
+ auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
+ auto serialized = fst_type->Serialize();
+ ASSERT_OK_AND_ASSIGN(auto deserialized,
+ fst_type->Deserialize(fst_type->storage_type(), serialized));
+ ASSERT_TRUE(fst_type->Equals(*deserialized));
+}
+
+void CheckDeserializationRaises(const std::shared_ptr<DataType>& storage_type,
+ const std::string& serialized,
+ const std::string& expected_message) {
+ auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_message),
+ fst_type->Deserialize(storage_type, serialized));
+}
+
+TEST_F(TestExtensionType, MetadataSerializationRoundtrip) {
+ CheckSerializationRoundtrip(ext_type_);
+ CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {}));
+ CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {}));
+ CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
+ CheckSerializationRoundtrip(
+ fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"}));
+ CheckSerializationRoundtrip(
+ fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"}));
+
+ auto storage_type = fixed_size_list(int64(), 12);
+ CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})",
+ "Expected FixedSizeList storage type, got bool");
+ CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})",
+ "Invalid serialized JSON data");
+ CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})",
+ "Invalid serialized JSON data");
+ CheckDeserializationRaises(storage_type, R"({"shape":[3,4],"permutation":[1,0,2]})",
+ "Invalid permutation");
+ CheckDeserializationRaises(storage_type, R"({"shape":[3],"dim_names":["x","y"]})",
+ "Invalid dim_names");
+}
+
+TEST_F(TestExtensionType, RoudtripBatch) {
+ auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)};
+ auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0);
+ auto arr = std::make_shared<Int64Array>(arr_data);
+ ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_));
+ auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr);
+
+ // Pass extension array, expect getting back extension array
+ std::shared_ptr<RecordBatch> read_batch;
+ auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_);
+ auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr});
+ RoundtripBatch(batch, &read_batch);
+ CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
+
+ // Pass extension metadata and storage array, expect getting back extension array
+ std::shared_ptr<RecordBatch> read_batch2;
+ auto ext_metadata =
+ key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()},
+ {"ARROW:extension:metadata", serialized_}});
+ ext_field = field(/*name=*/"f0", /*type=*/cell_type_, /*nullable=*/true,
+ /*metadata=*/ext_metadata);
+ auto batch2 = RecordBatch::Make(schema({ext_field}), fsla_arr->length(), {fsla_arr});
+ RoundtripBatch(batch2, &read_batch2);
+ CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
+}
+
+} // namespace arrow
diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc
index e579b69102..1199336763 100644
--- a/cpp/src/arrow/extension_type.cc
+++ b/cpp/src/arrow/extension_type.cc
@@ -26,6 +26,10 @@
#include "arrow/array/util.h"
#include "arrow/chunked_array.h"
+#include "arrow/config.h"
+#ifdef ARROW_JSON
+#include "arrow/extension/fixed_shape_tensor.h"
+#endif
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"
@@ -139,6 +143,14 @@ namespace internal {
static void CreateGlobalRegistry() {
g_registry = std::make_shared<ExtensionTypeRegistryImpl>();
+
+#ifdef ARROW_JSON
+ // Register canonical extension types
+ auto ext_type =
+ checked_pointer_cast<ExtensionType>(extension::fixed_shape_tensor(int64(), {}));
+
+ ARROW_CHECK_OK(g_registry->RegisterType(ext_type));
+#endif
}
} // namespace internal