You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2023/04/04 12:51:46 UTC

[arrow] branch main updated: GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)

This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new a84a39b640 GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)
a84a39b640 is described below

commit a84a39b640475dcc7cf8f6bde38e7e471e489e3d
Author: Rok Mihevc <ro...@mihevc.org>
AuthorDate: Tue Apr 4 14:51:38 2023 +0200

    GH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (#8510)
    
    > [ARROW-1614](https://issues.apache.org/jira/browse/ARROW-1614): In an Arrow table, we would like to add support for a column that has values cells each containing a tensor value, with all tensors having the same dimensions. These would be stored as a binary value, plus some metadata to store type and shape/strides.
    * Closes: #15483
    
    Lead-authored-by: Rok Mihevc <ro...@mihevc.org>
    Co-authored-by: Rok <ro...@mihevc.org>
    Co-authored-by: Joris Van den Bossche <jo...@gmail.com>
    Co-authored-by: Ben Harkins <60...@users.noreply.github.com>
    Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
 cpp/src/arrow/CMakeLists.txt                       |   2 +
 cpp/src/arrow/extension/CMakeLists.txt             |  24 +++
 cpp/src/arrow/extension/fixed_shape_tensor.cc      | 170 ++++++++++++++++
 cpp/src/arrow/extension/fixed_shape_tensor.h       |  92 +++++++++
 cpp/src/arrow/extension/fixed_shape_tensor_test.cc | 215 +++++++++++++++++++++
 cpp/src/arrow/extension_type.cc                    |  12 ++
 6 files changed, 515 insertions(+)

diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 07ea8930ff..143fb13ddc 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -520,6 +520,7 @@ endif()
 if(ARROW_JSON)
   list(APPEND
        ARROW_SRCS
+       extension/fixed_shape_tensor.cc
        json/options.cc
        json/chunked_builder.cc
        json/chunker.cc
@@ -856,6 +857,7 @@ endif()
 
 if(ARROW_JSON)
   add_subdirectory(json)
+  add_subdirectory(extension)
 endif()
 
 if(ARROW_ORC)
diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt
new file mode 100644
index 0000000000..c15c42874d
--- /dev/null
+++ b/cpp/src/arrow/extension/CMakeLists.txt
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+add_arrow_test(test
+               SOURCES
+               fixed_shape_tensor_test.cc
+               PREFIX
+               "arrow-fixed-shape-tensor")
+
+arrow_install_all_headers("arrow/extension")
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc
new file mode 100644
index 0000000000..8b0ed43df5
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <numeric>
+#include <sstream>
+
+#include "arrow/extension/fixed_shape_tensor.h"
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/json/rapidjson_defs.h"  // IWYU pragma: keep
+#include "arrow/util/int_util_overflow.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/sort.h"
+
+#include <rapidjson/document.h>
+#include <rapidjson/writer.h>
+
+namespace rj = arrow::rapidjson;
+
+namespace arrow {
+namespace extension {
+
+bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
+  if (extension_name() != other.extension_name()) {
+    return false;
+  }
+  const auto& other_ext = static_cast<const FixedShapeTensorType&>(other);
+
+  auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
+    for (size_t i = 1; i < permutation.size(); ++i) {
+      if (permutation[i - 1] + 1 != permutation[i]) {
+        return false;
+      }
+    }
+    return true;
+  };
+  const bool permutation_equivalent =
+      ((permutation_ == other_ext.permutation()) ||
+       (permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
+       (is_permutation_trivial(permutation_) && other_ext.permutation().empty()));
+
+  return (storage_type()->Equals(other_ext.storage_type())) &&
+         (this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
+         permutation_equivalent;
+}
+
+std::string FixedShapeTensorType::Serialize() const {
+  rj::Document document;
+  document.SetObject();
+  rj::Document::AllocatorType& allocator = document.GetAllocator();
+
+  rj::Value shape(rj::kArrayType);
+  for (auto v : shape_) {
+    shape.PushBack(v, allocator);
+  }
+  document.AddMember(rj::Value("shape", allocator), shape, allocator);
+
+  if (!permutation_.empty()) {
+    rj::Value permutation(rj::kArrayType);
+    for (auto v : permutation_) {
+      permutation.PushBack(v, allocator);
+    }
+    document.AddMember(rj::Value("permutation", allocator), permutation, allocator);
+  }
+
+  if (!dim_names_.empty()) {
+    rj::Value dim_names(rj::kArrayType);
+    for (std::string v : dim_names_) {
+      dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator);
+    }
+    document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator);
+  }
+
+  rj::StringBuffer buffer;
+  rj::Writer<rj::StringBuffer> writer(buffer);
+  document.Accept(writer);
+  return buffer.GetString();
+}
+
+Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
+    std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
+  if (storage_type->id() != Type::FIXED_SIZE_LIST) {
+    return Status::Invalid("Expected FixedSizeList storage type, got ",
+                           storage_type->ToString());
+  }
+  auto value_type =
+      internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
+  rj::Document document;
+  if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
+      !document.HasMember("shape") || !document["shape"].IsArray()) {
+    return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
+  }
+
+  std::vector<int64_t> shape;
+  for (auto& x : document["shape"].GetArray()) {
+    shape.emplace_back(x.GetInt64());
+  }
+  std::vector<int64_t> permutation;
+  if (document.HasMember("permutation")) {
+    for (auto& x : document["permutation"].GetArray()) {
+      permutation.emplace_back(x.GetInt64());
+    }
+    if (shape.size() != permutation.size()) {
+      return Status::Invalid("Invalid permutation");
+    }
+  }
+  std::vector<std::string> dim_names;
+  if (document.HasMember("dim_names")) {
+    for (auto& x : document["dim_names"].GetArray()) {
+      dim_names.emplace_back(x.GetString());
+    }
+    if (shape.size() != dim_names.size()) {
+      return Status::Invalid("Invalid dim_names");
+    }
+  }
+
+  return fixed_shape_tensor(value_type, shape, permutation, dim_names);
+}
+
+std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
+    std::shared_ptr<ArrayData> data) const {
+  DCHECK_EQ(data->type->id(), Type::EXTENSION);
+  DCHECK_EQ("arrow.fixed_shape_tensor",
+            static_cast<const ExtensionType&>(*data->type).extension_name());
+  return std::make_shared<ExtensionArray>(data);
+}
+
+Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
+    const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
+    const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
+  if (!permutation.empty() && shape.size() != permutation.size()) {
+    return Status::Invalid("permutation size must match shape size. Expected: ",
+                           shape.size(), " Got: ", permutation.size());
+  }
+  if (!dim_names.empty() && shape.size() != dim_names.size()) {
+    return Status::Invalid("dim_names size must match shape size. Expected: ",
+                           shape.size(), " Got: ", dim_names.size());
+  }
+  const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
+                                    std::multiplies<>());
+  return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
+                                                shape, permutation, dim_names);
+}
+
+std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
+                                             const std::vector<int64_t>& shape,
+                                             const std::vector<int64_t>& permutation,
+                                             const std::vector<std::string>& dim_names) {
+  auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
+  ARROW_DCHECK_OK(maybe_type.status());
+  return maybe_type.MoveValueUnsafe();
+}
+
+}  // namespace extension
+}  // namespace arrow
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h
new file mode 100644
index 0000000000..4ee2b894ee
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension_type.h"
+
+namespace arrow {
+namespace extension {
+
+class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
+ public:
+  using ExtensionArray::ExtensionArray;
+};
+
+/// \brief Concrete type class for constant-size Tensor data.
+/// This is a canonical arrow extension type.
+/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
+class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
+ public:
+  FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
+                       const std::vector<int64_t>& shape,
+                       const std::vector<int64_t>& permutation = {},
+                       const std::vector<std::string>& dim_names = {})
+      : ExtensionType(fixed_size_list(value_type, size)),
+        value_type_(value_type),
+        shape_(shape),
+        permutation_(permutation),
+        dim_names_(dim_names) {}
+
+  std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }
+
+  /// Number of dimensions of tensor elements
+  size_t ndim() { return shape_.size(); }
+
+  /// Shape of tensor elements
+  const std::vector<int64_t> shape() const { return shape_; }
+
+  /// Value type of tensor elements
+  const std::shared_ptr<DataType> value_type() const { return value_type_; }
+
+  /// Permutation mapping from logical to physical memory layout of tensor elements
+  const std::vector<int64_t>& permutation() const { return permutation_; }
+
+  /// Dimension names of tensor elements. Dimensions are ordered physically.
+  const std::vector<std::string>& dim_names() const { return dim_names_; }
+
+  bool ExtensionEquals(const ExtensionType& other) const override;
+
+  std::string Serialize() const override;
+
+  Result<std::shared_ptr<DataType>> Deserialize(
+      std::shared_ptr<DataType> storage_type,
+      const std::string& serialized_data) const override;
+
+  /// Create a FixedShapeTensorArray from ArrayData
+  std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
+
+  /// \brief Create a FixedShapeTensorType instance
+  static Result<std::shared_ptr<DataType>> Make(
+      const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
+      const std::vector<int64_t>& permutation = {},
+      const std::vector<std::string>& dim_names = {});
+
+ private:
+  std::shared_ptr<DataType> storage_type_;
+  std::shared_ptr<DataType> value_type_;
+  std::vector<int64_t> shape_;
+  std::vector<int64_t> permutation_;
+  std::vector<std::string> dim_names_;
+};
+
+/// \brief Return a FixedShapeTensorType instance.
+ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
+    const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
+    const std::vector<int64_t>& permutation = {},
+    const std::vector<std::string>& dim_names = {});
+
+}  // namespace extension
+}  // namespace arrow
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
new file mode 100644
index 0000000000..16ba9d2014
--- /dev/null
+++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
@@ -0,0 +1,215 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/extension/fixed_shape_tensor.h"
+
+#include "arrow/testing/matchers.h"
+
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
+#include "arrow/io/memory.h"
+#include "arrow/ipc/reader.h"
+#include "arrow/ipc/writer.h"
+#include "arrow/record_batch.h"
+#include "arrow/tensor.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/key_value_metadata.h"
+
+namespace arrow {
+
+using FixedShapeTensorType = extension::FixedShapeTensorType;
+using extension::fixed_shape_tensor;
+using extension::FixedShapeTensorArray;
+
+class TestExtensionType : public ::testing::Test {
+ public:
+  void SetUp() override {
+    shape_ = {3, 3, 4};
+    cell_shape_ = {3, 4};
+    value_type_ = int64();
+    cell_type_ = fixed_size_list(value_type_, 12);
+    dim_names_ = {"x", "y"};
+    ext_type_ = internal::checked_pointer_cast<ExtensionType>(
+        fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_));
+    values_ = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16, 17,
+               18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35};
+    serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})";
+  }
+
+ protected:
+  std::vector<int64_t> shape_;
+  std::vector<int64_t> cell_shape_;
+  std::shared_ptr<DataType> value_type_;
+  std::shared_ptr<DataType> cell_type_;
+  std::vector<std::string> dim_names_;
+  std::shared_ptr<ExtensionType> ext_type_;
+  std::vector<int64_t> values_;
+  std::string serialized_;
+};
+
+auto RoundtripBatch = [](const std::shared_ptr<RecordBatch>& batch,
+                         std::shared_ptr<RecordBatch>* out) {
+  ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
+  ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
+                                        out_stream.get()));
+
+  ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());
+
+  io::BufferReader reader(complete_ipc_stream);
+  std::shared_ptr<RecordBatchReader> batch_reader;
+  ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
+  ASSERT_OK(batch_reader->ReadNext(out));
+};
+
+TEST_F(TestExtensionType, CheckDummyRegistration) {
+  // We need a registered dummy type at runtime to allow for IPC deserialization
+  auto registered_type = GetExtensionType("arrow.fixed_shape_tensor");
+  ASSERT_TRUE(registered_type->type_id == Type::EXTENSION);
+}
+
+TEST_F(TestExtensionType, CreateExtensionType) {
+  auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+  // Test ExtensionType methods
+  ASSERT_EQ(ext_type_->extension_name(), "arrow.fixed_shape_tensor");
+  ASSERT_TRUE(ext_type_->Equals(*exact_ext_type));
+  ASSERT_FALSE(ext_type_->Equals(*cell_type_));
+  ASSERT_TRUE(ext_type_->storage_type()->Equals(*cell_type_));
+  ASSERT_EQ(ext_type_->Serialize(), serialized_);
+  ASSERT_OK_AND_ASSIGN(auto ds,
+                       ext_type_->Deserialize(ext_type_->storage_type(), serialized_));
+  auto deserialized = std::reinterpret_pointer_cast<ExtensionType>(ds);
+  ASSERT_TRUE(deserialized->Equals(*ext_type_));
+
+  // Test FixedShapeTensorType methods
+  ASSERT_EQ(exact_ext_type->id(), Type::EXTENSION);
+  ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size());
+  ASSERT_EQ(exact_ext_type->shape(), cell_shape_);
+  ASSERT_EQ(exact_ext_type->value_type(), value_type_);
+  ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);
+
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("Invalid: permutation size must match shape size."),
+      FixedShapeTensorType::Make(value_type_, cell_shape_, {0}));
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("Invalid: dim_names size must match shape size."),
+      FixedShapeTensorType::Make(value_type_, cell_shape_, {}, {"x"}));
+}
+
+TEST_F(TestExtensionType, EqualsCases) {
+  auto ext_type_permutation_1 = fixed_shape_tensor(int64(), {3, 4}, {0, 1}, {"x", "y"});
+  auto ext_type_permutation_2 = fixed_shape_tensor(int64(), {3, 4}, {1, 0}, {"x", "y"});
+  auto ext_type_no_permutation = fixed_shape_tensor(int64(), {3, 4}, {}, {"x", "y"});
+
+  ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_permutation_1));
+
+  ASSERT_FALSE(fixed_shape_tensor(int32(), {3, 4}, {}, {"x", "y"})
+                   ->Equals(ext_type_no_permutation));
+  ASSERT_FALSE(fixed_shape_tensor(int64(), {2, 4}, {}, {"x", "y"})
+                   ->Equals(ext_type_no_permutation));
+  ASSERT_FALSE(fixed_shape_tensor(int64(), {3, 4}, {}, {"H", "W"})
+                   ->Equals(ext_type_no_permutation));
+
+  ASSERT_TRUE(ext_type_no_permutation->Equals(ext_type_permutation_1));
+  ASSERT_TRUE(ext_type_permutation_1->Equals(ext_type_no_permutation));
+  ASSERT_FALSE(ext_type_no_permutation->Equals(ext_type_permutation_2));
+  ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_no_permutation));
+  ASSERT_FALSE(ext_type_permutation_1->Equals(ext_type_permutation_2));
+  ASSERT_FALSE(ext_type_permutation_2->Equals(ext_type_permutation_1));
+}
+
+TEST_F(TestExtensionType, CreateFromArray) {
+  auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+  std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)};
+  auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0);
+  auto arr = std::make_shared<Int64Array>(arr_data);
+  ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_));
+  auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr);
+  ASSERT_EQ(ext_arr->length(), shape_[0]);
+  ASSERT_EQ(ext_arr->null_count(), 0);
+}
+
+void CheckSerializationRoundtrip(const std::shared_ptr<DataType>& ext_type) {
+  auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
+  auto serialized = fst_type->Serialize();
+  ASSERT_OK_AND_ASSIGN(auto deserialized,
+                       fst_type->Deserialize(fst_type->storage_type(), serialized));
+  ASSERT_TRUE(fst_type->Equals(*deserialized));
+}
+
+void CheckDeserializationRaises(const std::shared_ptr<DataType>& storage_type,
+                                const std::string& serialized,
+                                const std::string& expected_message) {
+  auto fst_type = internal::checked_pointer_cast<FixedShapeTensorType>(
+      fixed_shape_tensor(int64(), {3, 4}));
+  EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr(expected_message),
+                                  fst_type->Deserialize(storage_type, serialized));
+}
+
+TEST_F(TestExtensionType, MetadataSerializationRoundtrip) {
+  CheckSerializationRoundtrip(ext_type_);
+  CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {}, {}, {}));
+  CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {0}, {}, {}));
+  CheckSerializationRoundtrip(fixed_shape_tensor(value_type_, {1}, {0}, {"x"}));
+  CheckSerializationRoundtrip(
+      fixed_shape_tensor(value_type_, {256, 256, 3}, {0, 1, 2}, {"H", "W", "C"}));
+  CheckSerializationRoundtrip(
+      fixed_shape_tensor(value_type_, {256, 256, 3}, {2, 0, 1}, {"C", "H", "W"}));
+
+  auto storage_type = fixed_size_list(int64(), 12);
+  CheckDeserializationRaises(boolean(), R"({"shape":[3,4]})",
+                             "Expected FixedSizeList storage type, got bool");
+  CheckDeserializationRaises(storage_type, R"({"dim_names":["x","y"]})",
+                             "Invalid serialized JSON data");
+  CheckDeserializationRaises(storage_type, R"({"shape":(3,4)})",
+                             "Invalid serialized JSON data");
+  CheckDeserializationRaises(storage_type, R"({"shape":[3,4],"permutation":[1,0,2]})",
+                             "Invalid permutation");
+  CheckDeserializationRaises(storage_type, R"({"shape":[3],"dim_names":["x","y"]})",
+                             "Invalid dim_names");
+}
+
+TEST_F(TestExtensionType, RoudtripBatch) {
+  auto exact_ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+  std::vector<std::shared_ptr<Buffer>> buffers = {nullptr, Buffer::Wrap(values_)};
+  auto arr_data = std::make_shared<ArrayData>(value_type_, values_.size(), buffers, 0, 0);
+  auto arr = std::make_shared<Int64Array>(arr_data);
+  ASSERT_OK_AND_ASSIGN(auto fsla_arr, FixedSizeListArray::FromArrays(arr, cell_type_));
+  auto ext_arr = ExtensionType::WrapArray(ext_type_, fsla_arr);
+
+  // Pass extension array, expect getting back extension array
+  std::shared_ptr<RecordBatch> read_batch;
+  auto ext_field = field(/*name=*/"f0", /*type=*/ext_type_);
+  auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr});
+  RoundtripBatch(batch, &read_batch);
+  CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
+
+  // Pass extension metadata and storage array, expect getting back extension array
+  std::shared_ptr<RecordBatch> read_batch2;
+  auto ext_metadata =
+      key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()},
+                          {"ARROW:extension:metadata", serialized_}});
+  ext_field = field(/*name=*/"f0", /*type=*/cell_type_, /*nullable=*/true,
+                    /*metadata=*/ext_metadata);
+  auto batch2 = RecordBatch::Make(schema({ext_field}), fsla_arr->length(), {fsla_arr});
+  RoundtripBatch(batch2, &read_batch2);
+  CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
+}
+
+}  // namespace arrow
diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc
index e579b69102..1199336763 100644
--- a/cpp/src/arrow/extension_type.cc
+++ b/cpp/src/arrow/extension_type.cc
@@ -26,6 +26,10 @@
 
 #include "arrow/array/util.h"
 #include "arrow/chunked_array.h"
+#include "arrow/config.h"
+#ifdef ARROW_JSON
+#include "arrow/extension/fixed_shape_tensor.h"
+#endif
 #include "arrow/status.h"
 #include "arrow/type.h"
 #include "arrow/util/checked_cast.h"
@@ -139,6 +143,14 @@ namespace internal {
 
 static void CreateGlobalRegistry() {
   g_registry = std::make_shared<ExtensionTypeRegistryImpl>();
+
+#ifdef ARROW_JSON
+  // Register canonical extension types
+  auto ext_type =
+      checked_pointer_cast<ExtensionType>(extension::fixed_shape_tensor(int64(), {}));
+
+  ARROW_CHECK_OK(g_registry->RegisterType(ext_type));
+#endif
 }
 
 }  // namespace internal