You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2019/12/01 15:41:11 UTC

[incubator-tvm] branch master updated: [Runtime] Make ADTObject POD container type (#4346)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bf5fd2  [Runtime] Make ADTObject POD container type (#4346)
2bf5fd2 is described below

commit 2bf5fd2b5e5c032d0c1803b271c2462e171e5d40
Author: Wei Chen <ip...@gmail.com>
AuthorDate: Sun Dec 1 07:41:00 2019 -0800

    [Runtime] Make ADTObject POD container type (#4346)
---
 include/tvm/runtime/container.h | 279 ++++++++++++++++++++++++++++++++++++++++
 include/tvm/runtime/memory.h    |  79 +++++++++++-
 include/tvm/runtime/vm.h        |  29 -----
 src/runtime/vm/object.cc        |  29 ++---
 src/runtime/vm/vm.cc            |  21 ++-
 tests/cpp/container_test.cc     | 130 ++++++++++++++++++-
 6 files changed, 498 insertions(+), 69 deletions(-)

diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h
new file mode 100644
index 0000000..2714ac2
--- /dev/null
+++ b/include/tvm/runtime/container.h
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/container.h
+ * \brief Common POD(plain old data) container types.
+ */
+#ifndef TVM_RUNTIME_CONTAINER_H_
+#define TVM_RUNTIME_CONTAINER_H_
+#include <dmlc/logging.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/object.h>
+
+#include <initializer_list>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+
+/*!
+ * \brief Base template for classes with array like memory layout.
+ *
+ *        It provides general methods to access the memory. The memory
+ *        layout is ArrayType + [ElemType]. The alignment of ArrayType
+ *        and ElemType is handled by the memory allocator.
+ *
+ * \tparam ArrayType The array header type, contains object specific metadata.
+ * \tparam ElemType The type of objects stored in the array right after
+ * ArrayType.
+ *
+ * \code
+ * // Example usage of the template to define a simple array wrapper
+ * class ArrayObj : public InplaceArrayBase<ArrayObj, Elem> {
+ * public:
+ *  // Wrap EmplaceInit to initialize the elements
+ *  template <typename Iterator>
+ *  void Init(Iterator begin, Iterator end) {
+ *   size_t num_elems = std::distance(begin, end);
+ *   auto it = begin;
+ *   this->size = 0;
+ *   for (size_t i = 0; i < num_elems; ++i) {
+ *     InplaceArrayBase::EmplaceInit(i, *it++);
+ *     this->size++;
+ *   }
+ *  }
+ * }
+ *
+ * void test_function() {
+ *   vector<Elem> fields;
+ *   auto ptr = make_inplace_array_object<ArrayObj, Elem>(fields.size());
+ *   ptr->Init(fields.begin(), fields.end());
+ *
+ *   // Access the 0th element in the array.
+ *   assert(ptr->operator[](0) == fields[0]);
+ * }
+ *
+ * \endcode
+ */
+template <typename ArrayType, typename ElemType>
+class InplaceArrayBase {
+ public:
+  /*!
+   * \brief Access element at index
+   * \param idx The index of the element.
+   * \return Const reference to ElemType at the index.
+   */
+  const ElemType& operator[](size_t idx) const {
+    size_t size = Self()->GetSize();
+    CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
+    return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
+  }
+
+  /*!
+   * \brief Access element at index
+   * \param idx The index of the element.
+   * \return Reference to ElemType at the index.
+   */
+  ElemType& operator[](size_t idx) {
+    size_t size = Self()->GetSize();
+    CHECK_LT(idx, size) << "Index " << idx << " out of bounds " << size << "\n";
+    return *(reinterpret_cast<ElemType*>(AddressOf(idx)));
+  }
+
+  /*!
+   * \brief Destroy the Inplace Array Base object
+   */
+  ~InplaceArrayBase() {
+    if (!(std::is_standard_layout<ElemType>::value &&
+          std::is_trivial<ElemType>::value)) {
+      size_t size = Self()->GetSize();
+      for (size_t i = 0; i < size; ++i) {
+        ElemType* fp = reinterpret_cast<ElemType*>(AddressOf(i));
+        fp->ElemType::~ElemType();
+      }
+    }
+  }
+
+ protected:
+  /*!
+   * \brief Construct a value in place with the arguments.
+   *
+   * \tparam Args Type parameters of the arguments.
+   * \param idx Index of the element.
+   * \param args Arguments to construct the new value.
+   *
+   * \note Please make sure ArrayType::GetSize returns 0 before first call of
+   * EmplaceInit, and increment GetSize by 1 each time EmplaceInit succeeds.
+   */
+  template <typename... Args>
+  void EmplaceInit(size_t idx, Args&&... args) {
+    void* field_ptr = AddressOf(idx);
+    new (field_ptr) ElemType(std::forward<Args>(args)...);
+  }
+
+ private:
+  /*!
+   * \brief Return the self object for the array.
+   *
+   * \return Pointer to ArrayType.
+   */
+  inline ArrayType* Self() const {
+    return static_cast<ArrayType*>(const_cast<InplaceArrayBase*>(this));
+  }
+
+  /*!
+   * \brief Return the raw pointer to the element at idx.
+   *
+   * \param idx The index of the element.
+   * \return Raw pointer to the element.
+   */
+  void* AddressOf(size_t idx) const {
+    static_assert(alignof(ArrayType) % alignof(ElemType) == 0 &&
+                      sizeof(ArrayType) % alignof(ElemType) == 0,
+                  "The size and alignment of ArrayType should respect "
+                  "ElemType's alignment.");
+
+    size_t kDataStart = sizeof(ArrayType);
+    ArrayType* self = Self();
+    char* data_start = reinterpret_cast<char*>(self) + kDataStart;
+    return data_start + idx * sizeof(ElemType);
+  }
+};
+
+/*! \brief An object representing a structure or enumeration. */
+class ADTObj : public Object, public InplaceArrayBase<ADTObj, ObjectRef> {
+ public:
+  /*! \brief The tag representing the constructor used. */
+  uint32_t tag;
+  /*! \brief Number of fields in the ADT object. */
+  uint32_t size;
+  // The fields of the structure follows directly in memory.
+
+  static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
+  static constexpr const char* _type_key = "vm.ADT";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
+
+ private:
+  /*!
+   * \return The number of elements in the array.
+   */
+  size_t GetSize() const { return size; }
+
+  /*!
+   * \brief Initialize the elements in the array.
+   *
+   * \tparam Iterator Iterator type of the array.
+   * \param begin The begin iterator.
+   * \param end The end iterator.
+   */
+  template <typename Iterator>
+  void Init(Iterator begin, Iterator end) {
+    size_t num_elems = std::distance(begin, end);
+    this->size = 0;
+    auto it = begin;
+    for (size_t i = 0; i < num_elems; ++i) {
+      InplaceArrayBase::EmplaceInit(i, *it++);
+      // Only increment size after the initialization succeeds
+      this->size++;
+    }
+  }
+
+  friend class ADT;
+  friend class InplaceArrayBase;
+};
+
+/*! \brief reference to algebraic data type objects. */
+class ADT : public ObjectRef {
+ public:
+  /*!
+   * \brief construct an ADT object reference.
+   * \param tag The tag of the ADT object.
+   * \param fields The fields of the ADT object.
+   * \return The constructed ADT object reference.
+   */
+  ADT(uint32_t tag, std::vector<ObjectRef> fields)
+      : ADT(tag, fields.begin(), fields.end()){};
+
+  /*!
+   * \brief construct an ADT object reference.
+   * \param tag The tag of the ADT object.
+   * \param begin The begin iterator to the start of the fields array.
+   * \param end The end iterator to the end of the fields array.
+   * \return The constructed ADT object reference.
+   */
+  template <typename Iterator>
+  ADT(uint32_t tag, Iterator begin, Iterator end) {
+    size_t num_elems = std::distance(begin, end);
+    auto ptr = make_inplace_array_object<ADTObj, ObjectRef>(num_elems);
+    ptr->tag = tag;
+    ptr->Init(begin, end);
+    data_ = std::move(ptr);
+  }
+
+  /*!
+   * \brief construct an ADT object reference.
+   * \param tag The tag of the ADT object.
+   * \param init The initializer list of fields.
+   * \return The constructed ADT object reference.
+   */
+  ADT(uint32_t tag, std::initializer_list<ObjectRef> init)
+      : ADT(tag, init.begin(), init.end()){};
+
+  /*!
+   * \brief Access element at index.
+   *
+   * \param idx The array index
+   * \return const ObjectRef
+   */
+  const ObjectRef& operator[](size_t idx) const {
+    return operator->()->operator[](idx);
+  }
+
+  /*!
+   * \brief Return the ADT tag.
+   */
+  size_t tag() const { return operator->()->tag; }
+
+  /*!
+   * \brief Return the number of fields.
+   */
+  size_t size() const { return operator->()->size; }
+
+  /*!
+   * \brief Construct a tuple object.
+   *
+   * \tparam Args Type params of tuple feilds.
+   * \param args Tuple fields.
+   * \return ADT The tuple object reference.
+   */
+  template <typename... Args>
+  static ADT Tuple(Args&&... args) {
+    return ADT(0, std::forward<Args>(args)...);
+  }
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
+};
+
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_CONTAINER_H_
diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h
index 07e22d7..63f3e4e 100644
--- a/include/tvm/runtime/memory.h
+++ b/include/tvm/runtime/memory.h
@@ -23,6 +23,7 @@
 #ifndef TVM_RUNTIME_MEMORY_H_
 #define TVM_RUNTIME_MEMORY_H_
 
+#include <cstdlib>
 #include <utility>
 #include <type_traits>
 #include "object.h"
@@ -33,7 +34,7 @@ namespace runtime {
  * \brief Allocate an object using default allocator.
  * \param args arguments to the constructor.
  * \tparam T the node type.
- * \return The NodePtr to the allocated object.
+ * \return The ObjectPtr to the allocated object.
  */
 template<typename T, typename... Args>
 inline ObjectPtr<T> make_object(Args&&... args);
@@ -67,13 +68,33 @@ class ObjAllocatorBase {
   inline ObjectPtr<T> make_object(Args&&... args) {
     using Handler = typename Derived::template Handler<T>;
     static_assert(std::is_base_of<Object, T>::value,
-                  "make_node can only be used to create NodeBase");
+                  "make can only be used to create Object");
     T* ptr = Handler::New(static_cast<Derived*>(this),
                          std::forward<Args>(args)...);
     ptr->type_index_ = T::RuntimeTypeIndex();
     ptr->deleter_ = Handler::Deleter();
     return ObjectPtr<T>(ptr);
   }
+
+  /*!
+   * \tparam ArrayType The type to be allocated.
+   * \tparam ElemType The type of array element.
+   * \tparam Args The constructor signature.
+   * \param num_elems The number of array elements.
+   * \param args The arguments.
+   */
+  template<typename ArrayType, typename ElemType, typename... Args>
+  inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) {
+    using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>;
+    static_assert(std::is_base_of<Object, ArrayType>::value,
+                  "make_inplace_array can only be used to create Object");
+    ArrayType* ptr = Handler::New(static_cast<Derived*>(this),
+                                  num_elems,
+                                  std::forward<Args>(args)...);
+    ptr->type_index_ = ArrayType::RuntimeTypeIndex();
+    ptr->deleter_ = Handler::Deleter();
+    return ObjectPtr<ArrayType>(ptr);
+  }
 };
 
 // Simple allocator that uses new/delete.
@@ -123,6 +144,54 @@ class SimpleObjAllocator :
       delete reinterpret_cast<StorageType*>(tptr);
     }
   };
