You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2022/02/22 17:35:13 UTC
[tvm] branch main updated: [runtime] Add Metadata classes for AOTExecutor (#10282)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 33082e0 [runtime] Add Metadata classes for AOTExecutor (#10282)
33082e0 is described below
commit 33082e0032fb57b0516ad7e3eabd11fe0203437e
Author: Andrew Reusch <ar...@gmail.com>
AuthorDate: Tue Feb 22 09:34:23 2022 -0800
[runtime] Add Metadata classes for AOTExecutor (#10282)
* Add new Metadata classes and base implementation.
* These were autogenerated in the original PR, but checking them in
as plain code until we can revisit the auto-generator approach.
* address masa comments
* Add documentation per Manupa's comments, and move kMetadataVersion namespace.
* remove get_name function, used for debugging
* clang-format
---
include/tvm/runtime/metadata.h | 160 ++++++++++++++++++++++++
include/tvm/runtime/metadata_base.h | 198 ++++++++++++++++++++++++++++++
include/tvm/support/span.h | 103 ++++++++++++++++
src/runtime/metadata.cc | 56 +++++++++
src/target/metadata.cc | 47 +++++++
src/target/metadata.h | 173 ++++++++++++++++++++++++++
tests/cpp/aot_metadata_test.cc | 236 ++++++++++++++++++++++++++++++++++++
7 files changed, 973 insertions(+)
diff --git a/include/tvm/runtime/metadata.h b/include/tvm/runtime/metadata.h
new file mode 100644
index 0000000..b716d41
--- /dev/null
+++ b/include/tvm/runtime/metadata.h
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/metadata.h
+ * \brief Defines types which can be used in Metadata.
+ */
+#ifndef TVM_RUNTIME_METADATA_H_
+#define TVM_RUNTIME_METADATA_H_
+
+#include <inttypes.h>
+#ifdef __cplusplus
+#include <memory>
+#include <string>
+#include <vector>
+#endif
+#include <tvm/runtime/c_runtime_api.h>
+#ifdef __cplusplus
+#include <tvm/runtime/metadata_base.h>
+#endif
+#include <tvm/support/span.h>
+
+// Version number recorded in emitted artifacts for runtime checking.
+#define TVM_METADATA_VERSION 1
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+/*!
+ * \brief Version of metadata emitted and understood by this compiler/runtime.
+ * Should be populated into the `version` field of all TVMMetadata.
+ */
+static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
+} // namespace metadata
+} // namespace runtime
+} // namespace tvm
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/*!
+ * \brief Top-level metadata structure. Holds all other metadata types.
+ */
+struct TVMMetadata {
+ /*! \brief Version identifier for this metadata. */
+ int64_t version;
+ /*! \brief Inputs to the AOT run_model function.
+ * The order of the elements is the same as in the arguments to run_model. That is to say,
+ * this array specifies the first `num_inputs` arguments to run_model.
+ */
+ const struct TVMTensorInfo* inputs;
+ /*! \brief Number of elements in `inputs` array. */
+ int64_t num_inputs;
+ /*! \brief Outputs of the AOT run_model function.
+ * The order of the elements is the same as in the arguments to run_model. That is to say,
+ * this array specifies the last `num_outputs` arguments to run_model.
+ */
+ const struct TVMTensorInfo* outputs;
+ /*! \brief Number of elements in `outputs` array. */
+ int64_t num_outputs;
+ /*! \brief Name of the model, as passed to tvm.relay.build. */
+ const char* mod_name;
+};
+
+/*!
+ * \brief Describes one tensor argument to `run_model`.
+ * NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model
+ * function does not currently accept these. Therefore it's not possible to express those
+ * in this metadata. A future patch may modify this.
+ */
+struct TVMTensorInfo {
+ /*! \brief Name of the tensor, as specified in the Relay program. */
+ const char* name;
+ /*! \brief Shape of the tensor. */
+ const int64_t* shape;
+ /*! \brief Rank of this tensor. */
+ int64_t num_shape;
+ /*! \brief Data type of one element of this tensor. */
+ DLDataType dtype;
+};
+#ifdef __cplusplus
+} // extern "C"
+#include <tvm/runtime/object.h>
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+class Metadata;
+class TensorInfo;
+
+class MetadataNode : public MetadataBaseNode {
+ public:
+ explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
+ static constexpr const char* _type_key = "metadata.MetadataNode";
+ inline int64_t version() const { return int64_t(data_->version); }
+ inline int64_t num_inputs() const { return data_->num_inputs; }
+ ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
+ inline int64_t num_outputs() const { return data_->num_outputs; }
+ ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
+ inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
+ const struct ::TVMMetadata* data() const { return data_; }
+ TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);
+
+ private:
+ const struct ::TVMMetadata* data_;
+};
+
+class Metadata : public MetadataBase {
+ public:
+ explicit Metadata(const struct ::TVMMetadata* data);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode);
+};
+
+class TensorInfoNode : public MetadataBaseNode {
+ public:
+ explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
+ static constexpr const char* _type_key = "metadata.TensorInfoNode";
+ inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
+ inline int64_t num_shape() const { return data_->num_shape; }
+ inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
+ return ::tvm::support::Span<const int64_t, int64_t>(data_->shape,
+ data_->shape + data_->num_shape);
+ }
+ inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); }
+ const struct ::TVMTensorInfo* data() const { return data_; }
+ TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode);
+
+ private:
+ const struct ::TVMTensorInfo* data_;
+};
+
+class TensorInfo : public MetadataBase {
+ public:
+ explicit TensorInfo(const struct ::TVMTensorInfo* data);
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
+};
+
+} // namespace metadata
+} // namespace runtime
+} // namespace tvm
+#endif // defined(__cplusplus)
+
+#endif // TVM_RUNTIME_METADATA_H_
diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h
new file mode 100644
index 0000000..9674319
--- /dev/null
+++ b/include/tvm/runtime/metadata_base.h
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/metadata_base.h
+ * \brief Defines types which can be used in Metadata.
+ */
+#ifndef TVM_RUNTIME_METADATA_BASE_H_
+#define TVM_RUNTIME_METADATA_BASE_H_
+
+#include <tvm/ir/expr.h>
+#include <tvm/runtime/object.h>
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+/*!
+ * \brief Common base class for all Metadata.
+ *
+ * This class is used in the visitor classes as a internal check to ensure that verify that all
+ * parts of the Metadata struct used in codegen are Metadata objects.
+ */
+class MetadataBaseNode : public ::tvm::runtime::Object {
+ public:
+ static constexpr const char* _type_key = "metadata.MetadataBaseNode";
+ TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
+};
+
+/*! \brief Reference class for the common MetadataBaseNode class. */
+class MetadataBase : public ::tvm::runtime::ObjectRef {
+ public:
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
+};
+
+template <typename C, class Ref>
+class ArrayAccessor;
+
+/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */
+template <typename C, class Ref>
+class ArrayIterator {
+ public:
+ ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
+ : index_{index}, parent_{parent} {}
+
+ inline Ref operator*() { return (*parent_)[index_]; }
+
+ inline ArrayIterator<C, Ref>& operator++() {
+ if (index_ < parent_->size()) {
+ index_++;
+ }
+
+ return *this;
+ }
+
+ inline bool operator==(const ArrayIterator<C, Ref>& other) const {
+ return parent_ == other.parent_ && index_ == other.index_;
+ }
+
+ inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }
+
+ private:
+ size_t index_;
+ const ArrayAccessor<C, Ref>* parent_;
+};
+
+/*! \brief A span-like class which permits access to Array fields with complex elements.
+ * These array fields should be accessed from C++ using the Metadata wrapper classes. This class
+ * lazily instantiates those wrappers as they are accessed.
+ */
+template <typename C, class Ref>
+class ArrayAccessor {
+ public:
+ using value_type = Ref;
+ using iterator = ArrayIterator<C, Ref>;
+ using const_iterator = iterator;
+
+ template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
+ ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}
+
+ inline size_t size() const { return num_data_; }
+
+ inline Ref operator[](size_t index) const {
+ if (index >= num_data_) {
+ throw std::runtime_error("Index out of range");
+ }
+
+ return Ref(&data_[index]);
+ }
+
+ inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }
+
+ inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }
+
+ private:
+ const C* data_;
+ size_t num_data_;
+};
+
+/*! \brief A specialization of ArrayAccessor for String.
+ * This class is needed because the String constructor signature is different from the typical
+ * Metadata subclass.
+ */
+template <>
+class ArrayAccessor<const char*, ::tvm::runtime::String> {
+ public:
+ using value_type = ::tvm::runtime::String;
+ using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
+ using const_iterator = iterator;
+
+ ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}
+
+ inline size_t size() const { return num_data_; }
+
+ inline ::tvm::runtime::String operator[](size_t index) const {
+ if (index >= num_data_) {
+ throw std::runtime_error("Index out of range");
+ }
+ return ::tvm::runtime::String(data_[index]);
+ }
+
+ inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
+ return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
+ }
+
+ inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
+ return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
+ }
+
+ private:
+ const char** data_;
+ size_t num_data_;
+};
+
+/*! \brief Enumerates the primitive types which can be part of a Metadata instance.
+ *
+ * These are separate from TIR DataType because TIR does not model structs.
+ */
+enum MetadataTypeIndex : uint8_t {
+ kUint64 = 0,
+ kInt64 = 1,
+ kBool = 2,
+ kString = 3,
+ kHandle = 4,
+ kMetadata = 5,
+};
+
+/*! \brief Container for arrays in the metadata.
+ *
+ * Type information is needed when emitting arrays. This container augments the data field with
+ * the necessary typing information.
+ */
+class MetadataArrayNode : public MetadataBaseNode {
+ public:
+ MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
+ : array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}
+
+ Array<ObjectRef> array;
+ MetadataTypeIndex type_index;
+ const char* struct_name;
+ static constexpr const char* _type_key = "metadata.MetadataArrayNode";
+ TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
+};
+
+/*! \brief Reference class for MetadataArray. */
+class MetadataArray : public MetadataBase {
+ public:
+ MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);
+
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
+};
+
+} // namespace metadata
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_METADATA_BASE_H_
diff --git a/include/tvm/support/span.h b/include/tvm/support/span.h
new file mode 100644
index 0000000..faa849c
--- /dev/null
+++ b/include/tvm/support/span.h
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *
+ * \file tvm/support/span.h
+ * \brief Reimplementation of part of C++-20 style span.
+ */
+#ifndef TVM_SUPPORT_SPAN_H_
+#define TVM_SUPPORT_SPAN_H_
+
+#include <cstddef>
+#include <iterator>
+#include <type_traits>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/*!
+ * \brief A partial implementation of the C++20 std::span.
+ *
+ * At the time of writing, TVM must compile against C++14.
+ */
+template <class T, class W>
+class Span {
+ public:
+ using value_type = W;
+ using const_W = typename ::std::add_const<W>::type;
+
+ template <class W1>
+ class iterator_base : public std::iterator<std::input_iterator_tag, W> {
+ public:
+ inline iterator_base(T* ptr, T* end) : ptr_{ptr}, end_{end} { CHECK_GE(end, ptr); }
+
+ inline W1 operator*() { return W1(*ptr_); }
+
+ inline iterator_base<W1>& operator++() {
+ if (ptr_ != end_) ptr_++;
+ return *this;
+ }
+
+ inline bool operator==(iterator_base<W1> other) {
+ return ptr_ == other.ptr_ && end_ == other.end_;
+ }
+
+ inline bool operator!=(iterator_base<W1> other) { return !(*this == other); }
+
+ template <class X = W1, typename = ::std::enable_if_t<!::std::is_const<X>::value> >
+ inline operator iterator_base<const_W>() const {
+ return iterator_base<const_W>(ptr_, end_);
+ }
+
+ private:
+ T* ptr_;
+ T* end_;
+ };
+
+ using iterator = iterator_base<W>;
+ using const_iterator = iterator_base<const_W>;
+
+ inline Span(T* begin, int num_elements) : begin_{begin}, end_{begin + num_elements} {}
+ inline Span(T* begin, T* end) : begin_{begin}, end_{end} {}
+
+ inline iterator begin() const { return iterator(begin_, end_); }
+
+ inline iterator end() const { return iterator(end_, end_); }
+
+ size_t size() const { return end_ - begin_; }
+
+ inline W operator[](int i) {
+ T* to_return = begin_ + i;
+ ICHECK_LT(to_return, end_) << "Span access out of bounds: " << i;
+ return W(*to_return);
+ }
+
+ inline operator std::vector<W>() { return std::vector<W>(begin(), end()); }
+
+ protected:
+ T* begin_;
+ T* end_;
+};
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_SUPPORT_SPAN_H_
diff --git a/src/runtime/metadata.cc b/src/runtime/metadata.cc
new file mode 100644
index 0000000..7ca333b
--- /dev/null
+++ b/src/runtime/metadata.cc
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file metadata.cc
+ * \brief Implementations of the runtime component of Metadata.
+ */
+
+#include <tvm/runtime/metadata.h>
+
+namespace tvm {
+namespace runtime {
+namespace metadata {
+
+ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::inputs() {
+ return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->inputs, data_->num_inputs);
+}
+ArrayAccessor<struct TVMTensorInfo, TensorInfo> MetadataNode::outputs() {
+ return ArrayAccessor<struct TVMTensorInfo, TensorInfo>(data_->outputs, data_->num_outputs);
+}
+
+TVM_REGISTER_OBJECT_TYPE(MetadataBaseNode);
+
+MetadataArray::MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index,
+ const char* struct_name)
+ : MetadataBase{make_object<MetadataArrayNode>(array, type_index, struct_name)} {}
+
+TVM_REGISTER_OBJECT_TYPE(MetadataArrayNode);
+
+Metadata::Metadata(const struct ::TVMMetadata* data)
+ : MetadataBase{make_object<MetadataNode>(data)} {}
+TVM_REGISTER_OBJECT_TYPE(MetadataNode);
+
+TensorInfo::TensorInfo(const struct ::TVMTensorInfo* data)
+ : MetadataBase{make_object<TensorInfoNode>(data)} {}
+TVM_REGISTER_OBJECT_TYPE(TensorInfoNode);
+
+} // namespace metadata
+} // namespace runtime
+} // namespace tvm
diff --git a/src/target/metadata.cc b/src/target/metadata.cc
new file mode 100644
index 0000000..adf4cba
--- /dev/null
+++ b/src/target/metadata.cc
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file metadata.cc
+ * \brief Implementations of the compiler extensions for Metadata.
+ */
+
+#include "metadata.h"
+
+#include <tvm/node/reflection.h>
+
+namespace tvm {
+namespace target {
+namespace metadata {
+
+TVM_REGISTER_REFLECTION_VTABLE(VisitableMetadataNode,
+ ::tvm::detail::ReflectionTrait<VisitableMetadataNode>)
+ .set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<VisitableMetadataNode>();
+ });
+
+TVM_REGISTER_REFLECTION_VTABLE(VisitableTensorInfoNode,
+ ::tvm::detail::ReflectionTrait<VisitableTensorInfoNode>)
+ .set_creator([](const std::string&) -> ObjectPtr<Object> {
+ return ::tvm::runtime::make_object<VisitableTensorInfoNode>();
+ });
+
+} // namespace metadata
+} // namespace target
+} // namespace tvm
diff --git a/src/target/metadata.h b/src/target/metadata.h
new file mode 100644
index 0000000..2621d5d
--- /dev/null
+++ b/src/target/metadata.h
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/target/metadata.h
+ * \brief Extends Metadata for use in the compiler.
+ */
+#ifndef TVM_TARGET_METADATA_H_
+#define TVM_TARGET_METADATA_H_
+
+#include <tvm/runtime/metadata.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace target {
+namespace metadata {
+
+/*!
+ * \brief Subclass of MetadataNode that implements the VisitAttrs reflection method.
+ *
+ * This implementation (and other such Visitable subclasses) is compiled into libtvm.so, but not
+ * libtvm_runtime.so, because reflection is not supported in libtvm_runtime.so over code size
+ * concerns. It is used during compilation by the generic metadata code-generators.
+ */
+class VisitableMetadataNode : public ::tvm::runtime::metadata::MetadataNode {
+ public:
+ explicit VisitableMetadataNode(const struct ::TVMMetadata* data) : MetadataNode{data} {}
+ VisitableMetadataNode() : MetadataNode{nullptr} {}
+
+ void VisitAttrs(AttrVisitor* v) {
+ int64_t version_cpp{version()};
+ v->Visit("version", &version_cpp);
+ auto inputs_array = Array<ObjectRef>();
+ auto inputs_accessor = inputs();
+ inputs_array.reserve(num_inputs());
+ for (int64_t i = 0; i < num_inputs(); ++i) {
+ inputs_array.push_back(::tvm::runtime::metadata::TensorInfo{inputs_accessor[i]});
+ }
+ ::tvm::runtime::metadata::MetadataArray inputs_metadata_array{
+ inputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"};
+ v->Visit("inputs", &inputs_metadata_array);
+ int64_t num_inputs_cpp = num_inputs();
+ v->Visit("num_inputs", &num_inputs_cpp);
+ auto outputs_array = Array<ObjectRef>();
+ auto outputs_accessor = outputs();
+ outputs_array.reserve(num_outputs());
+ for (int64_t i = 0; i < num_outputs(); ++i) {
+ outputs_array.push_back(::tvm::runtime::metadata::TensorInfo{outputs_accessor[i]});
+ }
+ ::tvm::runtime::metadata::MetadataArray outputs_metadata_array{
+ outputs_array, ::tvm::runtime::metadata::MetadataTypeIndex::kMetadata, "TVMTensorInfo"};
+ v->Visit("outputs", &outputs_metadata_array);
+ int64_t num_outputs_cpp = num_outputs();
+ v->Visit("num_outputs", &num_outputs_cpp);
+ ::std::string mod_name_cpp{data()->mod_name};
+ v->Visit("mod_name", &mod_name_cpp);
+ }
+};
+
+/*!
+ * \brief Subclass of MetadataNode which also owns the backing C structures.
+ *
+ * This class (and other InMemory subclasses) are used during compilation to instantiate Metadata
+ * instances whose storage lives outside of .rodata. This class exists because the Module returned
+ * from tvm.relay.build must also be ready to run inference.
+ */
+class InMemoryMetadataNode : public ::tvm::target::metadata::VisitableMetadataNode {
+ public:
+ InMemoryMetadataNode()
+ : InMemoryMetadataNode(0 /* version */, {} /* inputs */, {} /* outputs */,
+ "" /* mod_name */) {}
+ InMemoryMetadataNode(int64_t version,
+ const ::std::vector<::tvm::runtime::metadata::TensorInfo>& inputs,
+ const ::std::vector<::tvm::runtime::metadata::TensorInfo>& outputs,
+ const ::tvm::runtime::String mod_name)
+ : VisitableMetadataNode{&storage_},
+ inputs_{new struct TVMTensorInfo[inputs.size()]()},
+ inputs_objs_{inputs},
+ outputs_{new struct TVMTensorInfo[outputs.size()]()},
+ outputs_objs_{outputs},
+ mod_name_{mod_name},
+ storage_{version, nullptr, 0, nullptr, 0, mod_name_.c_str()} {
+ storage_.inputs = inputs_.get();
+ storage_.num_inputs = inputs.size();
+ for (unsigned int i = 0; i < inputs.size(); ++i) {
+ inputs_.get()[i] = *inputs[i]->data();
+ }
+ storage_.outputs = outputs_.get();
+ storage_.num_outputs = outputs.size();
+ for (unsigned int i = 0; i < outputs.size(); ++i) {
+ outputs_.get()[i] = *outputs[i]->data();
+ }
+ }
+
+ private:
+ ::std::unique_ptr<struct TVMTensorInfo> inputs_;
+ std::vector<::tvm::runtime::metadata::TensorInfo> inputs_objs_;
+ ::std::unique_ptr<struct TVMTensorInfo> outputs_;
+ std::vector<::tvm::runtime::metadata::TensorInfo> outputs_objs_;
+ ::std::string mod_name_;
+ struct ::TVMMetadata storage_;
+};
+
+class VisitableTensorInfoNode : public ::tvm::runtime::metadata::TensorInfoNode {
+ public:
+ explicit VisitableTensorInfoNode(const struct ::TVMTensorInfo* data) : TensorInfoNode{data} {}
+ VisitableTensorInfoNode() : TensorInfoNode{nullptr} {}
+
+ void VisitAttrs(AttrVisitor* v) {
+ ::std::string name_cpp{data()->name};
+ v->Visit("name", &name_cpp);
+ auto shape_array = Array<ObjectRef>();
+ auto shape_accessor = shape();
+ shape_array.reserve(num_shape());
+ for (int64_t i = 0; i < num_shape(); ++i) {
+ shape_array.push_back(::tvm::Integer{static_cast<int>(shape_accessor[i])});
+ }
+ ::tvm::runtime::metadata::MetadataArray shape_metadata_array{
+ shape_array, ::tvm::runtime::metadata::MetadataTypeIndex::kInt64, nullptr};
+ v->Visit("shape", &shape_metadata_array);
+ int64_t num_shape_cpp = num_shape();
+ v->Visit("num_shape", &num_shape_cpp);
+ ::tvm::runtime::DataType dtype_cpp{dtype()};
+ v->Visit("dtype", &dtype_cpp);
+ }
+};
+
+class InMemoryTensorInfoNode : public ::tvm::target::metadata::VisitableTensorInfoNode {
+ public:
+ InMemoryTensorInfoNode() : InMemoryTensorInfoNode("", {}, ::tvm::runtime::DataType(0, 0, 0)) {}
+ InMemoryTensorInfoNode(const ::tvm::runtime::String& name, const ::std::vector<int64_t>& shape,
+ ::tvm::runtime::DataType dtype)
+ : VisitableTensorInfoNode{&storage_},
+ name_{name},
+ shape_{new int64_t[shape.size()]()},
+ storage_{name_.c_str(), nullptr, 0, dtype} {
+ storage_.shape = shape_.get();
+ storage_.num_shape = shape.size();
+ for (unsigned int i = 0; i < shape.size(); ++i) {
+ shape_.get()[i] = shape[i];
+ }
+ }
+
+ private:
+ ::std::string name_;
+ ::std::unique_ptr<int64_t> shape_;
+ struct ::TVMTensorInfo storage_;
+};
+
+} // namespace metadata
+} // namespace target
+} // namespace tvm
+
+#endif // TVM_TARGET_METADATA_H_
diff --git a/tests/cpp/aot_metadata_test.cc b/tests/cpp/aot_metadata_test.cc
new file mode 100644
index 0000000..7307622
--- /dev/null
+++ b/tests/cpp/aot_metadata_test.cc
@@ -0,0 +1,236 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <tvm/ir/type.h>
+#include <tvm/node/reflection.h>
+#include <tvm/runtime/metadata.h>
+
+#include "../src/target/metadata.h"
+
+namespace {
+
+const int64_t kNormalInput1Shape[4] = {1, 5, 5, 3};
+const struct TVMTensorInfo kNormalInputs[2] = {
+ {"input1", kNormalInput1Shape, 4, DLDataType{1, 2, 3}},
+ {"input2", kNormalInput1Shape, 4, DLDataType{2, 3, 4}}};
+
+const int64_t kNormalOutput1Shape[3] = {3, 8, 8};
+const struct TVMTensorInfo kNormalOutputs[1] = {
+ {"output1", kNormalOutput1Shape, 3, DLDataType{3, 4, 5}}};
+
+const struct TVMMetadata kNormal = {
+ TVM_METADATA_VERSION, kNormalInputs, 2, kNormalOutputs, 1, "default",
+};
+} // namespace
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::StrEq;
+using ::tvm::runtime::Downcast;
+
+TEST(Metadata, ParseStruct) {
+ tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal);
+ EXPECT_THAT(md->version(), Eq(TVM_METADATA_VERSION));
+ EXPECT_THAT(md->num_inputs(), Eq(2));
+
+ auto inputs = md->inputs();
+ EXPECT_THAT(inputs.size(), Eq(2));
+
+ auto input1 = inputs[0];
+ EXPECT_THAT(input1->name(), Eq("input1"));
+ EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(input1->dtype(), Eq(tvm::runtime::DataType(DLDataType{1, 2, 3})));
+
+ auto input2 = inputs[1];
+ EXPECT_THAT(input2->name(), Eq("input2"));
+ EXPECT_THAT(input2->shape(), ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(input2->dtype(), Eq(tvm::runtime::DataType(DLDataType{2, 3, 4})));
+
+ EXPECT_THAT(md->num_outputs(), Eq(1));
+ auto outputs = md->outputs();
+ EXPECT_THAT(outputs.size(), Eq(1));
+
+ auto output1 = outputs[0];
+ EXPECT_THAT(output1->name(), Eq("output1"));
+ EXPECT_THAT(output1->shape(), ElementsAre(3, 8, 8));
+ EXPECT_THAT(output1->dtype(), Eq(tvm::runtime::DataType(DLDataType{3, 4, 5})));
+
+ EXPECT_THAT(md->mod_name(), Eq("default"));
+}
+
+class TestVisitor : public tvm::AttrVisitor {
+ public:
+ using Element = ::std::tuple<::std::string, ::tvm::runtime::ObjectRef>;
+ void Visit(const char* key, double* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::FloatImm(::tvm::runtime::DataType(kDLFloat, 64, 1), *value));
+ }
+ void Visit(const char* key, int64_t* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value));
+ }
+ void Visit(const char* key, uint64_t* value) final {
+ keys.push_back(key);
+ int64_t v;
+ *(reinterpret_cast<uint64_t*>(&v)) = *value;
+ values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLUInt, 64, 1), v));
+ }
+ void Visit(const char* key, int* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::IntImm(::tvm::runtime::DataType(kDLInt, 64, 1), *value));
+ }
+ void Visit(const char* key, bool* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::Bool(*value));
+ }
+ void Visit(const char* key, std::string* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::runtime::String(*value));
+ }
+ void Visit(const char* key, tvm::runtime::DataType* value) final {
+ keys.push_back(key);
+ values.push_back(::tvm::PrimType(*value));
+ }
+ void Visit(const char* key, tvm::runtime::NDArray* value) final {
+ keys.push_back(key);
+ values.push_back(*value);
+ }
+ void Visit(const char* key, void** value) final { CHECK(false) << "Do not expect this type"; }
+
+ void Visit(const char* key, ::tvm::runtime::ObjectRef* value) final {
+ keys.push_back(key);
+ values.push_back(*value);
+ }
+
+ std::vector<std::string> keys;
+ std::vector<::tvm::runtime::ObjectRef> values;
+};
+
+TEST(Metadata, Visitor) {
+ tvm::runtime::metadata::Metadata md = tvm::runtime::metadata::Metadata(&kNormal);
+ TestVisitor v;
+ ::tvm::ReflectionVTable::Global()->VisitAttrs(md.operator->(), &v);
+
+ EXPECT_THAT(v.keys, ElementsAre(StrEq("version"), StrEq("inputs"), StrEq("num_inputs"),
+ StrEq("outputs"), StrEq("num_outputs"), StrEq("mod_name")));
+ EXPECT_THAT(Downcast<tvm::IntImm>(v.values[0])->value, Eq(TVM_METADATA_VERSION));
+
+ EXPECT_THAT(Downcast<tvm::IntImm>(v.values[0])->value, Eq(TVM_METADATA_VERSION));
+
+ // Just identify the tensor.
+ auto input_array = Downcast<tvm::runtime::metadata::MetadataArray>(v.values[1]);
+ EXPECT_THAT(input_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata));
+ EXPECT_THAT(input_array->struct_name, StrEq("TVMTensorInfo"));
+ EXPECT_THAT(input_array->array.size(), Eq(2));
+
+ auto input1 = Downcast<tvm::runtime::metadata::TensorInfo>(input_array->array[0]);
+ EXPECT_THAT(input1->name(), StrEq("input1"));
+ EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3}));
+
+ auto input2 = Downcast<tvm::runtime::metadata::TensorInfo>(input_array->array[1]);
+ EXPECT_THAT(input1->name(), StrEq("input1"));
+ EXPECT_THAT(input1->shape(), ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(input1->dtype(), tvm::runtime::DataType(DLDataType{1, 2, 3}));
+
+ auto num_inputs = Downcast<tvm::IntImm>(v.values[2]);
+ EXPECT_THAT(num_inputs->value, Eq(2));
+
+ auto output_array = Downcast<tvm::runtime::metadata::MetadataArray>(v.values[3]);
+ EXPECT_THAT(output_array->type_index, Eq(tvm::runtime::metadata::MetadataTypeIndex::kMetadata));
+ EXPECT_THAT(output_array->struct_name, StrEq("TVMTensorInfo"));
+ auto output1 = Downcast<tvm::runtime::metadata::TensorInfo>(output_array->array[0]);
+
+ EXPECT_THAT(output1->name(), Eq("output1"));
+
+ auto num_outputs = Downcast<tvm::IntImm>(v.values[4]);
+ EXPECT_THAT(num_outputs->value, Eq(1));
+}
+
+using ::tvm::runtime::make_object;
+TEST(Metadata, InMemory) {
+ tvm::runtime::metadata::Metadata md =
+ tvm::runtime::metadata::Metadata(make_object<tvm::target::metadata::InMemoryMetadataNode>(
+ TVM_METADATA_VERSION,
+ std::vector<tvm::runtime::metadata::TensorInfo>(
+ {tvm::runtime::metadata::TensorInfo(
+ make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+ tvm::String("Input1"), std::vector<int64_t>{1, 5, 5, 3},
+ tvm::runtime::DataType(DLDataType{1, 2, 3}))),
+ tvm::runtime::metadata::TensorInfo(
+ make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+ tvm::String("Input2"), std::vector<int64_t>{1, 5, 5, 3},
+ tvm::runtime::DataType(DLDataType{2, 3, 4})))}),
+ std::vector<tvm::runtime::metadata::TensorInfo>({tvm::runtime::metadata::TensorInfo(
+ make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+ tvm::String("Output1"), std::vector<int64_t>{3, 8, 8},
+ tvm::runtime::DataType(DLDataType{3, 4, 5})))}),
+ "default"));
+
+ auto md_data = md->data();
+ EXPECT_THAT(md_data->version, Eq(TVM_METADATA_VERSION));
+ EXPECT_THAT(md_data->num_inputs, Eq(2));
+
+ auto input0 = &md_data->inputs[0];
+ EXPECT_THAT(input0->name, StrEq("Input1"));
+ EXPECT_THAT(std::vector<int64_t>(input0->shape, input0->shape + input0->num_shape),
+ ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(tvm::runtime::DataType(input0->dtype),
+ Eq(tvm::runtime::DataType(DLDataType({1, 2, 3}))));
+
+ auto input1 = &md_data->inputs[1];
+ EXPECT_THAT(input1->name, StrEq("Input2"));
+ EXPECT_THAT(std::vector<int64_t>(input1->shape, input1->shape + input1->num_shape),
+ ElementsAre(1, 5, 5, 3));
+ EXPECT_THAT(tvm::runtime::DataType(input1->dtype),
+ Eq(tvm::runtime::DataType(DLDataType({2, 3, 4}))));
+
+ auto output0 = &md_data->outputs[0];
+ EXPECT_THAT(output0->name, StrEq("Output1"));
+ EXPECT_THAT(std::vector<int64_t>(output0->shape, output0->shape + output0->num_shape),
+ ElementsAre(3, 8, 8));
+ EXPECT_THAT(tvm::runtime::DataType(output0->dtype),
+ Eq(tvm::runtime::DataType(DLDataType({3, 4, 5}))));
+
+ EXPECT_THAT(md_data->mod_name, StrEq("default"));
+}
+
+TEST(Metadata, ZeroElementLists) {
+ tvm::runtime::metadata::Metadata md =
+ tvm::runtime::metadata::Metadata(make_object<tvm::target::metadata::InMemoryMetadataNode>(
+ TVM_METADATA_VERSION, std::vector<tvm::runtime::metadata::TensorInfo>({}),
+ std::vector<tvm::runtime::metadata::TensorInfo>({tvm::runtime::metadata::TensorInfo(
+ make_object<tvm::target::metadata::InMemoryTensorInfoNode>(
+ tvm::String("Output1"), std::vector<int64_t>{},
+ tvm::runtime::DataType(DLDataType{3, 4, 5})))}),
+ "default"));
+
+ EXPECT_THAT(md->data()->num_inputs, Eq(0));
+ EXPECT_THAT(md->inputs().size(), Eq(0));
+ EXPECT_THAT(md->num_inputs(), Eq(0));
+ EXPECT_THAT(md->inputs(), ElementsAre());
+
+ auto output0 = md->data()->outputs[0];
+ EXPECT_THAT(output0.num_shape, Eq(0));
+ EXPECT_THAT(md->outputs()[0]->shape().size(), Eq(0));
+ EXPECT_THAT(md->outputs()[0]->shape(), ElementsAre());
+}