You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/02/22 17:35:13 UTC

[tvm] branch main updated: [runtime] Add Metadata classes for AOTExecutor (#10282)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 33082e0  [runtime] Add Metadata classes for AOTExecutor (#10282)
33082e0 is described below

commit 33082e0032fb57b0516ad7e3eabd11fe0203437e
Author: Andrew Reusch <ar...@gmail.com>
AuthorDate: Tue Feb 22 09:34:23 2022 -0800

    [runtime] Add Metadata classes for AOTExecutor (#10282)
    
    * Add new Metadata classes and base implementation.
    
     * These were autogenerated in the original PR, but checking them in
       as plain code until we can revisit the auto-generator approach.
    
    * address masa comments
    
    * Add documentation per Manupa's comments, and move kMetadataVersion namespace.
    
    * remove get_name function, used for debugging
    
    * clang-format
---
 include/tvm/runtime/metadata.h      | 160 ++++++++++++++++++++++++
 include/tvm/runtime/metadata_base.h | 198 ++++++++++++++++++++++++++++++
 include/tvm/support/span.h          | 103 ++++++++++++++++
 src/runtime/metadata.cc             |  56 +++++++++
 src/target/metadata.cc              |  47 +++++++
 src/target/metadata.h               | 173 ++++++++++++++++++++++++++
 tests/cpp/aot_metadata_test.cc      | 236 ++++++++++++++++++++++++++++++++++++
 7 files changed, 973 insertions(+)

diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h
new file mode 100644
index 0000000..b716d41
--- /dev/null
+++ b/include/tvm/runtime/metadata.h
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/metadata.h
+ * \brief Defines types which can be used in Metadata.
+ */
+#ifndef TVM_RUNTIME_METADATA_H_
+#define TVM_RUNTIME_METADATA_H_
+
+#include <inttypes.h>
+#ifdef __cplusplus
+#include <memory>
+#include <string>
+#include <vector>
+#endif
+#include <tvm/runtime/c_runtime_api.h>
+#ifdef __cplusplus
+#include <tvm/runtime/metadata_base.h>
+#endif
+#include <tvm/support/span.h>
+
+// Version number recorded in emitted artifacts for runtime checking.
+#define TVM_METADATA_VERSION 1
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+/*!
+ * \brief Version of metadata emitted and understood by this compiler/runtime.
+ * Should be populated into the `version` field of all TVMMetadata.
+ */
+static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
+}  // namespace metadata
+}  // namespace runtime
+}  // namespace tvm
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*!
+ * \brief Top-level metadata structure. Holds all other metadata types.
+ */
+struct TVMMetadata {
+  /*! \brief Version identifier for this metadata. */
+  int64_t version;
+  /*! \brief Inputs to the AOT run_model function.
+   * The order of the elements is the same as in the arguments to run_model. That is to say,
+   * this array specifies the first `num_inputs` arguments to run_model.
+   */
+  const struct TVMTensorInfo* inputs;
+  /*! \brief Number of elements in `inputs` array. */
+  int64_t num_inputs;
+  /*! \brief Outputs of the AOT run_model function.
+   * The order of the elements is the same as in the arguments to run_model. That is to say,
+   * this array specifies the last `num_outputs` arguments to run_model.
+   */
+  const struct TVMTensorInfo* outputs;
+  /*! \brief Number of elements in `outputs` array. */
+  int64_t num_outputs;
+  /*! \brief Name of the model, as passed to tvm.relay.build. */
+  const char* mod_name;
+};
+
+/*!
+ * \brief Describes one tensor argument to `run_model`.
+ * NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model
+ * function does not currently accept these. Therefore it's not possible to express those
+ * in this metadata. A future patch may modify this.
+ */
+struct TVMTensorInfo {
+  /*! \brief Name of the tensor, as specified in the Relay program. */
+  const char* name;
+  /*! \brief Shape of the tensor. */
+  const int64_t* shape;
+  /*! \brief Rank of this tensor. */
+  int64_t num_shape;
+  /*! \brief Data type of one element of this tensor. */
+  DLDataType dtype;
+};
+#ifdef __cplusplus
+}  // extern "C"
+#include <tvm/runtime/object.h>
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+class Metadata;
+class TensorInfo;
+
+class MetadataNode : public MetadataBaseNode {
+ public:
+  explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
+  static constexpr const char* _type_key = "metadata.MetadataNode";
+  inline int64_t version() const { return int64_t(data_->version); }
+  inline int64_t num_inputs() const { return data_->num_inputs; }
+  ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
+  inline int64_t num_outputs() const { return data_->num_outputs; }
+  ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
+  inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
+  const struct ::TVMMetadata* data() const { return data_; }
+  TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);
+
+ private:
+  const struct ::TVMMetadata* data_;
+};
+
+class Metadata : public MetadataBase {
+ public:
+  explicit Metadata(const struct ::TVMMetadata* data);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode);
+};
+
+class TensorInfoNode : public MetadataBaseNode {
+ public:
+  explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
+  static constexpr const char* _type_key = "metadata.TensorInfoNode";
+  inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
+  inline int64_t num_shape() const { return data_->num_shape; }
+  inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
+    return ::tvm::support::Span<const int64_t, int64_t>(data_->shape,
+                                                        data_->shape + data_->num_shape);
+  }
+  inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); }
+  const struct ::TVMTensorInfo* data() const { return data_; }
+  TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode);
+
+ private:
+  const struct ::TVMTensorInfo* data_;
+};
+
+class TensorInfo : public MetadataBase {
+ public:
+  explicit TensorInfo(const struct ::TVMTensorInfo* data);
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
+};
+
+}  // namespace metadata
+}  // namespace runtime
+}  // namespace tvm
+#endif  // defined(__cplusplus)
+
+#endif  // TVM_RUNTIME_METADATA_H_
diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h
new file mode 100644
index 0000000..9674319
--- /dev/null
+++ b/include/tvm/runtime/metadata_base.h
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/metadata_base.h
+ * \brief Defines types which can be used in Metadata.
+ */
+#ifndef TVM_RUNTIME_METADATA_BASE_H_
+#define TVM_RUNTIME_METADATA_BASE_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/object.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+/*!
+ * \brief Common base class for all Metadata.
+ *
+ * This class is used in the visitor classes as a internal check to ensure that verify that all
+ * parts of the Metadata struct used in codegen are Metadata objects.
+ */
+class MetadataBaseNode : public ::tvm::runtime::Object {
+ public:
+  static constexpr const char* _type_key = "metadata.MetadataBaseNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
+};
+
+/*! \brief Reference class for the common MetadataBaseNode class. */
+class MetadataBase : public ::tvm::runtime::ObjectRef {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
+};
+
+template <typename C, class Ref>
+class ArrayAccessor;
+
+/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */
+template <typename C, class Ref>
+class ArrayIterator {
+ public:
+  ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
+      : index_{index}, parent_{parent} {}
+
+  inline Ref operator*() { return (*parent_)[index_]; }
+
+  inline ArrayIterator<C, Ref>& operator++() {
+    if (index_ < parent_->size()) {
+      index_++;
+    }
+
+    return *this;
+  }
+
+  inline bool operator==(const ArrayIterator<C, Ref>& other) const {
+    return parent_ == other.parent_ && index_ == other.index_;
+  }
+
+  inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }
+
+ private:
+  size_t index_;
+  const ArrayAccessor<C, Ref>* parent_;
+};
+
+/*! \brief A span-like class which permits access to Array fields with complex elements.
+ * These array fields should be accessed from C++ using the Metadata wrapper classes. This class
+ * lazily instantiates those wrappers as they are accessed.
+ */
+template <typename C, class Ref>
+class ArrayAccessor {
+ public:
+  using value_type = Ref;
+  using iterator = ArrayIterator<C, Ref>;
+  using const_iterator = iterator;
+
+  template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
+  ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}
+
+  inline size_t size() const { return num_data_; }
+
+  inline Ref operator[](size_t index) const {
+    if (index >= num_data_) {
+      throw std::runtime_error("Index out of range");
+    }
+
+    return Ref(&data_[index]);
+  }
+
+  inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }
+
+  inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }
+
+ private:
+  const C* data_;
+  size_t num_data_;
+};
+
+/*! \brief A specialization of ArrayAccessor for String.
+ * This class is needed because the String constructor signature is different from the typical
+ * Metadata subclass.
+ */
+template <>
+class ArrayAccessor<const char*, ::tvm::runtime::String> {
+ public:
+  using value_type = ::tvm::runtime::String;
+  using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
+  using const_iterator = iterator;
+
+  ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}
+
+  inline size_t size() const { return num_data_; }
+
+  inline ::tvm::runtime::String operator[](size_t index) const {
+    if (index >= num_data_) {
+      throw std::runtime_error("Index out of range");
+    }
+    return ::tvm::runtime::String(data_[index]);
+  }
+
+  inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
+    return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
+  }
+
+  inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
+    return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
+  }
+
+ private:
+  const char** data_;
+  size_t num_data_;
+};
+
+/*! \brief Enumerates the primitive types which can be part of a Metadata instance.
+ *
+ * These are separate from TIR DataType because TIR does not model structs.
+ */
+enum MetadataTypeIndex : uint8_t {
+  kUint64 = 0,
+  kInt64 = 1,
+  kBool = 2,
+  kString = 3,
+  kHandle = 4,
+  kMetadata = 5,
+};
+
+/*! \brief Container for arrays in the metadata.
+ *
+ * Type information is needed when emitting arrays. This container augments the data field with
+ * the necessary typing information.
+ */
+class MetadataArrayNode : public MetadataBaseNode {
+ public:
+  MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
+      : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
+
+  Array<ObjectRef> array;
+  MetadataTypeIndex type_index;
+  const char* struct_name;
+  static constexpr const char* _type_key = "metadata.MetadataArrayNode";
+  TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
+};
+
+/*! \brief Reference class for MetadataArray. */
+class MetadataArray : public MetadataBase {
+ public:
+  MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
+
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
+};
+
+}  // namespace metadata
+}  // namespace runtime
+}  // namespace tvm
+
+#endif  // TVM_RUNTIME_METADATA_BASE_H_
diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h
new file mode 100644
index 0000000..faa849c
--- /dev/null
+++ b/include/tvm/support/span.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file tvm/support/span.h
+ * \brief Reimplementation of part of C++-20 style span.
+ */
+#ifndef TVM_SUPPORT_SPAN_H_
+#define TVM_SUPPORT_SPAN_H_
+
+#include <cstddef>
+#include <iterator>
+#include <type_traits>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*!
+ * \brief A partial implementation of the C++20 std::span.
+ *
+ * At the time of writing, TVM must compile against C++14.
+ */
+template <class T, class W>
+class Span {
+ public:
+  using value_type = W;
+  using const_W = typename ::std::add_const<W>::type;
+
+  template <class W1>
+  class iterator_base : public std::iterator<std::input_iterator_tag, W> {
+   public:
+    inline iterator_base(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); }
+
+    inline W1 operator*() { return W1(*ptr_); }
+
+    inline iterator_base<W1>& operator++() {
+      if (ptr_ != end_) ptr_++;
+      return *this;
+    }
+
+    inline bool operator==(iterator_base<W1> other) {
+      return ptr_ == other.ptr_ && end_ == other.end_;
+    }
+
+    inline bool operator!=(iterator_base<W1> other) { return !(*this == other); }
+
+    template <class X = W1, typename = ::std::enable_if_t<!::std::is_const<X>::value> >
+    inline operator iterator_base<const_W>() const {
+      return iterator_base<const_W>(ptr_, end_);
+    }
+
+   private:
+    T* ptr_;
+    T* end_;
+  };
+
+  using iterator = iterator_base<W>;
+  using const_iterator = iterator_base<const_W>;
+
+  inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {}
+  inline Span(T* begin, T* end) : begin_{begin}, end_{end} {}
+
+  inline iterator begin() const { return iterator(begin_, end_); }
+
+  inline iterator end() const { return iterator(end_, end_); }
+
+  size_t size() const { return end_ - begin_; }
+
+  inline W operator[](int i) {
+    T* to_return = begin_ + i;
+    ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i;
+    return W(*to_return);
+  }
+
+  inline operator std::vector<W>() { return std::vector<W>(begin(), end()); }
+
+ protected:
+  T* begin_;
+  T* end_;
+};
+
+}  // namespace support
+}  // namespace tvm
+
+#endif  // TVM_SUPPORT_SPAN_H_
diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc
new file mode 100644
index 0000000..7ca333b
--- /dev/null
+++ b/src/runtime/metadata.cc
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file metadata.cc
+ * \brief Implementations of the runtime component of Metadata.
+ */
+
+#include <tvm/runtime/metadata.h>
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::inputs() {
+  return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->inputs, data_->num_inputs);
+}
+ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::outputs() {
+  return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->outputs, data_->num_outputs);
+}
+
+TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode);
+
+MetadataArray::MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index,
+                             const char* struct_name)
+    : MetadataBase{make_object<MetadataArrayNode>(array, type_index, struct_name)} {}
+
+TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode);
+
+Metadata::Metadata(const struct ::TVMMetadata* data)
+    : MetadataBase{make_object<MetadataNode>(data)} {}
+TVM_REGISTER_OBJECT_TYPE(MetadataNode);
+
+TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data)
+    : MetadataBase{make_object<TensorInfoNode>(data)} {}
+TVM_REGISTER_OBJECT_TYPE(TensorInfoNode);
+
+}  // namespace metadata
+}  // namespace runtime
+}  // namespace tvm
diff --git a/src/target/metadata.cc b/src/target/metadata.cc
new file mode 100644
index 0000000..adf4cba
--- /dev/null
+++ b/src/target/metadata.cc
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file metadata.cc
+ * \brief Implementations of the compiler extensions for Metadata.
+ */
+
+#include "metadata.h"
+
+#include <tvm/node/reflection.h>
+
+namespace tvm {
+namespace target {
+namespace metadata {
+
+TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode,
+                               ::tvm::detail::ReflectionTrait<VisitableMetadataNode>)
+    .set_creator([](const std::string&) -> ObjectPtr<Object> {
+      return ::tvm::runtime::make_object<VisitableMetadataNode>();
+    });
+
+TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode,
+                               ::tvm::detail::ReflectionTrait<VisitableTensorInfoNode>)
+    .set_creator([](const std::string&) -> ObjectPtr<Object> {
+      return ::tvm::runtime::make_object<VisitableTensorInfoNode>();
+    });
+
+}  // namespace metadata
+}  // namespace target
+}  // namespace tvm
diff --git a/src/target/metadata.h b/src/target/metadata.h
new file mode 100644
index 0000000..2621d5d
--- /dev/null
+++ b/src/target/metadata.h
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata.h
+ * \brief Extends Metadata for use in the compiler.
+ */
+#ifndef TVM_TARGET_METADATA_H_
+#define TVM_TARGET_METADATA_H_
+
+#include <tvm/runtime/metadata.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace target {
+namespace metadata {
+
+/*!
+ * \brief Subclass of MetadataNode that implements the VisitAttrs reflection method.
+ *
+ * This implementation (and other such Visitable subclasses) is compiled into libtvm.so, but not
+ * libtvm_runtime.so, because reflection is not supported in libtvm_runtime.so over code size
+ * concerns. It is used during compilation by the generic metadata code-generators.
+ */
+class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode {
+ public:
+  explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {}
+  VisitableMetadataNode() : MetadataNode{nullptr} {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    int64_t version_cpp{version()};
+    v->Visit("version", &version_cpp);
+    auto inputs_array = Array<ObjectRef>();
+    auto inputs_accessor = inputs();
+    inputs_array.reserve(num_inputs());
+    for (int64_t i = 0; i < num_inputs(); ++i) {
+      inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]});
+    }
+    ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{
+        inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"};
+    v->Visit("inputs", &inputs_metadata_array);
+    int64_t num_inputs_cpp = num_inputs();
+    v->Visit("num_inputs", &num_inputs_cpp);
+    auto outputs_array = Array<ObjectRef>();
+    auto outputs_accessor = outputs();
+    outputs_array.reserve(num_outputs());
+    for (int64_t i = 0; i < num_outputs(); ++i) {
+      outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]});
+    }
+    ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{
+        outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"};
+    v->Visit("outputs", &outputs_metadata_array);
+    int64_t num_outputs_cpp = num_outputs();
+    v->Visit("num_outputs", &num_outputs_cpp);
+    ::std::string mod_name_cpp{data()->mod_name};
+    v->Visit("mod_name", &mod_name_cpp);
+  }
+};
+
+/*!
+ * \brief Subclass of MetadataNode which also owns the backing C structures.
+ *
+ * This class (and other InMemory subclasses) are used during compilation to instantiate Metadata
+ * instances whose storage lives outside of .rodata. This class exists because the Module returned
+ * from tvm.relay.build must also be ready to run inference.
+ */
+class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode {
+ public:
+  InMemoryMetadataNode()
+      : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */,
+                             "" /* mod_name */) {}
+  InMemoryMetadataNode(int64_t version,
+                       const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs,
+                       const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs,
+                       const ::tvm::runtime::String mod_name)
+      : VisitableMetadataNode{&storage_},
+        inputs_{new struct TVMTensorInfo[inputs.size()]()},
+        inputs_objs_{inputs},
+        outputs_{new struct TVMTensorInfo[outputs.size()]()},
+        outputs_objs_{outputs},
+        mod_name_{mod_name},
+        storage_{version, nullptr, 0, nullptr, 0, mod_name_.c_str()} {
+    storage_.inputs = inputs_.get();
+    storage_.num_inputs = inputs.size();
+    for (unsigned int i = 0; i < inputs.size(); ++i) {
+      inputs_.get()[i] = *inputs[i]->data();
+    }
+    storage_.outputs = outputs_.get();
+    storage_.num_outputs = outputs.size();
+    for (unsigned int i = 0; i < outputs.size(); ++i) {
+      outputs_.get()[i] = *outputs[i]->data();
+    }
+  }
+
+ private:
+  ::std::unique_ptr<struct TVMTensorInfo> inputs_;
+  std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_;
+  ::std::unique_ptr<struct TVMTensorInfo> outputs_;
+  std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_;
+  ::std::string mod_name_;
+  struct ::TVMMetadata storage_;
+};
+
+class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode {
+ public:
+  explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {}
+  VisitableTensorInfoNode() : TensorInfoNode{nullptr} {}
+
+  void VisitAttrs(AttrVisitor* v) {
+    ::std::string name_cpp{data()->name};
+    v->Visit("name", &name_cpp);
+    auto shape_array = Array<ObjectRef>();
+    auto shape_accessor = shape();
+    shape_array.reserve(num_shape());
+    for (int64_t i = 0; i < num_shape(); ++i) {
+      shape_array.push_back(::tvm::Integer{static_cast<int>(shape_accessor[i])});
+    }
+    ::tvm::runtime::metadata::MetadataArray shape_metadata_array{
+        shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr};
+    v->Visit("shape", &shape_metadata_array);
+    int64_t num_shape_cpp = num_shape();
+    v->Visit("num_shape", &num_shape_cpp);
+    ::tvm::runtime::DataType dtype_cpp{dtype()};
+    v->Visit("dtype", &dtype_cpp);
+  }
+};
+
+class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode {
+ public:
+  InMemoryTensorInfoNode() : InMemoryTensorInfoNode("", {}, ::tvm::runtime::DataType(0, 0, 0)) {}
+  InMemoryTensorInfoNode(const ::tvm::runtime::String& name, const ::std::vector<int64_t>& shape,
+                         ::tvm::runtime::DataType dtype)
+      : VisitableTensorInfoNode{&storage_},
+        name_{name},
+        shape_{new int64_t[shape.size()]()},
+        storage_{name_.c_str(), nullptr, 0, dtype} {
+    storage_.shape = shape_.get();
+    storage_.num_shape = shape.size();
+    for (unsigned int i = 0; i < shape.size(); ++i) {
+      shape_.get()[i] = shape[i];
+    }
+  }
+
+ private:
+  ::std::string name_;
+  ::std::unique_ptr<int64_t> shape_;
+  struct ::TVMTensorInfo storage_;
+};
+
+}  // namespace metadata
+}  // namespace target
+}  // namespace tvm
+
+#endif  // TVM_TARGET_METADATA_H_
diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc
new file mode 100644
index 0000000..7307622
--- /dev/null
+++ b/tests/cpp/aot_metadata_test.cc
@@ -0,0 +1,236 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <tvm/ir/type.h>
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/metadata.h>
+
+#include "../src/target/metadata.h"
+
+namespace {
+
+const int64_t kNormalInput1Shape[4] = {1, 5, 5, 3};
+const struct TVMTensorInfo kNormalInputs[2] = {
+    {"input1", kNormalInput1Shape, 4, DLDataType{1, 2, 3}},
+    {"input2", kNormalInput1Shape, 4, DLDataType{2, 3, 4}}};
+
+const int64_t kNormalOutput1Shape[3] = {3, 8, 8};
+const struct TVMTensorInfo kNormalOutputs[1] = {
+    {"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}};
+
+const struct TVMMetadata kNormal = {
+    TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, "default",
+};
+}  // namespace
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::StrEq;
+using ::tvm::runtime::Downcast;
+
+TEST(Metadata, ParseStruct) {
+  tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal);
+  EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION));
+  EXPECT_THAT(md->num_inputs(), Eq(2));
+
+  auto inputs = md->inputs();
+  EXPECT_THAT(inputs.size(), Eq(2));
+
+  auto input1 = inputs[0];
+  EXPECT_THAT(input1->name(), Eq("input1"));
+  EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3})));
+
+  auto input2 = inputs[1];
+  EXPECT_THAT(input2->name(), Eq("input2"));
+  EXPECT_THAT(input2->shape(), ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(input2->dtype(), Eq(tvm::runtime::DataType(DLDataType{2, 3, 4})));
+
+  EXPECT_THAT(md->num_outputs(), Eq(1));
+  auto outputs = md->outputs();
+  EXPECT_THAT(outputs.size(), Eq(1));
+
+  auto output1 = outputs[0];
+  EXPECT_THAT(output1->name(), Eq("output1"));
+  EXPECT_THAT(output1->shape(), ElementsAre(3, 8, 8));
+  EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5})));
+
+  EXPECT_THAT(md->mod_name(), Eq("default"));
+}
+
+class TestVisitor : public tvm::AttrVisitor {
+ public:
+  using Element = ::std::tuple<::std::string, ::tvm::runtime::ObjectRef>;
+  void Visit(const char* key, double* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::FloatImm(::tvm::runtime::DataType(kDLFloat, 64, 1), *value));
+  }
+  void Visit(const char* key, int64_t* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value));
+  }
+  void Visit(const char* key, uint64_t* value) final {
+    keys.push_back(key);
+    int64_t v;
+    *(reinterpret_cast<uint64_t*>(&v)) = *value;
+    values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLUInt, 64, 1), v));
+  }
+  void Visit(const char* key, int* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value));
+  }
+  void Visit(const char* key, bool* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::Bool(*value));
+  }
+  void Visit(const char* key, std::string* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::runtime::String(*value));
+  }
+  void Visit(const char* key, tvm::runtime::DataType* value) final {
+    keys.push_back(key);
+    values.push_back(::tvm::PrimType(*value));
+  }
+  void Visit(const char* key, tvm::runtime::NDArray* value) final {
+    keys.push_back(key);
+    values.push_back(*value);
+  }
+  void Visit(const char* key, void** value) final { CHECK(false) << "Do not expect this type"; }
+
+  void Visit(const char* key, ::tvm::runtime::ObjectRef* value) final {
+    keys.push_back(key);
+    values.push_back(*value);
+  }
+
+  std::vector<std::string> keys;
+  std::vector<::tvm::runtime::ObjectRef> values;
+};
+
+TEST(Metadata, Visitor) {
+  tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal);
+  TestVisitor v;
+  ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v);
+
+  EXPECT_THAT(v.keys, ElementsAre(StrEq("version"), StrEq("inputs"), StrEq("num_inputs"),
+                                  StrEq("outputs"), StrEq("num_outputs"), StrEq("mod_name")));
+  EXPECT_THAT(Downcast<tvm::IntImm>(v.values[0])->value, Eq(TVM_METADATA_VERSION));
+
+  EXPECT_THAT(Downcast<tvm::IntImm>(v.values[0])->value, Eq(TVM_METADATA_VERSION));
+
+  // Just identify the tensor.
+  auto input_array = Downcast<tvm::runtime::metadata::MetadataArray>(v.values[1]);
+  EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata));
+  EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo"));
+  EXPECT_THAT(input_array->array.size(), Eq(2));
+
+  auto input1 = Downcast<tvm::runtime::metadata::TensorInfo>(input_array->array[0]);
+  EXPECT_THAT(input1->name(), StrEq("input1"));
+  EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3}));
+
+  auto input2 = Downcast<tvm::runtime::metadata::TensorInfo>(input_array->array[1]);
+  EXPECT_THAT(input1->name(), StrEq("input1"));
+  EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3}));
+
+  auto num_inputs = Downcast<tvm::IntImm>(v.values[2]);
+  EXPECT_THAT(num_inputs->value, Eq(2));
+
+  auto output_array = Downcast<tvm::runtime::metadata::MetadataArray>(v.values[3]);
+  EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata));
+  EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo"));
+  auto output1 = Downcast<tvm::runtime::metadata::TensorInfo>(output_array->array[0]);
+
+  EXPECT_THAT(output1->name(), Eq("output1"));
+
+  auto num_outputs = Downcast<tvm::IntImm>(v.values[4]);
+  EXPECT_THAT(num_outputs->value, Eq(1));
+}
+
+using ::tvm::runtime::make_object;
+TEST(Metadata, InMemory) {
+  tvm::runtime::metadata::Metadata md =
+      tvm::runtime::metadata::Metadata(make_object<tvm::target::metadata::InMemoryMetadataNode>(
+          TVM_METADATA_VERSION,
+          std::vector<tvm::runtime::metadata::TensorInfo>(
+              {tvm::runtime::metadata::TensorInfo(
+                   make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+                       tvm::String("Input1"), std::vector<int64_t>{1, 5, 5, 3},
+                       tvm::runtime::DataType(DLDataType{1, 2, 3}))),
+               tvm::runtime::metadata::TensorInfo(
+                   make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+                       tvm::String("Input2"), std::vector<int64_t>{1, 5, 5, 3},
+                       tvm::runtime::DataType(DLDataType{2, 3, 4})))}),
+          std::vector<tvm::runtime::metadata::TensorInfo>({tvm::runtime::metadata::TensorInfo(
+              make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+                  tvm::String("Output1"), std::vector<int64_t>{3, 8, 8},
+                  tvm::runtime::DataType(DLDataType{3, 4, 5})))}),
+          "default"));
+
+  auto md_data = md->data();
+  EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION));
+  EXPECT_THAT(md_data->num_inputs, Eq(2));
+
+  auto input0 = &md_data->inputs[0];
+  EXPECT_THAT(input0->name, StrEq("Input1"));
+  EXPECT_THAT(std::vector<int64_t>(input0->shape, input0->shape + input0->num_shape),
+              ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(tvm::runtime::DataType(input0->dtype),
+              Eq(tvm::runtime::DataType(DLDataType({1, 2, 3}))));
+
+  auto input1 = &md_data->inputs[1];
+  EXPECT_THAT(input1->name, StrEq("Input2"));
+  EXPECT_THAT(std::vector<int64_t>(input1->shape, input1->shape + input1->num_shape),
+              ElementsAre(1, 5, 5, 3));
+  EXPECT_THAT(tvm::runtime::DataType(input1->dtype),
+              Eq(tvm::runtime::DataType(DLDataType({2, 3, 4}))));
+
+  auto output0 = &md_data->outputs[0];
+  EXPECT_THAT(output0->name, StrEq("Output1"));
+  EXPECT_THAT(std::vector<int64_t>(output0->shape, output0->shape + output0->num_shape),
+              ElementsAre(3, 8, 8));
+  EXPECT_THAT(tvm::runtime::DataType(output0->dtype),
+              Eq(tvm::runtime::DataType(DLDataType({3, 4, 5}))));
+
+  EXPECT_THAT(md_data->mod_name, StrEq("default"));
+}
+
+TEST(Metadata, ZeroElementLists) {
+  tvm::runtime::metadata::Metadata md =
+      tvm::runtime::metadata::Metadata(make_object<tvm::target::metadata::InMemoryMetadataNode>(
+          TVM_METADATA_VERSION, std::vector<tvm::runtime::metadata::TensorInfo>({}),
+          std::vector<tvm::runtime::metadata::TensorInfo>({tvm::runtime::metadata::TensorInfo(
+              make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+                  tvm::String("Output1"), std::vector<int64_t>{},
+                  tvm::runtime::DataType(DLDataType{3, 4, 5})))}),
+          "default"));
+
+  EXPECT_THAT(md->data()->num_inputs, Eq(0));
+  EXPECT_THAT(md->inputs().size(), Eq(0));
+  EXPECT_THAT(md->num_inputs(), Eq(0));
+  EXPECT_THAT(md->inputs(), ElementsAre());
+
+  auto output0 = md->data()->outputs[0];
+  EXPECT_THAT(output0.num_shape, Eq(0));
+  EXPECT_THAT(md->outputs()[0]->shape().size(), Eq(0));
+  EXPECT_THAT(md->outputs()[0]->shape(), ElementsAre());
+}