+
+  // Array handler that uses new/delete.
+  template<typename ArrayType, typename ElemType>
+  class ArrayHandler {
+   public:
+    using StorageType = typename std::aligned_union<sizeof(ArrayType), ArrayType, ElemType>::type;
+
+    template<typename... Args>
+    static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) {
+      // NOTE: the first argument is not needed for ArrayObjAllocator
+      // It is reserved for special allocators that needs to recycle
+      // the object to itself (e.g. in the case of object pool).
+      //
+      // In the case of an object pool, an allocator needs to create
+      // a special chunk memory that hides reference to the allocator
+      // and call allocator's release function in the deleter.
+
+      // NOTE2: Use inplace new to allocate
+      // This is used to get rid of warning when deleting a virtual
+      // class with non-virtual destructor.
+      // We are fine here as we captured the right deleter during construction.
+      // This is also the right way to get storage type for an object pool.
+      size_t factor = sizeof(ArrayType) / sizeof(ElemType);
+      num_elems = (num_elems + factor - 1) / factor;
+      StorageType* data = new StorageType[num_elems+1];
+      new (data) ArrayType(std::forward<Args>(args)...);
+      return reinterpret_cast<ArrayType*>(data);
+    }
+
+    static Object::FDeleter Deleter() {
+      return Deleter_;
+    }
+
+   private:
+    static void Deleter_(Object* objptr) {
+      // NOTE: this is important to cast back to ArrayType*
+      // because objptr and tptr may not be the same
+      // depending on how sub-class allocates the space.
+      ArrayType* tptr = static_cast<ArrayType*>(objptr);
+      // It is important to do tptr->ArrayType::~ArrayType(),
+      // so that we explicitly call the specific destructor
+      // instead of tptr->~ArrayType(), which could mean the intention
+      // call a virtual destructor(which may not be available and is not required).
+      tptr->ArrayType::~ArrayType();
+      StorageType* p = reinterpret_cast<StorageType*>(tptr);
+      delete []p;
+    }
+  };
 };
 
 template<typename T, typename... Args>
@@ -130,6 +199,12 @@ inline ObjectPtr<T> make_object(Args&&... args) {
   return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
 }
 
+template<typename ArrayType, typename ElemType, typename... Args>
+inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, Args&&... args) {
+  return SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(
+    num_elems, std::forward<Args>(args)...);
+}
+
 }  // namespace runtime
 }  // namespace tvm
 #endif  // TVM_RUNTIME_MEMORY_H_
diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h
index f7188e4..59e9ae8 100644
--- a/include/tvm/runtime/vm.h
+++ b/include/tvm/runtime/vm.h
@@ -55,35 +55,6 @@ class Tensor : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj);
 };
 
-
-/*! \brief An object representing a structure or enumeration. */
-class ADTObj : public Object {
- public:
-  /*! \brief The tag representing the constructor used. */
-  size_t tag;
-  /*! \brief The fields of the structure. */
-  std::vector<ObjectRef> fields;
-
-  static constexpr const uint32_t _type_index = TypeIndex::kVMADT;
-  static constexpr const char* _type_key = "vm.ADT";
-  TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object);
-};
-
-/*! \brief reference to algebraic data type objects. */
-class ADT : public ObjectRef {
- public:
-  ADT(size_t tag, std::vector<ObjectRef> fields);
-
-  /*!
-   * \brief construct a tuple object.
-   * \param fields The fields of the tuple.
-   * \return The constructed tuple type.
-   */
-  static ADT Tuple(std::vector<ObjectRef> fields);
-
-  TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj);
-};
-
 /*! \brief An object representing a closure. */
 class ClosureObj : public Object {
  public:
diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc
index 12edf51..988ba5d 100644
--- a/src/runtime/vm/object.cc
+++ b/src/runtime/vm/object.cc
@@ -22,6 +22,7 @@
  * \brief VM related objects.
  */
 #include <tvm/logging.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/object.h>
 #include <tvm/runtime/vm.h>
 #include <tvm/runtime/memory.h>
@@ -39,17 +40,6 @@ Tensor::Tensor(NDArray data) {
   data_ = std::move(ptr);
 }
 
-ADT::ADT(size_t tag, std::vector<ObjectRef> fields) {
-  auto ptr = make_object<ADTObj>();
-  ptr->tag = tag;
-  ptr->fields = std::move(fields);
-  data_ = std::move(ptr);
-}
-
-ADT ADT::Tuple(std::vector<ObjectRef> fields) {
-  return ADT(0, fields);
-}
-
 Closure::Closure(size_t func_index, std::vector<ObjectRef> free_vars) {
   auto ptr = make_object<ClosureObj>();
   ptr->func_index = func_index;
@@ -69,17 +59,15 @@ TVM_REGISTER_GLOBAL("_vmobj.GetTensorData")
 TVM_REGISTER_GLOBAL("_vmobj.GetADTTag")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
-  const auto* cell = obj.as<ADTObj>();
-  CHECK(cell != nullptr);
-  *rv = static_cast<int64_t>(cell->tag);
+  const auto& adt = Downcast<ADT>(obj);
+  *rv = static_cast<int64_t>(adt.tag());
 });
 
 TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
-  const auto* cell = obj.as<ADTObj>();
-  CHECK(cell != nullptr);
-  *rv = static_cast<int64_t>(cell->fields.size());
+  const auto& adt = Downcast<ADT>(obj);
+  *rv = static_cast<int64_t>(adt.size());
 });
 
 
@@ -87,10 +75,9 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
   ObjectRef obj = args[0];
   int idx = args[1];
-  const auto* cell = obj.as<ADTObj>();
-  CHECK(cell != nullptr);
-  CHECK_LT(idx, cell->fields.size());
-  *rv = cell->fields[idx];
+  const auto& adt = Downcast<ADT>(obj);
+  CHECK_LT(idx, adt.size());
+  *rv = adt[idx];
 });
 
 TVM_REGISTER_GLOBAL("_vmobj.Tensor")
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 333dd1e..41fe71a 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -24,6 +24,7 @@
 
 #include <dmlc/memory_io.h>
 #include <tvm/logging.h>
+#include <tvm/runtime/container.h>
 #include <tvm/runtime/vm.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/runtime/object.h>
@@ -755,7 +756,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
   size_t arity = 0;
   for (Index i = 0; i < arg_count; i++) {
     if (const auto* obj = args[i].as<ADTObj>()) {
-      arity += obj->fields.size();
+      arity += obj->size;
     } else {
       ++arity;
     }
@@ -767,7 +768,8 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func,
   int idx = 0;
   for (Index i = 0; i < arg_count; i++) {
     if (const auto* dt_cell = args[i].as<ADTObj>()) {
-      for (auto obj : dt_cell->fields) {
+      for (size_t fi = 0; fi < dt_cell->size; ++fi) {
+        auto obj = (*dt_cell)[fi];
         const auto* tensor = obj.as<TensorObj>();
         CHECK(tensor != nullptr);
         setter(idx++, tensor->data);
@@ -924,23 +926,16 @@ void VirtualMachine::RunLoop() {
       }
       case Opcode::GetField: {
         auto object = ReadRegister(instr.object);
-        const auto* tuple = object.as<ADTObj>();
-        CHECK(tuple != nullptr)
-            << "Object is not data type object, register " << instr.object << ", Object tag "
-            << object->type_index();
-        auto field = tuple->fields[instr.field_index];
+        const auto& tuple = Downcast<ADT>(object);
+        auto field = tuple[instr.field_index];
         WriteRegister(instr.dst, field);
         pc_++;
         goto main_loop;
       }
       case Opcode::GetTag: {
         auto object = ReadRegister(instr.get_tag.object);
-        const auto* data = object.as<ADTObj>();
-        CHECK(data != nullptr)
-            << "Object is not data type object, register "
-            << instr.get_tag.object << ", Object tag "
-            << object->type_index();
-        auto tag = data->tag;
+        const auto& adt = Downcast<ADT>(object);
+        auto tag = adt.tag();
         auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0});
         reinterpret_cast<int32_t*>(tag_tensor->data)[0] = tag;
         WriteRegister(instr.dst, Tensor(tag_tensor));
diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc
index 005e159..4428642 100644
--- a/tests/cpp/container_test.cc
+++ b/tests/cpp/container_test.cc
@@ -17,11 +17,132 @@
  * under the License.
  */
 
-#include <vector>
-#include <unordered_map>
 #include <dmlc/logging.h>
 #include <gtest/gtest.h>
 #include <tvm/packed_func_ext.h>
+#include <tvm/runtime/container.h>
+#include <new>
+#include <unordered_map>
+#include <vector>
+
+using namespace tvm;
+using namespace tvm::runtime;
+
+class TestErrorSwitch {
+ public:
+  // Need this so that destructor of temporary objects don't interrupt our
+  // testing.
+  TestErrorSwitch(const TestErrorSwitch& other)
+      : should_fail(other.should_fail) {
+    const_cast<TestErrorSwitch&>(other).should_fail = false;
+  }
+
+  TestErrorSwitch(bool fail_flag) : should_fail{fail_flag} {}
+  bool should_fail{false};
+
+  ~TestErrorSwitch() {
+    if (should_fail) {
+      exit(1);
+    }
+  }
+};
+
+class TestArrayObj : public Object,
+                     public InplaceArrayBase<TestArrayObj, TestErrorSwitch> {
+ public:
+  static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+  static constexpr const char* _type_key = "test.TestArrayObj";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TestArrayObj, Object);
+  uint32_t size;
+
+  size_t GetSize() const { return size; }
+
+  template <typename Iterator>
+  void Init(Iterator begin, Iterator end) {
+    size_t num_elems = std::distance(begin, end);
+    this->size = 0;
+    auto it = begin;
+    for (size_t i = 0; i < num_elems; ++i) {
+      InplaceArrayBase::EmplaceInit(i, *it++);
+      if (i == 1) {
+        throw std::bad_alloc();
+      }
+      // Only increment size after the initialization succeeds
+      this->size++;
+    }
+  }
+
+  template <typename Iterator>
+  void WrongInit(Iterator begin, Iterator end) {
+    size_t num_elems = std::distance(begin, end);
+    this->size = num_elems;
+    auto it = begin;
+    for (size_t i = 0; i < num_elems; ++i) {
+      InplaceArrayBase::EmplaceInit(i, *it++);
+      if (i == 1) {
+        throw std::bad_alloc();
+      }
+    }
+  }
+
+  friend class InplaceArrayBase;
+};
+
+TEST(ADT, Constructor) {
+  std::vector<ObjectRef> fields;
+  auto f1 = ADT::Tuple(fields);
+  auto f2 = ADT::Tuple(fields);
+  ADT v1{1, {f1, f2}};
+  ASSERT_EQ(f1.tag(), 0);
+  ASSERT_EQ(f2.size(), 0);
+  ASSERT_EQ(v1.tag(), 1);
+  ASSERT_EQ(v1.size(), 2);
+  ASSERT_EQ(Downcast<ADT>(v1[0]).tag(), 0);
+  ASSERT_EQ(Downcast<ADT>(v1[1]).size(), 0);
+}
+
+TEST(InplaceArrayBase, BadExceptionSafety) {
+  auto wrong_init = []() {
+    TestErrorSwitch f1{false};
+    // WrongInit will set size to 3 so it will call destructor at index 1, which
+    // will exit with error status.
+    TestErrorSwitch f2{true};
+    TestErrorSwitch f3{false};
+    std::vector<TestErrorSwitch> fields{f1, f2, f3};
+    auto ptr =
+        make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
+    try {
+      ptr->WrongInit(fields.begin(), fields.end());
+    } catch (...) {
+    }
+    // Call ~InplaceArrayBase
+    ptr.reset();
+    // never reaches here.
+    exit(0);
+  };
+  ASSERT_EXIT(wrong_init(), ::testing::ExitedWithCode(1), "");
+}
+
+TEST(InplaceArrayBase, ExceptionSafety) {
+  auto correct_init = []() {
+    TestErrorSwitch f1{false};
+    // Init will fail at index 1, so destrucotr at index 1 should not be called
+    // since it's not initalized.
+    TestErrorSwitch f2{true};
+    std::vector<TestErrorSwitch> fields{f1, f2};
+    auto ptr =
+        make_inplace_array_object<TestArrayObj, TestErrorSwitch>(fields.size());
+    try {
+      ptr->Init(fields.begin(), fields.end());
+    } catch (...) {
+    }
+    // Call ~InplaceArrayBase
+    ptr.reset();
+    // Skip the destructors of f1, f2, and fields
+    exit(0);
+  };
+  ASSERT_EXIT(correct_init(), ::testing::ExitedWithCode(0), "");
+}
 
 TEST(Array, Expr) {
   using namespace tvm;
@@ -99,11 +220,12 @@ TEST(Map, Iterator) {
   using namespace tvm;
   Expr a = 1, b = 2;
   Map<Expr, Expr> map1{{a, b}};
-  std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(), map1.end());
+  std::unordered_map<Expr, Expr, NodeHash, NodeEqual> map2(map1.begin(),
+                                                           map1.end());
   CHECK(map2[a].as<IntImm>()->value == 2);
 }
 
-int main(int argc, char ** argv) {
+int main(int argc, char** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";
   return RUN_ALL_TESTS();