You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/07/24 22:50:00 UTC
[incubator-tvm] branch master updated: [Relay][VM] Allow to config
allocator type and refactor vm code structure (#6105)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 922e0a0 [Relay][VM] Allow to config allocator type and refactor vm code structure (#6105)
922e0a0 is described below
commit 922e0a05c8cc8fce7964d3f9907fde6981c5f72d
Author: Haichen Shen <sh...@gmail.com>
AuthorDate: Fri Jul 24 15:49:45 2020 -0700
[Relay][VM] Allow to config allocator type and refactor vm code structure (#6105)
* [Relay][VM] Allow to config allocator type and refactor vm code structure
* fix doc
* fix
* update
* trigger ci
* trigger ci
* trigger ci
* trigger ci
* fix doc warning
---
docs/dev/virtual_machine.rst | 11 +-
include/tvm/relay/interpreter.h | 9 +-
include/tvm/runtime/container.h | 17 +
include/tvm/runtime/vm.h | 826 ----------------------
include/tvm/runtime/vm/bytecode.h | 377 ++++++++++
include/tvm/runtime/vm/executable.h | 230 ++++++
{src => include/tvm}/runtime/vm/memory_manager.h | 36 +-
include/tvm/runtime/vm/vm.h | 289 ++++++++
python/tvm/_ffi/runtime_ctypes.py | 3 +
python/tvm/relay/backend/vm.py | 3 +-
python/tvm/runtime/profiler_vm.py | 18 +-
python/tvm/runtime/vm.py | 81 ++-
src/relay/backend/build_module.cc | 1 -
src/relay/backend/vm/compiler.cc | 2 +-
src/relay/backend/vm/compiler.h | 2 +-
src/relay/backend/vm/inline_primitives.cc | 1 -
src/relay/backend/vm/lambda_lift.cc | 1 -
src/relay/backend/vm/removed_unused_funcs.cc | 1 -
src/runtime/container.cc | 3 -
src/runtime/vm/bytecode.cc | 610 ++++++++++++++++
src/runtime/vm/executable.cc | 3 +-
src/runtime/vm/memory_manager.cc | 52 +-
src/runtime/vm/naive_allocator.h | 5 +-
src/runtime/vm/pooled_allocator.h | 5 +-
src/runtime/vm/profiler/vm.cc | 1 -
src/runtime/vm/profiler/vm.h | 2 +-
src/runtime/vm/serialize_util.h | 2 +-
src/runtime/vm/vm.cc | 626 +---------------
tests/python/frontend/tensorflow/test_forward.py | 3 +-
tests/python/relay/benchmarking/benchmark_vm.py | 3 +-
tests/python/relay/test_external_codegen.py | 3 +-
tests/python/relay/test_json_runtime.py | 3 +-
tests/python/relay/test_pass_annotate_target.py | 3 +-
tests/python/relay/test_pass_partition_graph.py | 3 +-
tests/python/relay/test_vm.py | 4 +-
tests/python/relay/test_vm_serialization.py | 6 +-
tests/python/unittest/test_runtime_vm_profiler.py | 3 +-
37 files changed, 1723 insertions(+), 1525 deletions(-)
diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst
index 5878003..059878f 100644
--- a/docs/dev/virtual_machine.rst
+++ b/docs/dev/virtual_machine.rst
@@ -276,11 +276,11 @@ VM.
Currently, three types of objects, ``NDArray``, ``ADT``, and ``Closure`` objects, are used
to represent tensor, tuple/list, and closure data, respectively. More details
for each of them can be found at `include/tvm/runtime/ndarray.h`_,
-`include/tvm/runtime/vm.h`_, and `include/tvm/runtime/container.h`_, respectively.
+`include/tvm/runtime/vm/vm.h`_, and `include/tvm/runtime/container.h`_, respectively.
.. _include/tvm/runtime/ndarray.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/ndarray.h
-.. _include/tvm/runtime/vm.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/vm.h
+.. _include/tvm/runtime/vm/vm.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/vm/vm.h
.. _include/tvm/runtime/container.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/container.h
@@ -321,7 +321,12 @@ VM Compiler
An important part of this infrastructure is a compiler from Relay's full IR into a sequence of bytecode.
The VM compiler transforms a ``tvm::relay::Module`` into a ``tvm::relay::vm::Executable``. The executable
-contains a set of compiled functions, the compiled functions are contained in ``tvm::relay::vm::Function``. The functions contain metadata about the function as well as its compiled bytecode. The emitted executable object then can be loaded and run by a ``tvm::relay::vm::VirtualMachine`` object. For full definitions of the data structures, please see `include/tvm/runtime/vm.h`_.
+contains a set of compiled functions, the compiled functions are contained in ``tvm::relay::vm::Function``.
+The functions contain metadata about the function as well as its compiled bytecode. The emitted executable
+object then can be loaded and run by a ``tvm::relay::vm::VirtualMachine`` object. For full definitions of the
+data structures, please see `include/tvm/runtime/vm/executable.h`_ and `include/tvm/runtime/vm/vm.h`_.
+
+.. _include/tvm/runtime/vm/executable.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/runtime/vm/executable.h
Optimizations
~~~~~~~~~~~~~
diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h
index bda73ed..8a41ab7 100644
--- a/include/tvm/relay/interpreter.h
+++ b/include/tvm/relay/interpreter.h
@@ -38,7 +38,6 @@
#include <tvm/relay/expr.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/object.h>
-#include <tvm/runtime/vm.h>
#include <tvm/target/target.h>
namespace tvm {
@@ -67,7 +66,7 @@ runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, DLCont
Target target);
/*! \brief The container type of Closures used by the interpreter. */
-class InterpreterClosureObj : public runtime::vm::ClosureObj {
+class InterpreterClosureObj : public runtime::ClosureObj {
public:
/*! \brief The set of free variables in the closure.
*
@@ -89,13 +88,13 @@ class InterpreterClosureObj : public runtime::vm::ClosureObj {
}
static constexpr const char* _type_key = "interpreter.Closure";
- TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::vm::ClosureObj);
+ TVM_DECLARE_FINAL_OBJECT_INFO(InterpreterClosureObj, runtime::ClosureObj);
};
-class InterpreterClosure : public runtime::vm::Closure {
+class InterpreterClosure : public runtime::Closure {
public:
TVM_DLL InterpreterClosure(tvm::Map<Var, ObjectRef> env, Function func);
- TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::vm::Closure, InterpreterClosureObj);
+ TVM_DEFINE_OBJECT_REF_METHODS(InterpreterClosure, runtime::Closure, InterpreterClosureObj);
};
/*! \brief The container type of RecClosure. */
diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h
index 5467ae4..f8fa09d 100644
--- a/include/tvm/runtime/container.h
+++ b/include/tvm/runtime/container.h
@@ -1671,6 +1671,23 @@ struct PackedFuncValueConverter<Optional<T>> {
}
};
+/*!
+ * \brief An object representing a closure. This object is used by both the
+ * Relay VM and interpreter.
+ */
+class ClosureObj : public Object {
+ public:
+ static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
+ static constexpr const char* _type_key = "runtime.Closure";
+ TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
+};
+
+/*! \brief reference to closure. */
+class Closure : public ObjectRef {
+ public:
+ TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
+};
+
} // namespace runtime
// expose the functions to the root namespace.
diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h
deleted file mode 100644
index cb98715..0000000
--- a/include/tvm/runtime/vm.h
+++ /dev/null
@@ -1,826 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/runtime/vm.h
- * \brief A virtual machine for executing Relay programs.
- */
-#ifndef TVM_RUNTIME_VM_H_
-#define TVM_RUNTIME_VM_H_
-
-#include <tvm/runtime/memory.h>
-#include <tvm/runtime/object.h>
-#include <tvm/runtime/packed_func.h>
-#include <tvm/runtime/registry.h>
-
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <utility>
-#include <vector>
-
-namespace tvm {
-namespace runtime {
-namespace vm {
-
-/*!
- * \brief An object representing a closure. This object is used by both the
- * Relay VM and interpreter.
- */
-class ClosureObj : public Object {
- public:
- static constexpr const uint32_t _type_index = TypeIndex::kRuntimeClosure;
- static constexpr const char* _type_key = "runtime.Closure";
- TVM_DECLARE_BASE_OBJECT_INFO(ClosureObj, Object);
-};
-
-/*! \brief reference to closure. */
-class Closure : public ObjectRef {
- public:
- TVM_DEFINE_OBJECT_REF_METHODS(Closure, ObjectRef, ClosureObj);
-};
-
-/*!
- * \brief An object representing a vm closure.
- */
-class VMClosureObj : public ClosureObj {
- public:
- /*!
- * \brief The index into the function list. The function could be any
- * function object that is compatible to the VM runtime.
- */
- size_t func_index;
- /*! \brief The free variables of the closure. */
- std::vector<ObjectRef> free_vars;
-
- static constexpr const char* _type_key = "vm.Closure";
- TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
-};
-
-/*! \brief reference to closure. */
-class VMClosure : public Closure {
- public:
- VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
- TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
-};
-
-/*! \brief Magic number for NDArray list file */
-constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
-
-/*! \brief A register name. */
-using RegName = int64_t;
-
-/*! \brief An alias for the integer type used ubiquitously
- * in the VM.
- */
-using Index = int64_t;
-
-/*! \brief An enumeration of Relay's opcodes.
- *
- * The opcode is used to implement instruction
- * as a tagged union.
- */
-enum class Opcode {
- Move = 0U,
- Ret = 1U,
- Invoke = 2U,
- InvokeClosure = 3U,
- InvokePacked = 4U,
- AllocTensor = 5U,
- AllocTensorReg = 6U,
- AllocADT = 7U,
- AllocClosure = 8U,
- GetField = 9U,
- If = 10U,
- LoadConst = 11U,
- Goto = 12U,
- GetTag = 13U,
- LoadConsti = 14U,
- Fatal = 15U,
- AllocStorage = 16U,
- ShapeOf = 17U,
- ReshapeTensor = 18U,
-};
-
-/*! \brief A single virtual machine instruction.
- *
- * The representation of the instruction is as
- * a tagged union.
- *
- * The first field represents which instruction,
- * and by extension which field of the union
- * is active.
- */
-struct Instruction {
- /*! \brief The instruction opcode. */
- Opcode op;
-
- /*! \brief The destination register. */
- RegName dst;
-
- union {
- struct /* AllocTensor Operands */ {
- /*! \brief The storage to allocate from. */
- RegName storage;
- /*! \brief The offset into the storage to allocate from. */
- Index offset;
- /*! \brief The number of dimensions. */
- uint32_t ndim;
- /*! \brief The shape of tensor. */
- int64_t* shape;
- /*! \brief The datatype of tensor to be allocated. */
- DLDataType dtype;
- } alloc_tensor;
- struct /* AllocTensorReg Operands */ {
- /*! \brief The storage to allocate from. */
- RegName storage;
- /*! \brief The offset into the storage to allocate from. */
- Index offset;
- /*! \brief The register to read the shape out of. */
- RegName shape_register;
- /*! \brief The datatype of tensor to be allocated. */
- DLDataType dtype;
- } alloc_tensor_reg;
- struct /* InvokeClosure Operands */ {
- /*! \brief The register containing the closure. */
- RegName closure;
- /*! \brief The number of arguments to the closure. */
- Index num_closure_args;
- /*! \brief The closure arguments as an array. */
- RegName* closure_args;
- };
- struct /* Return Operands */ {
- /*! \brief The register to return. */
- RegName result;
- };
- struct /* Move Operands */ {
- /*! \brief The source register for a move operation. */
- RegName from;
- };
- struct /* InvokePacked Operands */ {
- /*! \brief The index into the packed function table. */
- Index packed_index;
- /*! \brief The arity of the packed function. */
- Index arity;
- /*! \brief The number of outputs produced by the packed function. */
- Index output_size;
- /*! \brief The arguments to pass to the packed function. */
- RegName* packed_args;
- };
- struct /* If Operands */ {
- /*! \brief The register containing the test value. */
- RegName test;
- /*! \brief The register containing the target value. */
- RegName target;
- /*! \brief The program counter offset for the true branch. */
- Index true_offset;
- /*! \brief The program counter offset for the false branch. */
- Index false_offset;
- } if_op;
- struct /* Invoke Operands */ {
- /*! \brief The function to call. */
- Index func_index;
- /*! \brief The number of arguments to the function. */
- Index num_args;
- /*! \brief The registers containing the arguments. */
- RegName* invoke_args_registers;
- };
- struct /* LoadConst Operands */ {
- /* \brief The index into the constant pool. */
- Index const_index;
- };
- struct /* LoadConsti Operands */ {
- /* \brief The index into the constant pool. */
- Index val;
- } load_consti;
- struct /* Jump Operands */ {
- /*! \brief The jump offset. */
- Index pc_offset;
- };
- struct /* Proj Operands */ {
- /*! \brief The register to project from. */
- RegName object;
- /*! \brief The field to read out. */
- Index field_index;
- };
- struct /* GetTag Operands */ {
- /*! \brief The register to project from. */
- RegName object;
- } get_tag;
- struct /* AllocADT Operands */ {
- /*! \brief The datatype's constructor tag. */
- Index constructor_tag;
- /*! \brief The number of fields to store in the datatype. */
- Index num_fields;
- /*! \brief The fields as an array. */
- RegName* datatype_fields;
- };
- struct /* AllocClosure Operands */ {
- /*! \brief The index into the function table. */
- Index clo_index;
- /*! \brief The number of free variables to capture. */
- Index num_freevar;
- /*! \brief The free variables as an array. */
- RegName* free_vars;
- };
- struct /* AllocStorage Operands */ {
- /*! \brief The size of the allocation. */
- RegName allocation_size;
- /*! \brief The alignment of the allocation. */
- Index alignment;
- /*! \brief The hint of the dtype. */
- DLDataType dtype_hint;
- } alloc_storage;
- struct /* ShapeOf Operands */ {
- RegName tensor;
- } shape_of;
- struct /* ReshapeTensor Operands */ {
- RegName tensor;
- RegName newshape;
- } reshape_tensor;
- };
-
- /*!
- * \brief Construct a return instruction.
- * \param return_reg The register containing the return value.
- * \return The return instruction.
- */
- static Instruction Ret(RegName return_reg);
- /*!
- * \brief Construct a fatal instruction.
- * \return The fatal instruction.
- */
- static Instruction Fatal();
- /*!
- * \brief Construct a invoke packed instruction.
- * \param packed_index The index of the packed function.
- * \param arity The arity of the function.
- * \param output_size The number of outputs of the packed function.
- * \param args The argument registers.
- * \return The invoke packed instruction.
- */
- static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
- const std::vector<RegName>& args);
- /*!
- * \brief Construct an allocate tensor instruction with constant shape.
- * \param storage The storage to allocate out of.
- * \param offset The offset to allocate at.
- * \param shape The shape of the tensor.
- * \param dtype The dtype of the tensor.
- * \param dst The destination register.
- * \return The allocate tensor instruction.
- */
- static Instruction AllocTensor(RegName storage, Index offset, const std::vector<int64_t>& shape,
- DLDataType dtype, RegName dst);
- /*!
- * \brief Construct an allocate tensor instruction with register.
- * \param storage The storage to allocate out of.
- * \param offset The offset into the storage to allocate from.
- * \param shape_register The register containing the shape.
- * \param dtype The dtype of the tensor.
- * \param dst The destination register.
- * \return The allocate tensor instruction.
- */
- static Instruction AllocTensorReg(RegName storage, Index offset, RegName shape_register,
- DLDataType dtype, RegName dst);
- /*!
- * \brief Construct an allocate datatype instruction.
- * \param tag The datatype tag.
- * \param num_fields The number of fields for the datatype.
- * \param fields The registers containing the fields.
- * \param dst The register name of the destination.
- * \return The allocate instruction tensor.
- */
- static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
- RegName dst);
- /*!
- * \brief Construct an allocate closure instruction.
- * \param func_index The index of the function table.
- * \param num_freevar The number of free variables.
- * \param free_vars The registers of the free variables.
- * \param dst The destination register.
- * \return The allocate closure instruction.
- */
- static Instruction AllocClosure(Index func_index, Index num_freevar,
- const std::vector<RegName>& free_vars, RegName dst);
- /*!
- * \brief Construct a get field instruction.
- * \param object_reg The register containing the object to project from.
- * \param field_index The field to read out of the object.
- * \param dst The destination register.
- * \return The get field instruction.
- */
- static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
- /*!
- * \brief Construct a get_tag instruction.
- * \param object_reg The register containing the object to project from.
- * \param dst The destination register.
- * \return The get_tag instruction.
- */
- static Instruction GetTag(RegName object_reg, RegName dst);
- /*!
- * \brief Construct an if instruction.
- * \param test The register containing the test value.
- * \param target The register containing the target value.
- * \param true_branch The offset to the true branch.
- * \param false_branch The offset to the false branch.
- * \return The if instruction.
- */
- static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
- /*!
- * \brief Construct a goto instruction.
- * \param pc_offset The offset from the current pc.
- * \return The goto instruction.
- */
- static Instruction Goto(Index pc_offset);
- /*!
- * \brief Construct an invoke instruction.
- * \param func_index The index of the function to invoke.
- * \param args The registers containing the arguments.
- * \param dst The destination register.
- * \return The invoke instruction.
- */
- static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
- /*!
- * \brief Construct an invoke closure instruction.
- * \param closure The register of the closure to invoke.
- * \param args The registers containing the arguments.
- * \param dst The destination register.
- * \return The invoke closure instruction.
- */
- static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
- /*!
- * \brief Construct a load constant instruction.
- * \param const_index The index of the constant.
- * \param dst The destination register.
- * \return The load constant instruction.
- */
- static Instruction LoadConst(Index const_index, RegName dst);
- /*!
- * \brief Construct a load_constanti instruction.
- * \param val The interger constant value.
- * \param dst The destination register.
- * \return The load_constanti instruction.
- */
- static Instruction LoadConsti(Index val, RegName dst);
- /*!
- * \brief Construct a move instruction.
- * \param src The source register.
- * \param dst The destination register.
- * \return The move instruction.
- */
- static Instruction Move(RegName src, RegName dst);
-
- /*!
- * \brief Allocate a storage block.
- * \param size The size of the allocation.
- * \param alignment The allocation's alignment.
- * \param dtype_hint The data type hint for the allocator.
- * \param dst The destination to place the storage.
- * \return The alloc storage instruction.
- */
- static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
- RegName dst);
-
- /*!
- * \brief Get the shape of an input tensor.
- * \param tensor The input tensor.
- * \param dst The destination to store the shape of the given tensor.
- * \return The shape of instruction.
- */
- static Instruction ShapeOf(RegName tensor, RegName dst);
-
- /*!
- * \brief Reshape the tensor given the new shape.
- * \param tensor The input tensor.
- * \param newshape The shape tensor.
- * \param dst The destination to store the output tensor with new shape.
- * \return The reshape tensor instruction.
- */
- static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
-
- Instruction();
- Instruction(const Instruction& instr);
- Instruction& operator=(const Instruction& instr);
- ~Instruction();
-
- friend std::ostream& operator<<(std::ostream& os, const Instruction&);
-};
-
-/*!
- * \brief A representation of a Relay function in the VM.
- *
- * Contains metadata about the compiled function, as
- * well as the compiled VM instructions.
- */
-struct VMFunction {
- /*! \brief The function's name. */
- std::string name;
- /*! \brief The function parameter names. */
- std::vector<std::string> params;
- /*! \brief The instructions representing the function. */
- std::vector<Instruction> instructions;
- /*! \brief The size of the frame for this function */
- Index register_file_size;
-
- VMFunction(const std::string& name, std::vector<std::string> params,
- const std::vector<Instruction>& instructions, Index register_file_size)
- : name(name),
- params(params),
- instructions(instructions),
- register_file_size(register_file_size) {}
-
- VMFunction() {}
-
- friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
-};
-
-/*!
- * \brief A representation of a stack frame.
- *
- * A stack frame is a record containing the information needed
- * to restore the caller's virtual machine state after returning
- * from a function call.
- */
-struct VMFrame {
- /*! \brief The return program counter. */
- Index pc;
- /*! \brief The index into the function table, points to the caller. */
- Index func_index;
- /*! \brief The number of arguments. */
- Index args;
- /*! \brief A pointer into the caller function's instructions. */
- const Instruction* code;
-
- /*! \brief Statically allocated space for objects */
- std::vector<ObjectRef> register_file;
-
- /*! \brief Register in caller's frame to put return value */
- RegName caller_return_register;
-
- VMFrame(Index pc, Index func_index, Index args, const Instruction* code, Index register_file_size)
- : pc(pc),
- func_index(func_index),
- args(args),
- code(code),
- register_file(register_file_size),
- caller_return_register(0) {}
-};
-
-/*!
- * \brief The executable emitted by the VM compiler.
- *
- * The executable contains information (e.g. data in different memory regions)
- * to run in a virtual machine.
- *
- * - Global section, containing all globals.
- * - Constant section, storing the constant pool.
- * - Primitive name section, containing the function name of the primitive ops
- * used by the virtual machine.
- * - Code section, handling the VM functions and bytecode.
- */
-class Executable : public ModuleNode {
- public:
- /*!
- * \brief Get a PackedFunc from an executable module.
- *
- * \param name the name of the function.
- * \param sptr_to_self The shared_ptr that points to this module node.
- *
- * \return PackedFunc or nullptr when it is not available.
- */
- PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
-
- /*!
- * \brief Serialize the executable into global section, constant section, and
- * code section.
- *
- * \return The binary representation of the VM.
- */
- TVMByteArray Save();
-
- /*!
- * \brief Load the saved VM executable.
- *
- * \param code The bytecode in string.
- * \param lib The compiled runtime library.
- *
- * \return exe The constructed executable.
- */
- static runtime::Module Load(const std::string& code, const runtime::Module lib);
-
- /*!
- * \brief Get the serialized form of the `functions`. This is
- * essentially bytecode serialization.
- *
- * \return The serialized vm bytecode.
- *
- * \note The bytecode is in the following format:
- * func_name reg_file_size num_instructions
- * param1 param2 ... paramM
- * instruction1
- * instruction2
- * ...
- * instructionN
- *
- * Each instruction is printed in the following format:
- * opcode num_fields field1 ... fieldX # The text format.
- *
- * Serializing an `Instruction` requires us to deal with the bytecode. Each line
- * of the instructions could be serialized as the following format:
- * hash, opcode, f1, f2, ..., fX, field with variable length
- * 1. hash: the hash of the instruction. This number will be used to help us
- * validate if an instruction is well-formed during deserialization.
- * 2. opcode: the opcode code of the instruction.
- * 3. f1, f2, ..., fX. These fields together represent the fixed fields in
- * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
- * example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
- * 4. The rest of the line indicates the field with variable length, e.g.,
- * the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
-
- * The field starting from # is only used for debugging. The serialized code
- * doesn't contain it, therefore the deserializer doens't need to handle it.
- */
- std::string GetBytecode() const;
-
- /*!
- * \brief Print the detailed statistics of the given code, i.e. number of
- * globls and constants, etc.
- */
- std::string Stats() const;
-
- /*!
- * \brief Get the `lib` module in an executable. Users have the flexibility to call
- * `export_library` from the frontend to save the library to disk.
- *
- * \return The runtime module that contains the hardwre dependent code.
- */
- runtime::Module GetLib() const { return lib; }
-
- /*!
- * \brief Get the arity of the VM Fucntion.
- * \param func Function name.
- * \return The number of parameters.
- */
- int GetFunctionArity(std::string func) const;
-
- /*!
- * \brief Get the parameter name given the function name and parameter index.
- * \param func Function name.
- * \param index Parameter index.
- * \return The parameter name.
- */
- std::string GetFunctionParameterName(std::string func, uint32_t index) const;
-
- virtual ~Executable() {}
-
- const char* type_key() const final { return "VMExecutable"; }
-
- /*! \brief The runtime module/library that contains both the host and also the device
- * code when executing on non-CPU devices. */
- runtime::Module lib;
- /*! \brief The global constant pool. */
- std::vector<ObjectRef> constants;
- /*! \brief A map from globals (as strings) to their index in the function map. */
- std::unordered_map<std::string, Index> global_map;
- /*! \brief A mapping from the packed function (as string) to the index that
- * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
- */
- std::unordered_map<std::string, Index> primitive_map;
- /*! \brief The virtual machine's function table. */
- std::vector<VMFunction> functions;
-
- private:
- /*!
- * \brief Save the globals.
- *
- * \param strm The input stream.
- */
- void SaveGlobalSection(dmlc::Stream* strm);
-
- /*!
- * \brief Save the constant pool.
- *
- * \param strm The input stream.
- */
- void SaveConstantSection(dmlc::Stream* strm);
-
- /*!
- * \brief Save primitive op names.
- *
- * \param strm The input stream.
- */
- void SavePrimitiveOpNames(dmlc::Stream* strm);
-
- /*!
- * \brief Save the vm functions.
- *
- * \param strm The input stream.
- */
- void SaveCodeSection(dmlc::Stream* strm);
-
- /*!
- * \brief Load the globals.
- *
- * \param strm The input stream.
- */
- void LoadGlobalSection(dmlc::Stream* strm);
-
- /*!
- * \brief Load the constant pool.
- *
- * \param strm The input stream.
- */
- void LoadConstantSection(dmlc::Stream* strm);
-
- /*!
- * \brief Load primitive op names.
- *
- * \param strm The input stream.
- */
- void LoadPrimitiveOpNames(dmlc::Stream* strm);
-
- /*!
- * \brief Load the vm functions.
- *
- * \param strm The input stream.
- */
- void LoadCodeSection(dmlc::Stream* strm);
-
- /*! \brief The serialized bytecode. */
- std::string code_;
-};
-
-/*!
- * \brief The virtual machine.
- *
- * The virtual machine contains all the current execution state,
- * as well as the executable.
- *
- * The goal is to have a single self-contained object,
- * enabling one to easily pass around VMs, execute them on
- * multiple threads, or serialize them to disk or over the
- * wire.
- */
-class VirtualMachine : public runtime::ModuleNode {
- public:
- /*!
- * \brief Get a PackedFunc from module.
- *
- * The PackedFunc may not be fully initialized,
- * there might still be first time running overhead when
- * executing the function on certain devices.
- * For benchmarking, use prepare to eliminate
- *
- * \param name the name of the function.
- * \param sptr_to_self The shared_ptr that points to this module node.
- *
- * \return PackedFunc(nullptr) when it is not available.
- *
- * \note The function will always remain valid.
- * If the function needs resource from the module(e.g. late linking),
- * it should capture sptr_to_self.
- */
- virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
-
- virtual ~VirtualMachine() {}
-
- const char* type_key() const final { return "VirtualMachine"; }
-
- VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
-
- /*!
- * \brief load the executable for the virtual machine.
- * \param exec The executable.
- */
- virtual void LoadExecutable(const Executable* exec);
-
- protected:
- /*! \brief The virtual machine's packed function table. */
- std::vector<PackedFunc> packed_funcs_;
- /*! \brief The current stack of call frames. */
- std::vector<VMFrame> frames_;
- /*! \brief The fuction table index of the current function. */
- Index func_index_;
- /*! \brief The current pointer to the code section. */
- const Instruction* code_;
- /*! \brief The virtual machine PC. */
- Index pc_;
- /*! \brief The special return register. */
- ObjectRef return_register_;
- /*! \brief The executable the VM will operate on. */
- const Executable* exec_;
- /*! \brief The function name to inputs mapping. */
- std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
- /*! \brief The set of TVM contexts the VM is currently executing on. */
- std::vector<TVMContext> ctxs_;
-
- /*! \brief Push a call frame on to the call stack. */
- void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
-
- /*!
- * \brief Pop a frame off the call stack.
- * \return The number of frames left.
- */
- Index PopFrame();
-
- /*!
- * \brief Write to a VM register.
- * \param reg The register to write to.
- * \param obj The object to write to.
- */
- inline void WriteRegister(RegName reg, const ObjectRef& obj);
-
- /*!
- * \brief Read a VM register.
- * \param reg The register to read from.
- * \return The read object.
- */
- inline ObjectRef ReadRegister(RegName reg) const;
-
- /*!
- * \brief Read a VM register and cast it to int32_t
- * \param reg The register to read from.
- * \return The read scalar.
- */
- inline int64_t LoadScalarInt(RegName reg) const;
-
- /*!
- * \brief Invoke a VM function.
- * \param func The function.
- * \param args The arguments to the function.
- * \return The object representing the result.
- */
- ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
-
- // TODO(@jroesch): I really would like this to be a global variable.
- /*!
- * \brief Invoke a VM function by name.
- * \param name The function's name.
- * \param args The arguments to the function.
- * \return The object representing the result.
- */
- ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
-
- /*!
- * \brief Invoke a PackedFunction
- *
- * \param packed_index The offset of the PackedFunction in all functions.
- * \param func The PackedFunction to be invoked.
- * \param arg_count The number of arguments to the PackedFunction.
- * \param output_size The number of outputs of the PackedFunction.
- * \param args Arguments to the PackedFunction.
- *
- * \note The return value will be stored in the last output_size slots of args.
- */
- virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
- Index output_size, const std::vector<ObjectRef>& args);
-
- /*!
- * \brief Initialize the virtual machine for a set of contexts.
- * \param contexts The set of TVM contexts.
- */
- void Init(const std::vector<TVMContext>& contexts);
-
- /*! \brief Run VM dispatch loop. */
- void RunLoop();
-
- /*! \brief Get device context for params. */
- TVMContext GetParamsContext() const;
-
- private:
- /*!
- * \brief Invoke a global setting up the VM state to execute.
- *
- * This does not begin execution of the VM.
- */
- void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
-
- /*!
- * \brief The constant pool for runtime. It caches the device dependent
- * object to avoid rellocation of constants during inference.
- */
- std::vector<ObjectRef> const_pool_;
-};
-
-} // namespace vm
-} // namespace runtime
-} // namespace tvm
-
-#endif // TVM_RUNTIME_VM_H_
diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h
new file mode 100644
index 0000000..89a3164
--- /dev/null
+++ b/include/tvm/runtime/vm/bytecode.h
@@ -0,0 +1,377 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/vm/bytecode.h
+ * \brief The bytecode for Relay virtual machine.
+ */
+#ifndef TVM_RUNTIME_VM_BYTECODE_H_
+#define TVM_RUNTIME_VM_BYTECODE_H_
+
+#include <tvm/runtime/data_type.h>
+
+#include <iostream>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace vm {
+
+/*! \brief A register name. */
+using RegName = int64_t;
+
+/*! \brief An alias for the integer type used ubiquitously
+ * in the VM.
+ */
+using Index = int64_t;
+
+/*! \brief An enumeration of Relay's opcodes.
+ *
+ * The opcode is used to implement instruction
+ * as a tagged union.
+ */
+enum class Opcode {
+ Move = 0U,
+ Ret = 1U,
+ Invoke = 2U,
+ InvokeClosure = 3U,
+ InvokePacked = 4U,
+ AllocTensor = 5U,
+ AllocTensorReg = 6U,
+ AllocADT = 7U,
+ AllocClosure = 8U,
+ GetField = 9U,
+ If = 10U,
+ LoadConst = 11U,
+ Goto = 12U,
+ GetTag = 13U,
+ LoadConsti = 14U,
+ Fatal = 15U,
+ AllocStorage = 16U,
+ ShapeOf = 17U,
+ ReshapeTensor = 18U,
+};
+
+/*! \brief A single virtual machine instruction.
+ *
+ * The representation of the instruction is as
+ * a tagged union.
+ *
+ * The first field represents which instruction,
+ * and by extension which field of the union
+ * is active.
+ */
+struct Instruction {
+ /*! \brief The instruction opcode. */
+ Opcode op;
+
+ /*! \brief The destination register. */
+ RegName dst;
+
+ union {
+ struct /* AllocTensor Operands */ {
+ /*! \brief The storage to allocate from. */
+ RegName storage;
+ /*! \brief The offset into the storage to allocate from. */
+ Index offset;
+ /*! \brief The number of dimensions. */
+ uint32_t ndim;
+ /*! \brief The shape of tensor. */
+ int64_t* shape;
+ /*! \brief The datatype of tensor to be allocated. */
+ DLDataType dtype;
+ } alloc_tensor;
+ struct /* AllocTensorReg Operands */ {
+ /*! \brief The storage to allocate from. */
+ RegName storage;
+ /*! \brief The offset into the storage to allocate from. */
+ Index offset;
+ /*! \brief The register to read the shape out of. */
+ RegName shape_register;
+ /*! \brief The datatype of tensor to be allocated. */
+ DLDataType dtype;
+ } alloc_tensor_reg;
+ struct /* InvokeClosure Operands */ {
+ /*! \brief The register containing the closure. */
+ RegName closure;
+ /*! \brief The number of arguments to the closure. */
+ Index num_closure_args;
+ /*! \brief The closure arguments as an array. */
+ RegName* closure_args;
+ };
+ struct /* Return Operands */ {
+ /*! \brief The register to return. */
+ RegName result;
+ };
+ struct /* Move Operands */ {
+ /*! \brief The source register for a move operation. */
+ RegName from;
+ };
+ struct /* InvokePacked Operands */ {
+ /*! \brief The index into the packed function table. */
+ Index packed_index;
+ /*! \brief The arity of the packed function. */
+ Index arity;
+ /*! \brief The number of outputs produced by the packed function. */
+ Index output_size;
+ /*! \brief The arguments to pass to the packed function. */
+ RegName* packed_args;
+ };
+ struct /* If Operands */ {
+ /*! \brief The register containing the test value. */
+ RegName test;
+ /*! \brief The register containing the target value. */
+ RegName target;
+ /*! \brief The program counter offset for the true branch. */
+ Index true_offset;
+ /*! \brief The program counter offset for the false branch. */
+ Index false_offset;
+ } if_op;
+ struct /* Invoke Operands */ {
+ /*! \brief The function to call. */
+ Index func_index;
+ /*! \brief The number of arguments to the function. */
+ Index num_args;
+ /*! \brief The registers containing the arguments. */
+ RegName* invoke_args_registers;
+ };
+ struct /* LoadConst Operands */ {
+ /* \brief The index into the constant pool. */
+ Index const_index;
+ };
+ struct /* LoadConsti Operands */ {
+ /* \brief The index into the constant pool. */
+ Index val;
+ } load_consti;
+ struct /* Jump Operands */ {
+ /*! \brief The jump offset. */
+ Index pc_offset;
+ };
+ struct /* Proj Operands */ {
+ /*! \brief The register to project from. */
+ RegName object;
+ /*! \brief The field to read out. */
+ Index field_index;
+ };
+ struct /* GetTag Operands */ {
+ /*! \brief The register to project from. */
+ RegName object;
+ } get_tag;
+ struct /* AllocADT Operands */ {
+ /*! \brief The datatype's constructor tag. */
+ Index constructor_tag;
+ /*! \brief The number of fields to store in the datatype. */
+ Index num_fields;
+ /*! \brief The fields as an array. */
+ RegName* datatype_fields;
+ };
+ struct /* AllocClosure Operands */ {
+ /*! \brief The index into the function table. */
+ Index clo_index;
+ /*! \brief The number of free variables to capture. */
+ Index num_freevar;
+ /*! \brief The free variables as an array. */
+ RegName* free_vars;
+ };
+ struct /* AllocStorage Operands */ {
+ /*! \brief The size of the allocation. */
+ RegName allocation_size;
+ /*! \brief The alignment of the allocation. */
+ Index alignment;
+ /*! \brief The hint of the dtype. */
+ DLDataType dtype_hint;
+ } alloc_storage;
+ struct /* ShapeOf Operands */ {
+ RegName tensor;
+ } shape_of;
+ struct /* ReshapeTensor Operands */ {
+ RegName tensor;
+ RegName newshape;
+ } reshape_tensor;
+ };
+
+ /*!
+ * \brief Construct a return instruction.
+ * \param return_reg The register containing the return value.
+ * \return The return instruction.
+ */
+ static Instruction Ret(RegName return_reg);
+ /*!
+ * \brief Construct a fatal instruction.
+ * \return The fatal instruction.
+ */
+ static Instruction Fatal();
+ /*!
+ * \brief Construct a invoke packed instruction.
+ * \param packed_index The index of the packed function.
+ * \param arity The arity of the function.
+ * \param output_size The number of outputs of the packed function.
+ * \param args The argument registers.
+ * \return The invoke packed instruction.
+ */
+ static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
+ const std::vector<RegName>& args);
+ /*!
+ * \brief Construct an allocate tensor instruction with constant shape.
+ * \param storage The storage to allocate out of.
+ * \param offset The offset to allocate at.
+ * \param shape The shape of the tensor.
+ * \param dtype The dtype of the tensor.
+ * \param dst The destination register.
+ * \return The allocate tensor instruction.
+ */
+ static Instruction AllocTensor(RegName storage, Index offset, const std::vector<int64_t>& shape,
+ DLDataType dtype, RegName dst);
+ /*!
+ * \brief Construct an allocate tensor instruction with register.
+ * \param storage The storage to allocate out of.
+ * \param offset The offset into the storage to allocate from.
+ * \param shape_register The register containing the shape.
+ * \param dtype The dtype of the tensor.
+ * \param dst The destination register.
+ * \return The allocate tensor instruction.
+ */
+ static Instruction AllocTensorReg(RegName storage, Index offset, RegName shape_register,
+ DLDataType dtype, RegName dst);
+ /*!
+ * \brief Construct an allocate datatype instruction.
+ * \param tag The datatype tag.
+ * \param num_fields The number of fields for the datatype.
+ * \param fields The registers containing the fields.
+ * \param dst The register name of the destination.
+ * \return The allocate instruction tensor.
+ */
+ static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
+ RegName dst);
+ /*!
+ * \brief Construct an allocate closure instruction.
+ * \param func_index The index of the function table.
+ * \param num_freevar The number of free variables.
+ * \param free_vars The registers of the free variables.
+ * \param dst The destination register.
+ * \return The allocate closure instruction.
+ */
+ static Instruction AllocClosure(Index func_index, Index num_freevar,
+ const std::vector<RegName>& free_vars, RegName dst);
+ /*!
+ * \brief Construct a get field instruction.
+ * \param object_reg The register containing the object to project from.
+ * \param field_index The field to read out of the object.
+ * \param dst The destination register.
+ * \return The get field instruction.
+ */
+ static Instruction GetField(RegName object_reg, Index field_index, RegName dst);
+ /*!
+ * \brief Construct a get_tag instruction.
+ * \param object_reg The register containing the object to project from.
+ * \param dst The destination register.
+ * \return The get_tag instruction.
+ */
+ static Instruction GetTag(RegName object_reg, RegName dst);
+ /*!
+ * \brief Construct an if instruction.
+ * \param test The register containing the test value.
+ * \param target The register containing the target value.
+ * \param true_branch The offset to the true branch.
+ * \param false_branch The offset to the false branch.
+ * \return The if instruction.
+ */
+ static Instruction If(RegName test, RegName target, Index true_branch, Index false_branch);
+ /*!
+ * \brief Construct a goto instruction.
+ * \param pc_offset The offset from the current pc.
+ * \return The goto instruction.
+ */
+ static Instruction Goto(Index pc_offset);
+ /*!
+ * \brief Construct an invoke instruction.
+ * \param func_index The index of the function to invoke.
+ * \param args The registers containing the arguments.
+ * \param dst The destination register.
+ * \return The invoke instruction.
+ */
+ static Instruction Invoke(Index func_index, const std::vector<RegName>& args, RegName dst);
+ /*!
+ * \brief Construct an invoke closure instruction.
+ * \param closure The register of the closure to invoke.
+ * \param args The registers containing the arguments.
+ * \param dst The destination register.
+ * \return The invoke closure instruction.
+ */
+ static Instruction InvokeClosure(RegName closure, const std::vector<RegName>& args, RegName dst);
+ /*!
+ * \brief Construct a load constant instruction.
+ * \param const_index The index of the constant.
+ * \param dst The destination register.
+ * \return The load constant instruction.
+ */
+ static Instruction LoadConst(Index const_index, RegName dst);
+ /*!
+ * \brief Construct a load_constanti instruction.
+ * \param val The interger constant value.
+ * \param dst The destination register.
+ * \return The load_constanti instruction.
+ */
+ static Instruction LoadConsti(Index val, RegName dst);
+ /*!
+ * \brief Construct a move instruction.
+ * \param src The source register.
+ * \param dst The destination register.
+ * \return The move instruction.
+ */
+ static Instruction Move(RegName src, RegName dst);
+ /*!
+ * \brief Allocate a storage block.
+ * \param size The size of the allocation.
+ * \param alignment The allocation's alignment.
+ * \param dtype_hint The data type hint for the allocator.
+ * \param dst The destination to place the storage.
+ * \return The alloc storage instruction.
+ */
+ static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
+ RegName dst);
+ /*!
+ * \brief Get the shape of an input tensor.
+ * \param tensor The input tensor.
+ * \param dst The destination to store the shape of the given tensor.
+ * \return The shape of instruction.
+ */
+ static Instruction ShapeOf(RegName tensor, RegName dst);
+ /*!
+ * \brief Reshape the tensor given the new shape.
+ * \param tensor The input tensor.
+ * \param newshape The shape tensor.
+ * \param dst The destination to store the output tensor with new shape.
+ * \return The reshape tensor instruction.
+ */
+ static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
+
+ Instruction();
+ Instruction(const Instruction& instr);
+ Instruction& operator=(const Instruction& instr);
+ ~Instruction();
+
+ friend std::ostream& operator<<(std::ostream& os, const Instruction&);
+};
+
+} // namespace vm
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_VM_BYTECODE_H_
diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h
new file mode 100644
index 0000000..cc38da7
--- /dev/null
+++ b/include/tvm/runtime/vm/executable.h
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/vm/executable.h
+ * \brief The Relay virtual machine executable.
+ */
+#ifndef TVM_RUNTIME_VM_EXECUTABLE_H_
+#define TVM_RUNTIME_VM_EXECUTABLE_H_
+
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/vm/bytecode.h>
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace vm {
+
+struct VMFunction;
+
+/*!
+ * \brief The executable emitted by the VM compiler.
+ *
+ * The executable contains information (e.g. data in different memory regions)
+ * to run in a virtual machine.
+ *
+ * - Global section, containing all globals.
+ * - Constant section, storing the constant pool.
+ * - Primitive name section, containing the function name of the primitive ops
+ * used by the virtual machine.
+ * - Code section, handling the VM functions and bytecode.
+ */
+class Executable : public ModuleNode {
+ public:
+ /*!
+ * \brief Get a PackedFunc from an executable module.
+ *
+ * \param name the name of the function.
+ * \param sptr_to_self The shared_ptr that points to this module node.
+ *
+ * \return PackedFunc or nullptr when it is not available.
+ */
+ PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final;
+
+ /*!
+ * \brief Serialize the executable into global section, constant section, and
+ * code section.
+ *
+ * \return The binary representation of the VM.
+ */
+ TVMByteArray Save();
+
+ /*!
+ * \brief Load the saved VM executable.
+ *
+ * \param code The bytecode in string.
+ * \param lib The compiled runtime library.
+ *
+ * \return exe The constructed executable.
+ */
+ static runtime::Module Load(const std::string& code, const runtime::Module lib);
+
+ /*!
+ * \brief Get the serialized form of the `functions`. This is
+ * essentially bytecode serialization.
+ *
+ * \return The serialized vm bytecode.
+ *
+ * \note The bytecode is in the following format:
+ * func_name reg_file_size num_instructions
+ * param1 param2 ... paramM
+ * instruction1
+ * instruction2
+ * ...
+ * instructionN
+ *
+ * Each instruction is printed in the following format:
+ * opcode num_fields field1 ... fieldX # The text format.
+ *
+ * Serializing an `Instruction` requires us to deal with the bytecode. Each line
+ * of the instructions could be serialized as the following format:
+ * hash, opcode, f1, f2, ..., fX, field with variable length
+ * 1. hash: the hash of the instruction. This number will be used to help us
+ * validate if an instruction is well-formed during deserialization.
+ * 2. opcode: the opcode code of the instruction.
+ * 3. f1, f2, ..., fX. These fields together represent the fixed fields in
+ * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For
+ * example, `DLDataType` will be unpacked into three fields (code, bits, lanes).
+ * 4. The rest of the line indicates the field with variable length, e.g.,
+ * the shape of a tensor, the args used by an `InvokPacked` instruction, etc.
+
+ * The field starting from # is only used for debugging. The serialized code
+ * doesn't contain it, therefore the deserializer doens't need to handle it.
+ */
+ std::string GetBytecode() const;
+
+ /*!
+ * \brief Print the detailed statistics of the given code, i.e. number of
+ * globls and constants, etc.
+ */
+ std::string Stats() const;
+
+ /*!
+ * \brief Get the `lib` module in an executable. Users have the flexibility to call
+ * `export_library` from the frontend to save the library to disk.
+ *
+ * \return The runtime module that contains the hardwre dependent code.
+ */
+ runtime::Module GetLib() const { return lib; }
+
+ /*!
+ * \brief Get the arity of the VM Fucntion.
+ * \param func Function name.
+ * \return The number of parameters.
+ */
+ int GetFunctionArity(std::string func) const;
+
+ /*!
+ * \brief Get the parameter name given the function name and parameter index.
+ * \param func Function name.
+ * \param index Parameter index.
+ * \return The parameter name.
+ */
+ std::string GetFunctionParameterName(std::string func, uint32_t index) const;
+
+ virtual ~Executable() {}
+
+ const char* type_key() const final { return "VMExecutable"; }
+
+ /*! \brief The runtime module/library that contains both the host and also the device
+ * code when executing on non-CPU devices. */
+ runtime::Module lib;
+ /*! \brief The global constant pool. */
+ std::vector<ObjectRef> constants;
+ /*! \brief A map from globals (as strings) to their index in the function map. */
+ std::unordered_map<std::string, Index> global_map;
+ /*! \brief A mapping from the packed function (as string) to the index that
+ * corresponds to the position of the `packed_funcs` list in a `VirtualMachine` object.
+ */
+ std::unordered_map<std::string, Index> primitive_map;
+ /*! \brief The virtual machine's function table. */
+ std::vector<VMFunction> functions;
+
+ private:
+ /*!
+ * \brief Save the globals.
+ *
+ * \param strm The input stream.
+ */
+ void SaveGlobalSection(dmlc::Stream* strm);
+
+ /*!
+ * \brief Save the constant pool.
+ *
+ * \param strm The input stream.
+ */
+ void SaveConstantSection(dmlc::Stream* strm);
+
+ /*!
+ * \brief Save primitive op names.
+ *
+ * \param strm The input stream.
+ */
+ void SavePrimitiveOpNames(dmlc::Stream* strm);
+
+ /*!
+ * \brief Save the vm functions.
+ *
+ * \param strm The input stream.
+ */
+ void SaveCodeSection(dmlc::Stream* strm);
+
+ /*!
+ * \brief Load the globals.
+ *
+ * \param strm The input stream.
+ */
+ void LoadGlobalSection(dmlc::Stream* strm);
+
+ /*!
+ * \brief Load the constant pool.
+ *
+ * \param strm The input stream.
+ */
+ void LoadConstantSection(dmlc::Stream* strm);
+
+ /*!
+ * \brief Load primitive op names.
+ *
+ * \param strm The input stream.
+ */
+ void LoadPrimitiveOpNames(dmlc::Stream* strm);
+
+ /*!
+ * \brief Load the vm functions.
+ *
+ * \param strm The input stream.
+ */
+ void LoadCodeSection(dmlc::Stream* strm);
+
+ /*! \brief The serialized bytecode. */
+ std::string code_;
+};
+
+} // namespace vm
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_VM_EXECUTABLE_H_
diff --git a/src/runtime/vm/memory_manager.h b/include/tvm/runtime/vm/memory_manager.h
similarity index 82%
rename from src/runtime/vm/memory_manager.h
rename to include/tvm/runtime/vm/memory_manager.h
index f59d584..c983cb0 100644
--- a/src/runtime/vm/memory_manager.h
+++ b/include/tvm/runtime/vm/memory_manager.h
@@ -18,7 +18,7 @@
*/
/*!
- * \file src/runtime/memory_manager.h
+ * \file tvm/runtime/vm/memory_manager.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_VM_MEMORY_MANAGER_H_
@@ -64,17 +64,24 @@ struct Buffer {
TVMContext ctx;
};
+enum AllocatorType {
+ kNaive = 1,
+ kPooled,
+};
+
class Allocator {
public:
- Allocator() {}
-
+ explicit Allocator(AllocatorType type) : type_(type) {}
+ virtual ~Allocator() = default;
/*! \brief Allocate an empty NDArray using from the allocator.
* \param shape The shape of the NDArray.
- * \param alignment The datatype of the NDArray.
+ * \param dtype The datatype of the NDArray.
* \param ctx The context where the array is allocated.
* \return The empty NDArray.
*/
NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
+ /*! \brief Return the allocator type. */
+ inline AllocatorType type() const { return type_; }
/*! \brief Allocate a buffer given a size, alignment and type.
* \param nbytes The size of the buffer.
* \param alignment The alignment of the buffer.
@@ -90,21 +97,34 @@ class Allocator {
* \return The amount of memory currently allocated.
*/
virtual size_t UsedMemory() const = 0;
- virtual ~Allocator() = default;
+
+ private:
+ AllocatorType type_;
};
class MemoryManager {
public:
static MemoryManager* Global();
-
- Allocator* GetAllocator(TVMContext ctx);
+ /*!
+ * \brief Get or create an allocator given the context and allocator type.
+ * \param ctx The TVM context
+ * \param type The allocator type
+ * \return The memory allocator.
+ */
+ static Allocator* GetOrCreateAllocator(TVMContext ctx, AllocatorType type);
+ /*!
+ * \brief Get an allocator given the context.
+ * \param ctx The TVM context
+ * \return The memory allocator.
+ */
+ static Allocator* GetAllocator(TVMContext ctx);
private:
MemoryManager() {}
private:
std::mutex mu_;
- std::unordered_map<TVMContext, std::unique_ptr<Allocator> > allocators_;
+ std::unordered_map<TVMContext, std::unique_ptr<Allocator>> allocators_;
};
/*! \brief An object representing a storage allocation. */
diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h
new file mode 100644
index 0000000..273b8fe
--- /dev/null
+++ b/include/tvm/runtime/vm/vm.h
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/runtime/vm/vm.h
+ * \brief The Relay virtual machine runtime.
+ */
+#ifndef TVM_RUNTIME_VM_VM_H_
+#define TVM_RUNTIME_VM_VM_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/runtime/vm/bytecode.h>
+#include <tvm/runtime/vm/executable.h>
+#include <tvm/runtime/vm/memory_manager.h>
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace runtime {
+namespace vm {
+
+/*!
+ * \brief An object representing a vm closure.
+ */
+class VMClosureObj : public ClosureObj {
+ public:
+ /*!
+ * \brief The index into the function list. The function could be any
+ * function object that is compatible to the VM runtime.
+ */
+ size_t func_index;
+ /*! \brief The free variables of the closure. */
+ std::vector<ObjectRef> free_vars;
+
+ static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
+ static constexpr const char* _type_key = "vm.Closure";
+ TVM_DECLARE_FINAL_OBJECT_INFO(VMClosureObj, ClosureObj);
+};
+
+/*! \brief reference to closure. */
+class VMClosure : public Closure {
+ public:
+ VMClosure(size_t func_index, std::vector<ObjectRef> free_vars);
+ TVM_DEFINE_OBJECT_REF_METHODS(VMClosure, Closure, VMClosureObj);
+};
+
+/*!
+ * \brief A representation of a Relay function in the VM.
+ *
+ * Contains metadata about the compiled function, as
+ * well as the compiled VM instructions.
+ */
+struct VMFunction {
+ /*! \brief The function's name. */
+ std::string name;
+ /*! \brief The function parameter names. */
+ std::vector<std::string> params;
+ /*! \brief The instructions representing the function. */
+ std::vector<Instruction> instructions;
+ /*! \brief The size of the frame for this function */
+ Index register_file_size;
+
+ VMFunction(const std::string& name, std::vector<std::string> params,
+ const std::vector<Instruction>& instructions, Index register_file_size)
+ : name(name),
+ params(params),
+ instructions(instructions),
+ register_file_size(register_file_size) {}
+
+ VMFunction() {}
+
+ friend std::ostream& operator<<(std::ostream& os, const VMFunction&);
+};
+
+/*!
+ * \brief A representation of a stack frame.
+ *
+ * A stack frame is a record containing the information needed
+ * to restore the caller's virtual machine state after returning
+ * from a function call.
+ */
+struct VMFrame {
+ /*! \brief The return program counter. */
+ Index pc;
+ /*! \brief The index into the function table, points to the caller. */
+ Index func_index;
+ /*! \brief The number of arguments. */
+ Index args;
+ /*! \brief A pointer into the caller function's instructions. */
+ const Instruction* code;
+
+ /*! \brief Statically allocated space for objects */
+ std::vector<ObjectRef> register_file;
+
+ /*! \brief Register in caller's frame to put return value */
+ RegName caller_return_register;
+
+ VMFrame(Index pc, Index func_index, Index args, const Instruction* code, Index register_file_size)
+ : pc(pc),
+ func_index(func_index),
+ args(args),
+ code(code),
+ register_file(register_file_size),
+ caller_return_register(0) {}
+};
+
+/*!
+ * \brief The virtual machine.
+ *
+ * The virtual machine contains all the current execution state,
+ * as well as the executable.
+ *
+ * The goal is to have a single self-contained object,
+ * enabling one to easily pass around VMs, execute them on
+ * multiple threads, or serialize them to disk or over the
+ * wire.
+ */
+class VirtualMachine : public runtime::ModuleNode {
+ public:
+ /*!
+ * \brief Get a PackedFunc from module.
+ *
+ * The PackedFunc may not be fully initialized,
+ * there might still be first time running overhead when
+ * executing the function on certain devices.
+ * For benchmarking, use prepare to eliminate
+ *
+ * \param name the name of the function.
+ * \param sptr_to_self The shared_ptr that points to this module node.
+ *
+ * \return PackedFunc(nullptr) when it is not available.
+ *
+ * \note The function will always remain valid.
+ * If the function needs resource from the module(e.g. late linking),
+ * it should capture sptr_to_self.
+ */
+ virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
+
+ virtual ~VirtualMachine() {}
+
+ const char* type_key() const final { return "VirtualMachine"; }
+
+ VirtualMachine() : frames_(), func_index_(0), code_(nullptr), pc_(0), exec_(nullptr) {}
+
+ /*!
+ * \brief load the executable for the virtual machine.
+ * \param exec The executable.
+ */
+ virtual void LoadExecutable(const Executable* exec);
+
+ protected:
+ /*! \brief Push a call frame on to the call stack. */
+ void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
+
+ /*!
+ * \brief Pop a frame off the call stack.
+ * \return The number of frames left.
+ */
+ Index PopFrame();
+
+ /*!
+ * \brief Write to a VM register.
+ * \param reg The register to write to.
+ * \param obj The object to write to.
+ */
+ inline void WriteRegister(RegName reg, const ObjectRef& obj);
+
+ /*!
+ * \brief Read a VM register.
+ * \param reg The register to read from.
+ * \return The read object.
+ */
+ inline ObjectRef ReadRegister(RegName reg) const;
+
+ /*!
+ * \brief Read a VM register and cast it to int32_t
+ * \param reg The register to read from.
+ * \return The read scalar.
+ */
+ inline int64_t LoadScalarInt(RegName reg) const;
+
+ /*!
+ * \brief Invoke a VM function.
+ * \param func The function.
+ * \param args The arguments to the function.
+ * \return The object representing the result.
+ */
+ ObjectRef Invoke(const VMFunction& func, const std::vector<ObjectRef>& args);
+
+ // TODO(@jroesch): I really would like this to be a global variable.
+ /*!
+ * \brief Invoke a VM function by name.
+ * \param name The function's name.
+ * \param args The arguments to the function.
+ * \return The object representing the result.
+ */
+ ObjectRef Invoke(const std::string& name, const std::vector<ObjectRef>& args);
+
+ /*!
+ * \brief Invoke a PackedFunction
+ *
+ * \param packed_index The offset of the PackedFunction in all functions.
+ * \param func The PackedFunction to be invoked.
+ * \param arg_count The number of arguments to the PackedFunction.
+ * \param output_size The number of outputs of the PackedFunction.
+ * \param args Arguments to the PackedFunction.
+ *
+ * \note The return value will be stored in the last output_size slots of args.
+ */
+ virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
+ Index output_size, const std::vector<ObjectRef>& args);
+
+ /*!
+ * \brief Initialize the virtual machine for a set of contexts.
+ * \param contexts The set of TVM contexts.
+ * \param alloc_types The allocator types for each context.
+ */
+ void Init(const std::vector<TVMContext>& contexts, const std::vector<AllocatorType>& alloc_types);
+
+ /*! \brief Run VM dispatch loop. */
+ void RunLoop();
+
+ /*! \brief Get device context for params. */
+ TVMContext GetParamsContext() const;
+
+ /*!
+ * \brief Invoke a global setting up the VM state to execute.
+ *
+ * This does not begin execution of the VM.
+ */
+ void InvokeGlobal(const VMFunction& func, const std::vector<ObjectRef>& args);
+
+ protected:
+ /*! \brief The virtual machine's packed function table. */
+ std::vector<PackedFunc> packed_funcs_;
+ /*! \brief The current stack of call frames. */
+ std::vector<VMFrame> frames_;
+ /*! \brief The fuction table index of the current function. */
+ Index func_index_;
+ /*! \brief The current pointer to the code section. */
+ const Instruction* code_;
+ /*! \brief The virtual machine PC. */
+ Index pc_;
+ /*! \brief The special return register. */
+ ObjectRef return_register_;
+ /*! \brief The executable the VM will operate on. */
+ const Executable* exec_;
+ /*! \brief The function name to inputs mapping. */
+ std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
+ /*! \brief The set of TVM contexts the VM is currently executing on. */
+ std::vector<TVMContext> ctxs_;
+ /*! \brief The mapping from TVM context to memory allocator. */
+ std::unordered_map<TVMContext, Allocator*> allocators_;
+ /*!
+ * \brief The constant pool for runtime. It caches the device dependent
+ * object to avoid rellocation of constants during inference.
+ */
+ std::vector<ObjectRef> const_pool_;
+};
+
+} // namespace vm
+} // namespace runtime
+} // namespace tvm
+
+#endif // TVM_RUNTIME_VM_VM_H_
diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index a7bfb32..dcc9528 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -274,6 +274,9 @@ class TVMContext(ctypes.Structure):
def __ne__(self, other):
return not self.__eq__(other)
+ def __hash__(self):
+ return hash(str(self))
+
def __repr__(self):
if self.device_type >= RPC_SESS_MASK:
tbl_id = self.device_type / RPC_SESS_MASK - 1
diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py
index 16d4724..cb7761b 100644
--- a/python/tvm/relay/backend/vm.py
+++ b/python/tvm/relay/backend/vm.py
@@ -248,8 +248,7 @@ class VMExecutor(Executor):
self.ctx = ctx
self.target = target
self.executable = compile(mod, target)
- self.vm = vm_rt.VirtualMachine(self.executable)
- self.vm.init(ctx)
+ self.vm = vm_rt.VirtualMachine(self.executable, ctx)
def _make_executor(self, expr=None):
main = self.mod["main"]
diff --git a/python/tvm/runtime/profiler_vm.py b/python/tvm/runtime/profiler_vm.py
index 9d60483..5df10e5 100644
--- a/python/tvm/runtime/profiler_vm.py
+++ b/python/tvm/runtime/profiler_vm.py
@@ -32,15 +32,15 @@ def enabled():
class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
- def __init__(self, mod):
- super(VirtualMachineProfiler, self).__init__(mod)
- m = mod.module if isinstance(mod, vm.Executable) else mod
- self.mod = _ffi_api._VirtualMachineDebug(m)
- self._init = self.mod["init"]
- self._invoke = self.mod["invoke"]
- self._get_stat = self.mod["get_stat"]
- self._set_input = self.mod["set_input"]
- self._reset = self.mod["reset"]
+ def __init__(self, exe, ctx, memory_cfg=None):
+ super(VirtualMachineProfiler, self).__init__(exe, ctx, memory_cfg)
+ self.module = _ffi_api._VirtualMachineDebug(exe.module)
+ self._init = self.module["init"]
+ self._invoke = self.module["invoke"]
+ self._get_stat = self.module["get_stat"]
+ self._set_input = self.module["set_input"]
+ self._reset = self.module["reset"]
+ self._setup_ctx(ctx, memory_cfg)
def get_stat(self, sort_by_time=True):
"""Get the statistics of executed ops.
diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py
index d7d9451..f88f43d 100644
--- a/python/tvm/runtime/vm.py
+++ b/python/tvm/runtime/vm.py
@@ -131,8 +131,7 @@ class Executable(object):
des_exec = tvm.runtime.vm.Executable.load_exec(loaded_code, loaded_code)
# execute the deserialized executable.
x_data = np.random.rand(10, 10).astype('float32')
- des_vm = tvm.runtime.vm.VirtualMachine(des_exec)
- des_vm.init(ctx)
+ des_vm = tvm.runtime.vm.VirtualMachine(des_exec, ctx)
res = des_vm.run(x_data)
print(res.asnumpy())
"""
@@ -273,29 +272,61 @@ class Executable(object):
class VirtualMachine(object):
- """Relay VM runtime."""
-
- def __init__(self, mod):
- if not isinstance(mod, (Executable, tvm.runtime.Module)):
- raise TypeError("mod is expected to be the type of Executable or " +
- "tvm.runtime.Module, but received {}".format(type(mod)))
- m = mod.module if isinstance(mod, Executable) else mod
- self.mod = _ffi_api._VirtualMachine(m)
- self._exec = mod
- self._init = self.mod["init"]
- self._invoke = self.mod["invoke"]
- self._set_input = self.mod["set_input"]
-
- def init(self, ctx):
- """Initialize the context in the VM.
-
- Parameters
- ----------
- ctx : :py:class:`TVMContext`
- The runtime context to run the code on.
- """
- args = [ctx.device_type, ctx.device_id]
- self._init(*args)
+ """Relay VM runtime.
+
+ Parameters
+ ----------
+ exe : Executable
+ The VM executable.
+
+ ctx : tvm.runtime.TVMContext or List[tvm.runtime.TVMContext]
+ The context to deploy the module
+
+ memory_cfg : str or Dict[tvm.runtime.TVMContext, str], optional
+ Config the type of memory allocator. The allocator type can be ["naive",
+ "pooled"]. If memory_cfg is None, all contexts will use pooled allocator
+ by default. If memory_cfg is string, all contexts will use the specified
+ allocator type. If memory_cfg is a dict, each context uses the allocator
+ type specified in the dict, or pooled allocator if not specified in the
+ dict.
+ """
+
+ NAIVE_ALLOCATOR = 1
+ POOLED_ALLOCATOR = 2
+
+ def __init__(self, exe, ctx, memory_cfg=None):
+ if not isinstance(exe, Executable):
+ raise TypeError("exe is expected to be the type of Executable, " +
+ "but received {}".format(type(exe)))
+ self.module = _ffi_api._VirtualMachine(exe.module)
+ self._exec = exe
+ self._init = self.module["init"]
+ self._invoke = self.module["invoke"]
+ self._set_input = self.module["set_input"]
+ self._setup_ctx(ctx, memory_cfg)
+
+ def _setup_ctx(self, ctx, memory_cfg):
+ """Init context and allocators."""
+ if isinstance(ctx, tvm.runtime.TVMContext):
+ ctx = [ctx]
+ default_alloc_type = VirtualMachine.POOLED_ALLOCATOR
+ if memory_cfg is None:
+ memory_cfg = {}
+ elif isinstance(memory_cfg, str):
+ assert memory_cfg in ["naive", "pooled"]
+ if memory_cfg == "naive":
+ default_alloc_type = VirtualMachine.NAIVE_ALLOCATOR
+ memory_cfg = {}
+ elif not isinstance(memory_cfg, dict):
+ raise TypeError("memory_cfg is expected be string or dictionary, " +
+ "but received {}".format(type(memory_cfg)))
+ init_args = []
+ for context in ctx:
+ init_args.append(context.device_type)
+ init_args.append(context.device_id)
+ alloc_type = memory_cfg[context] if context in memory_cfg else default_alloc_type
+ init_args.append(alloc_type)
+ self._init(*init_args)
def set_input(self, func_name, *args, **kwargs):
"""Set the input to a function.
diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc
index b57c0eb..1392798 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -27,7 +27,6 @@
#include <tvm/relay/qnn/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/vm.h>
#include <memory>
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index ab11c6c..b811911 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -31,7 +31,7 @@
#include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/vm.h>
#include <tvm/support/logging.h>
#include <tvm/te/operation.h>
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index 8b1df7f..d1e1f7e 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -29,7 +29,7 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/vm.h>
#include <tvm/support/logging.h>
#include <tvm/tir/function.h>
diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc
index cf4f533..650df99 100644
--- a/src/relay/backend/vm/inline_primitives.cc
+++ b/src/relay/backend/vm/inline_primitives.cc
@@ -25,7 +25,6 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/vm.h>
#include <tvm/support/logging.h>
#include <iostream>
diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc
index 011c7d2..22b8364 100644
--- a/src/relay/backend/vm/lambda_lift.cc
+++ b/src/relay/backend/vm/lambda_lift.cc
@@ -28,7 +28,6 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/vm.h>
#include <tvm/support/logging.h>
#include <iostream>
diff --git a/src/relay/backend/vm/removed_unused_funcs.cc b/src/relay/backend/vm/removed_unused_funcs.cc
index 4e8713b..cdf898f 100644
--- a/src/relay/backend/vm/removed_unused_funcs.cc
+++ b/src/relay/backend/vm/removed_unused_funcs.cc
@@ -26,7 +26,6 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
-#include <tvm/runtime/vm.h>
#include <tvm/support/logging.h>
#include <iostream>
diff --git a/src/runtime/container.cc b/src/runtime/container.cc
index 62220a8..2532432 100644
--- a/src/runtime/container.cc
+++ b/src/runtime/container.cc
@@ -25,13 +25,10 @@
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/vm.h>
namespace tvm {
namespace runtime {
-using namespace vm;
-
TVM_REGISTER_GLOBAL("runtime.GetADTTag").set_body([](TVMArgs args, TVMRetValue* rv) {
ObjectRef obj = args[0];
const auto& adt = Downcast<ADT>(obj);
diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc
new file mode 100644
index 0000000..edfd3ac
--- /dev/null
+++ b/src/runtime/vm/bytecode.cc
@@ -0,0 +1,610 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/vm/bytecode.cc
+ * \brief The bytecode for Relay virtual machine.
+ */
+
+#include <dmlc/logging.h>
+#include <tvm/runtime/vm/bytecode.h>
+#include <tvm/support/logging.h>
+
+#include <sstream>
+
+namespace tvm {
+namespace runtime {
+namespace vm {
+
+Instruction::Instruction() {}
+
+template <typename T>
+static T* Duplicate(T* src, Index size) {
+ auto dst = new T[size];
+ std::copy(src, src + size, dst);
+ return dst;
+}
+
+Instruction::Instruction(const Instruction& instr) {
+ this->op = instr.op;
+ this->dst = instr.dst;
+
+ switch (instr.op) {
+ case Opcode::Move:
+ this->from = instr.from;
+ return;
+ case Opcode::Fatal:
+ return;
+ case Opcode::Ret:
+ this->result = instr.result;
+ return;
+ case Opcode::AllocTensor:
+ this->alloc_tensor.storage = instr.alloc_tensor.storage;
+ this->alloc_tensor.offset = instr.alloc_tensor.offset;
+ this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
+ this->alloc_tensor.shape =
+ Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
+ this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
+ return;
+ case Opcode::AllocTensorReg:
+ this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
+ this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset;
+ this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
+ this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
+ return;
+ case Opcode::AllocADT:
+ this->constructor_tag = instr.constructor_tag;
+ this->num_fields = instr.num_fields;
+ this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
+ return;
+ case Opcode::AllocClosure:
+ this->clo_index = instr.clo_index;
+ this->num_freevar = instr.num_freevar;
+ this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
+ return;
+ case Opcode::InvokePacked:
+ this->packed_index = instr.packed_index;
+ this->arity = instr.arity;
+ this->output_size = instr.output_size;
+ this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
+ return;
+ case Opcode::InvokeClosure:
+ this->closure = instr.closure;
+ this->num_closure_args = instr.num_closure_args;
+ this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
+ return;
+ case Opcode::Invoke:
+ this->func_index = instr.func_index;
+ this->num_args = instr.num_args;
+ this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
+ return;
+ case Opcode::If:
+ this->if_op = instr.if_op;
+ return;
+ case Opcode::LoadConst:
+ this->const_index = instr.const_index;
+ return;
+ case Opcode::LoadConsti:
+ this->load_consti = instr.load_consti;
+ return;
+ case Opcode::GetField:
+ this->object = instr.object;
+ this->field_index = instr.field_index;
+ return;
+ case Opcode::GetTag:
+ this->get_tag = instr.get_tag;
+ return;
+ case Opcode::Goto:
+ this->pc_offset = instr.pc_offset;
+ return;
+ case Opcode::AllocStorage:
+ this->alloc_storage = instr.alloc_storage;
+ return;
+ case Opcode::ShapeOf:
+ this->shape_of.tensor = instr.shape_of.tensor;
+ return;
+ case Opcode::ReshapeTensor:
+ this->reshape_tensor.tensor = instr.reshape_tensor.tensor;
+ this->reshape_tensor.newshape = instr.reshape_tensor.newshape;
+ return;
+ default:
+ std::ostringstream out;
+ out << "Invalid instruction " << static_cast<int>(instr.op);
+ throw std::runtime_error(out.str());
+ }
+}
+
+template <typename T>
+static inline void FreeIf(T* t) {
+ if (t != nullptr) {
+ delete t;
+ }
+}
+
+Instruction& Instruction::operator=(const Instruction& instr) {
+ this->op = instr.op;
+ this->dst = instr.dst;
+
+ switch (instr.op) {
+ case Opcode::Move:
+ this->from = instr.from;
+ return *this;
+ case Opcode::Fatal:
+ return *this;
+ case Opcode::LoadConsti:
+ this->load_consti = instr.load_consti;
+ return *this;
+ case Opcode::Ret:
+ this->result = instr.result;
+ return *this;
+ case Opcode::AllocTensor:
+ this->alloc_tensor.storage = this->alloc_tensor.storage;
+ this->alloc_tensor.offset = instr.alloc_tensor.offset;
+ this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
+ this->alloc_tensor.shape =
+ Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
+ this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
+ return *this;
+ case Opcode::AllocTensorReg:
+ this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
+ this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset;
+ this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
+ this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
+ return *this;
+ case Opcode::AllocADT:
+ this->constructor_tag = instr.constructor_tag;
+ this->num_fields = instr.num_fields;
+ FreeIf(this->datatype_fields);
+ this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
+ return *this;
+ case Opcode::AllocClosure:
+ this->clo_index = instr.clo_index;
+ this->num_freevar = instr.num_freevar;
+ FreeIf(this->free_vars);
+ this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
+ return *this;
+ case Opcode::InvokePacked:
+ this->packed_index = instr.packed_index;
+ this->arity = instr.arity;
+ this->output_size = instr.output_size;
+ FreeIf(this->packed_args);
+ this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
+ return *this;
+ case Opcode::InvokeClosure:
+ this->closure = instr.closure;
+ this->num_closure_args = instr.num_closure_args;
+ FreeIf(this->closure_args);
+ this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
+ return *this;
+ case Opcode::Invoke:
+ this->func_index = instr.func_index;
+ this->num_args = instr.num_args;
+ FreeIf(this->invoke_args_registers);
+ this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
+ return *this;
+ case Opcode::If:
+ this->if_op = instr.if_op;
+ return *this;
+ case Opcode::LoadConst:
+ this->const_index = instr.const_index;
+ return *this;
+ case Opcode::GetField:
+ this->object = instr.object;
+ this->field_index = instr.field_index;
+ return *this;
+ case Opcode::GetTag:
+ this->get_tag = instr.get_tag;
+ return *this;
+ case Opcode::Goto:
+ this->pc_offset = instr.pc_offset;
+ return *this;
+ case Opcode::AllocStorage:
+ this->alloc_storage = instr.alloc_storage;
+ return *this;
+ case Opcode::ShapeOf:
+ this->shape_of.tensor = instr.shape_of.tensor;
+ return *this;
+ default:
+ std::ostringstream out;
+ out << "Invalid instruction " << static_cast<int>(instr.op);
+ throw std::runtime_error(out.str());
+ }
+}
+
+Instruction::~Instruction() {
+ switch (this->op) {
+ case Opcode::Move:
+ case Opcode::Ret:
+ case Opcode::AllocTensorReg:
+ case Opcode::If:
+ case Opcode::LoadConst:
+ case Opcode::GetField:
+ case Opcode::GetTag:
+ case Opcode::Goto:
+ case Opcode::LoadConsti:
+ case Opcode::AllocStorage:
+ case Opcode::ShapeOf:
+ case Opcode::ReshapeTensor:
+ case Opcode::Fatal:
+ return;
+ case Opcode::AllocTensor:
+ delete[] this->alloc_tensor.shape;
+ return;
+ case Opcode::AllocADT:
+ delete[] this->datatype_fields;
+ return;
+ case Opcode::AllocClosure:
+ delete[] this->free_vars;
+ return;
+ case Opcode::InvokePacked:
+ delete[] this->packed_args;
+ return;
+ case Opcode::InvokeClosure:
+ delete[] this->closure_args;
+ return;
+ case Opcode::Invoke:
+ delete[] this->invoke_args_registers;
+ return;
+ default:
+ std::ostringstream out;
+ LOG(FATAL) << "Invalid instruction " << static_cast<int>(this->op);
+ }
+}
+
+Instruction Instruction::Ret(RegName result) {
+ Instruction instr;
+ instr.op = Opcode::Ret;
+ instr.result = result;
+ return instr;
+}
+
+Instruction Instruction::Fatal() {
+ Instruction instr;
+ instr.op = Opcode::Fatal;
+ return instr;
+}
+
+Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size,
+ const std::vector<RegName>& args) {
+ Instruction instr;
+ instr.op = Opcode::InvokePacked;
+ instr.packed_index = packed_index;
+ instr.arity = arity;
+ instr.output_size = output_size;
+ instr.packed_args = new RegName[arity];
+ for (Index i = 0; i < arity; ++i) {
+ instr.packed_args[i] = args[i];
+ }
+ return instr;
+}
+
+Instruction Instruction::AllocTensor(RegName storage, RegName offset,
+ const std::vector<int64_t>& shape, DLDataType dtype,
+ RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::AllocTensor;
+ instr.dst = dst;
+ instr.alloc_tensor.storage = storage;
+ instr.alloc_tensor.offset = offset;
+ instr.alloc_tensor.ndim = shape.size();
+ instr.alloc_tensor.shape = new int64_t[shape.size()];
+ for (size_t i = 0; i < shape.size(); ++i) {
+ instr.alloc_tensor.shape[i] = shape[i];
+ }
+ instr.alloc_tensor.dtype = dtype;
+ return instr;
+}
+
+Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register,
+ DLDataType dtype, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::AllocTensorReg;
+ instr.dst = dst;
+ instr.alloc_tensor_reg.storage = storage;
+ instr.alloc_tensor_reg.offset = offset;
+ instr.alloc_tensor_reg.shape_register = shape_register;
+ instr.alloc_tensor_reg.dtype = dtype;
+ return instr;
+}
+
+Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
+ RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::AllocStorage;
+ instr.dst = dst;
+ instr.alloc_storage.allocation_size = size;
+ instr.alloc_storage.alignment = alignment;
+ instr.alloc_storage.dtype_hint = dtype_hint;
+ return instr;
+}
+
+Instruction Instruction::ShapeOf(RegName tensor, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::ShapeOf;
+ instr.dst = dst;
+ instr.shape_of.tensor = tensor;
+ return instr;
+}
+
+Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::ReshapeTensor;
+ instr.dst = dst;
+ instr.reshape_tensor.tensor = tensor;
+ instr.reshape_tensor.newshape = newshape;
+ return instr;
+}
+
+Instruction Instruction::AllocADT(Index tag, Index num_fields,
+ const std::vector<RegName>& datatype_fields, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::AllocADT;
+ instr.dst = dst;
+ instr.constructor_tag = tag;
+ instr.num_fields = num_fields;
+ instr.datatype_fields = new RegName[num_fields];
+ for (Index i = 0; i < num_fields; ++i) {
+ instr.datatype_fields[i] = datatype_fields[i];
+ }
+ return instr;
+}
+
+Instruction Instruction::AllocClosure(Index func_index, Index free_vars,
+ const std::vector<RegName>& free_var_register, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::AllocClosure;
+ instr.dst = dst;
+ instr.clo_index = func_index;
+ instr.num_freevar = free_vars;
+ instr.free_vars = new RegName[instr.num_freevar];
+ for (Index i = 0; i < instr.num_freevar; ++i) {
+ instr.free_vars[i] = free_var_register[i];
+ }
+ return instr;
+}
+
+Instruction Instruction::GetField(RegName object, Index field_index, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::GetField;
+ instr.dst = dst;
+ instr.object = object;
+ instr.field_index = field_index;
+ return instr;
+}
+
+Instruction Instruction::GetTag(RegName object, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::GetTag;
+ instr.dst = dst;
+ instr.get_tag.object = object;
+ return instr;
+}
+
+Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) {
+ Instruction instr;
+ instr.op = Opcode::If;
+ instr.if_op.test = test;
+ instr.if_op.target = target;
+ instr.if_op.true_offset = true_branch;
+ instr.if_op.false_offset = false_branch;
+ return instr;
+}
+
+Instruction Instruction::Goto(Index pc_offset) {
+ Instruction instr;
+ instr.op = Opcode::Goto;
+ instr.pc_offset = pc_offset;
+ return instr;
+}
+
+Instruction Instruction::Invoke(Index func_index, const std::vector<RegName>& args_registers,
+ RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::Invoke;
+ instr.dst = dst;
+ instr.func_index = func_index;
+ instr.num_args = args_registers.size();
+ instr.invoke_args_registers = new RegName[instr.num_args];
+ for (Index i = 0; i < instr.num_args; ++i) {
+ instr.invoke_args_registers[i] = args_registers[i];
+ }
+ return instr;
+}
+
+Instruction Instruction::InvokeClosure(RegName closure, const std::vector<RegName>& args,
+ RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::InvokeClosure;
+ instr.dst = dst;
+ instr.closure = closure;
+ instr.num_closure_args = args.size();
+ instr.closure_args = new RegName[args.size()];
+ for (size_t i = 0; i < args.size(); ++i) {
+ instr.closure_args[i] = args[i];
+ }
+ return instr;
+}
+
+Instruction Instruction::LoadConst(Index const_index, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::LoadConst;
+ instr.dst = dst;
+ instr.const_index = const_index;
+ return instr;
+}
+
+Instruction Instruction::LoadConsti(Index val, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::LoadConsti;
+ instr.dst = dst;
+ instr.load_consti.val = val;
+ return instr;
+}
+
+Instruction Instruction::Move(RegName src, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::Move;
+ instr.dst = dst;
+ instr.from = src;
+ return instr;
+}
+
+void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
+ switch (dtype.code) {
+ case kDLInt:
+ os << "int";
+ break;
+ case kDLUInt:
+ os << "uint";
+ break;
+ case kDLFloat:
+ os << "float";
+ break;
+ }
+
+ os << int(dtype.bits);
+ if (dtype.lanes != 1) {
+ os << "x" << dtype.lanes;
+ }
+}
+
+template <typename T>
+std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
+ if (cnt == 0) {
+ return "";
+ }
+ std::ostringstream oss;
+ oss << items[offset];
+ for (int i = 1; i < cnt; ++i) {
+ oss << delim << items[offset + i];
+ }
+ return oss.str();
+}
+
+void InstructionPrint(std::ostream& os, const Instruction& instr) {
+ switch (instr.op) {
+ case Opcode::Move: {
+ os << "move $" << instr.dst << " $" << instr.from;
+ break;
+ }
+ case Opcode::Ret: {
+ os << "ret $" << instr.result;
+ break;
+ }
+ case Opcode::Fatal: {
+ os << "fatal";
+ break;
+ }
+ case Opcode::InvokePacked: {
+ os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $"
+ << StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ", $")
+ << ", out: $"
+ << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size, instr.output_size,
+ ", $")
+ << ")";
+ break;
+ }
+ case Opcode::AllocTensor: {
+ os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " $"
+ << instr.alloc_tensor.offset << " ["
+ << StrJoin<int64_t>(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] ";
+ DLDatatypePrint(os, instr.alloc_tensor.dtype);
+ break;
+ }
+ case Opcode::AllocTensorReg: {
+ os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $"
+ << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " ";
+ DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
+ break;
+ }
+ case Opcode::AllocADT: {
+ os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
+ << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
+ break;
+ }
+ case Opcode::AllocClosure: {
+ os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($"
+ << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$") << ")";
+ break;
+ }
+ case Opcode::If: {
+ os << "if "
+ << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset
+ << " " << instr.if_op.false_offset;
+ break;
+ }
+ case Opcode::Invoke: {
+ os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
+ << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")";
+ break;
+ }
+ case Opcode::InvokeClosure: {
+ os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
+ << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$") << ")";
+ break;
+ }
+ case Opcode::LoadConst: {
+ os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
+ break;
+ }
+ case Opcode::LoadConsti: {
+ os << "load_consti $" << instr.dst << " " << instr.load_consti.val;
+ break;
+ }
+ case Opcode::GetField: {
+ os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]";
+ break;
+ }
+ case Opcode::GetTag: {
+ os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
+ break;
+ }
+ case Opcode::Goto: {
+ os << "goto " << instr.pc_offset;
+ break;
+ }
+ case Opcode::AllocStorage: {
+ os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " "
+ << instr.alloc_storage.alignment << " "
+ << DLDataType2String(instr.alloc_storage.dtype_hint);
+ break;
+ }
+ case Opcode::ShapeOf: {
+ os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
+ break;
+ }
+ case Opcode::ReshapeTensor: {
+ os << "reshape_tensor $" << instr.dst << " $" << instr.reshape_tensor.tensor << " $"
+ << instr.reshape_tensor.newshape;
+ break;
+ }
+ default:
+ LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
+ break;
+ }
+}
+
+std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
+ InstructionPrint(os, instr);
+ return os;
+}
+
+} // namespace vm
+} // namespace runtime
+} // namespace tvm
diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc
index 4944778..9987621 100644
--- a/src/runtime/vm/executable.cc
+++ b/src/runtime/vm/executable.cc
@@ -25,7 +25,8 @@
#include <dmlc/memory_io.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/executable.h>
+#include <tvm/runtime/vm/vm.h>
#include <algorithm>
#include <iomanip>
diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc
index 4c220bb..4d443d9 100644
--- a/src/runtime/vm/memory_manager.cc
+++ b/src/runtime/vm/memory_manager.cc
@@ -21,7 +21,7 @@
* \file tvm/runtime/vm/memory_manager.cc
* \brief Allocate and manage memory for the runtime.
*/
-#include "memory_manager.h"
+#include <tvm/runtime/vm/memory_manager.h>
#include <memory>
#include <utility>
@@ -37,7 +37,7 @@ static void BufferDeleter(Object* obj) {
auto* ptr = static_cast<NDArray::Container*>(obj);
CHECK(ptr->manager_ctx != nullptr);
Buffer* buffer = reinterpret_cast<Buffer*>(ptr->manager_ctx);
- MemoryManager::Global()->GetAllocator(buffer->ctx)->Free(*(buffer));
+ MemoryManager::GetAllocator(buffer->ctx)->Free(*(buffer));
delete buffer;
delete ptr;
}
@@ -114,15 +114,49 @@ MemoryManager* MemoryManager::Global() {
return &memory_manager;
}
+Allocator* MemoryManager::GetOrCreateAllocator(TVMContext ctx, AllocatorType type) {
+ MemoryManager* m = MemoryManager::Global();
+ std::lock_guard<std::mutex> lock(m->mu_);
+ if (m->allocators_.find(ctx) == m->allocators_.end()) {
+ std::unique_ptr<Allocator> alloc;
+ switch (type) {
+ case kNaive: {
+ DLOG(INFO) << "New naive allocator for " << DeviceName(ctx.device_type) << "("
+ << ctx.device_id << ")";
+ alloc.reset(new NaiveAllocator(ctx));
+ break;
+ }
+ case kPooled: {
+ DLOG(INFO) << "New pooled allocator for " << DeviceName(ctx.device_type) << "("
+ << ctx.device_id << ")";
+ alloc.reset(new PooledAllocator(ctx));
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unknown allocator type: " << type;
+ }
+ auto ret = alloc.get();
+ m->allocators_.emplace(ctx, std::move(alloc));
+ return ret;
+ }
+ auto alloc = m->allocators_.at(ctx).get();
+ if (alloc->type() != type) {
+ LOG(WARNING) << "The type of existing allocator for " << DeviceName(ctx.device_type) << "("
+ << ctx.device_id << ") is different from the request type (" << alloc->type()
+ << " vs " << type << ")";
+ }
+ return alloc;
+}
+
Allocator* MemoryManager::GetAllocator(TVMContext ctx) {
- std::lock_guard<std::mutex> lock(mu_);
- if (allocators_.find(ctx) == allocators_.end()) {
- DLOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id
- << ")";
- std::unique_ptr<Allocator> alloc(new NaiveAllocator(ctx));
- allocators_.emplace(ctx, std::move(alloc));
+ MemoryManager* m = MemoryManager::Global();
+ std::lock_guard<std::mutex> lock(m->mu_);
+ auto it = m->allocators_.find(ctx);
+ if (it == m->allocators_.end()) {
+ LOG(FATAL) << "Allocator for " << DeviceName(ctx.device_type) << "(" << ctx.device_id
+ << ") has not been created yet.";
}
- return allocators_.at(ctx).get();
+ return it->second.get();
}
NDArray Allocator::Empty(std::vector<int64_t> shape, DLDataType dtype, DLContext ctx) {
diff --git a/src/runtime/vm/naive_allocator.h b/src/runtime/vm/naive_allocator.h
index 5ac2ca6..301acf8 100644
--- a/src/runtime/vm/naive_allocator.h
+++ b/src/runtime/vm/naive_allocator.h
@@ -24,18 +24,17 @@
#define TVM_RUNTIME_VM_NAIVE_ALLOCATOR_H_
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/vm/memory_manager.h>
#include <atomic>
-#include "memory_manager.h"
-
namespace tvm {
namespace runtime {
namespace vm {
class NaiveAllocator final : public Allocator {
public:
- explicit NaiveAllocator(TVMContext ctx) : Allocator(), used_memory_(0), ctx_(ctx) {}
+ explicit NaiveAllocator(TVMContext ctx) : Allocator(kNaive), used_memory_(0), ctx_(ctx) {}
Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
Buffer buf;
diff --git a/src/runtime/vm/pooled_allocator.h b/src/runtime/vm/pooled_allocator.h
index e09628f..4226ef7 100644
--- a/src/runtime/vm/pooled_allocator.h
+++ b/src/runtime/vm/pooled_allocator.h
@@ -24,14 +24,13 @@
#define TVM_RUNTIME_VM_POOLED_ALLOCATOR_H_
#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/vm/memory_manager.h>
#include <atomic>
#include <mutex>
#include <unordered_map>
#include <vector>
-#include "memory_manager.h"
-
namespace tvm {
namespace runtime {
namespace vm {
@@ -41,7 +40,7 @@ class PooledAllocator final : public Allocator {
static constexpr size_t kDefaultPageSize = 4096;
explicit PooledAllocator(TVMContext ctx, size_t page_size = kDefaultPageSize)
- : Allocator(), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
+ : Allocator(kPooled), page_size_(page_size), used_memory_(0), ctx_(ctx) {}
~PooledAllocator() { ReleaseAll(); }
diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc
index 6e4682d..7273b56 100644
--- a/src/runtime/vm/profiler/vm.cc
+++ b/src/runtime/vm/profiler/vm.cc
@@ -25,7 +25,6 @@
#include "vm.h"
#include <tvm/runtime/registry.h>
-#include <tvm/runtime/vm.h>
#include <algorithm>
#include <chrono>
diff --git a/src/runtime/vm/profiler/vm.h b/src/runtime/vm/profiler/vm.h
index c286828..797d414 100644
--- a/src/runtime/vm/profiler/vm.h
+++ b/src/runtime/vm/profiler/vm.h
@@ -25,7 +25,7 @@
#ifndef TVM_RUNTIME_VM_PROFILER_VM_H_
#define TVM_RUNTIME_VM_PROFILER_VM_H_
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/vm.h>
#include <memory>
#include <string>
diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h
index 8bd1f86..d52b73d 100644
--- a/src/runtime/vm/serialize_util.h
+++ b/src/runtime/vm/serialize_util.h
@@ -26,7 +26,7 @@
#include <dmlc/common.h>
#include <dmlc/memory_io.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/executable.h>
#include <functional>
#include <string>
diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc
index 24fc110..9af5202 100644
--- a/src/runtime/vm/vm.cc
+++ b/src/runtime/vm/vm.cc
@@ -19,32 +19,30 @@
/*!
* \file src/runtime/vm/vm.cc
- * \brief The Relay virtual machine.
+ * \brief The Relay virtual machine runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
-#include <tvm/runtime/vm.h>
+#include <tvm/runtime/vm/vm.h>
#include <tvm/support/logging.h>
#include <algorithm>
#include <chrono>
#include <iostream>
-#include <sstream>
#include <stdexcept>
#include <vector>
-#include "memory_manager.h"
-#include "naive_allocator.h"
-
using namespace tvm::runtime;
namespace tvm {
namespace runtime {
namespace vm {
+TVM_REGISTER_OBJECT_TYPE(VMClosureObj);
+
VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
auto ptr = make_object<VMClosureObj>();
ptr->func_index = func_index;
@@ -52,588 +50,6 @@ VMClosure::VMClosure(size_t func_index, std::vector<ObjectRef> free_vars) {
data_ = std::move(ptr);
}
-inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
- // We could put cache in here, from ctx to storage allocator.
- auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
- auto alloc = MemoryManager::Global()->GetAllocator(ctx);
- DCHECK(alloc != nullptr) << "allocator must not null";
- storage_obj->buffer = alloc->Alloc(size, alignment, dtype_hint);
- return Storage(storage_obj);
-}
-
-Instruction::Instruction() {}
-
-template <typename T>
-static T* Duplicate(T* src, Index size) {
- auto dst = new T[size];
- std::copy(src, src + size, dst);
- return dst;
-}
-
-Instruction::Instruction(const Instruction& instr) {
- this->op = instr.op;
- this->dst = instr.dst;
-
- switch (instr.op) {
- case Opcode::Move:
- this->from = instr.from;
- return;
- case Opcode::Fatal:
- return;
- case Opcode::Ret:
- this->result = instr.result;
- return;
- case Opcode::AllocTensor:
- this->alloc_tensor.storage = instr.alloc_tensor.storage;
- this->alloc_tensor.offset = instr.alloc_tensor.offset;
- this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
- this->alloc_tensor.shape =
- Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
- this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
- return;
- case Opcode::AllocTensorReg:
- this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
- this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset;
- this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
- this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
- return;
- case Opcode::AllocADT:
- this->constructor_tag = instr.constructor_tag;
- this->num_fields = instr.num_fields;
- this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
- return;
- case Opcode::AllocClosure:
- this->clo_index = instr.clo_index;
- this->num_freevar = instr.num_freevar;
- this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
- return;
- case Opcode::InvokePacked:
- this->packed_index = instr.packed_index;
- this->arity = instr.arity;
- this->output_size = instr.output_size;
- this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
- return;
- case Opcode::InvokeClosure:
- this->closure = instr.closure;
- this->num_closure_args = instr.num_closure_args;
- this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
- return;
- case Opcode::Invoke:
- this->func_index = instr.func_index;
- this->num_args = instr.num_args;
- this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
- return;
- case Opcode::If:
- this->if_op = instr.if_op;
- return;
- case Opcode::LoadConst:
- this->const_index = instr.const_index;
- return;
- case Opcode::LoadConsti:
- this->load_consti = instr.load_consti;
- return;
- case Opcode::GetField:
- this->object = instr.object;
- this->field_index = instr.field_index;
- return;
- case Opcode::GetTag:
- this->get_tag = instr.get_tag;
- return;
- case Opcode::Goto:
- this->pc_offset = instr.pc_offset;
- return;
- case Opcode::AllocStorage:
- this->alloc_storage = instr.alloc_storage;
- return;
- case Opcode::ShapeOf:
- this->shape_of.tensor = instr.shape_of.tensor;
- return;
- case Opcode::ReshapeTensor:
- this->reshape_tensor.tensor = instr.reshape_tensor.tensor;
- this->reshape_tensor.newshape = instr.reshape_tensor.newshape;
- return;
- default:
- std::ostringstream out;
- out << "Invalid instruction " << static_cast<int>(instr.op);
- throw std::runtime_error(out.str());
- }
-}
-
-template <typename T>
-static inline void FreeIf(T* t) {
- if (t != nullptr) {
- delete t;
- }
-}
-
-Instruction& Instruction::operator=(const Instruction& instr) {
- this->op = instr.op;
- this->dst = instr.dst;
-
- switch (instr.op) {
- case Opcode::Move:
- this->from = instr.from;
- return *this;
- case Opcode::Fatal:
- return *this;
- case Opcode::LoadConsti:
- this->load_consti = instr.load_consti;
- return *this;
- case Opcode::Ret:
- this->result = instr.result;
- return *this;
- case Opcode::AllocTensor:
- this->alloc_tensor.storage = this->alloc_tensor.storage;
- this->alloc_tensor.offset = instr.alloc_tensor.offset;
- this->alloc_tensor.ndim = instr.alloc_tensor.ndim;
- this->alloc_tensor.shape =
- Duplicate<int64_t>(instr.alloc_tensor.shape, instr.alloc_tensor.ndim);
- this->alloc_tensor.dtype = instr.alloc_tensor.dtype;
- return *this;
- case Opcode::AllocTensorReg:
- this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage;
- this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset;
- this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register;
- this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype;
- return *this;
- case Opcode::AllocADT:
- this->constructor_tag = instr.constructor_tag;
- this->num_fields = instr.num_fields;
- FreeIf(this->datatype_fields);
- this->datatype_fields = Duplicate<RegName>(instr.datatype_fields, instr.num_fields);
- return *this;
- case Opcode::AllocClosure:
- this->clo_index = instr.clo_index;
- this->num_freevar = instr.num_freevar;
- FreeIf(this->free_vars);
- this->free_vars = Duplicate<RegName>(instr.free_vars, instr.num_freevar);
- return *this;
- case Opcode::InvokePacked:
- this->packed_index = instr.packed_index;
- this->arity = instr.arity;
- this->output_size = instr.output_size;
- FreeIf(this->packed_args);
- this->packed_args = Duplicate<RegName>(instr.packed_args, instr.arity);
- return *this;
- case Opcode::InvokeClosure:
- this->closure = instr.closure;
- this->num_closure_args = instr.num_closure_args;
- FreeIf(this->closure_args);
- this->closure_args = Duplicate<RegName>(instr.closure_args, instr.num_closure_args);
- return *this;
- case Opcode::Invoke:
- this->func_index = instr.func_index;
- this->num_args = instr.num_args;
- FreeIf(this->invoke_args_registers);
- this->invoke_args_registers = Duplicate<RegName>(instr.invoke_args_registers, instr.num_args);
- return *this;
- case Opcode::If:
- this->if_op = instr.if_op;
- return *this;
- case Opcode::LoadConst:
- this->const_index = instr.const_index;
- return *this;
- case Opcode::GetField:
- this->object = instr.object;
- this->field_index = instr.field_index;
- return *this;
- case Opcode::GetTag:
- this->get_tag = instr.get_tag;
- return *this;
- case Opcode::Goto:
- this->pc_offset = instr.pc_offset;
- return *this;
- case Opcode::AllocStorage:
- this->alloc_storage = instr.alloc_storage;
- return *this;
- case Opcode::ShapeOf:
- this->shape_of.tensor = instr.shape_of.tensor;
- return *this;
- default:
- std::ostringstream out;
- out << "Invalid instruction " << static_cast<int>(instr.op);
- throw std::runtime_error(out.str());
- }
-}
-
-Instruction::~Instruction() {
- switch (this->op) {
- case Opcode::Move:
- case Opcode::Ret:
- case Opcode::AllocTensorReg:
- case Opcode::If:
- case Opcode::LoadConst:
- case Opcode::GetField:
- case Opcode::GetTag:
- case Opcode::Goto:
- case Opcode::LoadConsti:
- case Opcode::AllocStorage:
- case Opcode::ShapeOf:
- case Opcode::ReshapeTensor:
- case Opcode::Fatal:
- return;
- case Opcode::AllocTensor:
- delete[] this->alloc_tensor.shape;
- return;
- case Opcode::AllocADT:
- delete[] this->datatype_fields;
- return;
- case Opcode::AllocClosure:
- delete[] this->free_vars;
- return;
- case Opcode::InvokePacked:
- delete[] this->packed_args;
- return;
- case Opcode::InvokeClosure:
- delete[] this->closure_args;
- return;
- case Opcode::Invoke:
- delete[] this->invoke_args_registers;
- return;
- default:
- std::ostringstream out;
- LOG(FATAL) << "Invalid instruction " << static_cast<int>(this->op);
- }
-}
-
-Instruction Instruction::Ret(RegName result) {
- Instruction instr;
- instr.op = Opcode::Ret;
- instr.result = result;
- return instr;
-}
-
-Instruction Instruction::Fatal() {
- Instruction instr;
- instr.op = Opcode::Fatal;
- return instr;
-}
-
-Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size,
- const std::vector<RegName>& args) {
- Instruction instr;
- instr.op = Opcode::InvokePacked;
- instr.packed_index = packed_index;
- instr.arity = arity;
- instr.output_size = output_size;
- instr.packed_args = new RegName[arity];
- for (Index i = 0; i < arity; ++i) {
- instr.packed_args[i] = args[i];
- }
- return instr;
-}
-
-Instruction Instruction::AllocTensor(RegName storage, RegName offset,
- const std::vector<int64_t>& shape, DLDataType dtype,
- RegName dst) {
- Instruction instr;
- instr.op = Opcode::AllocTensor;
- instr.dst = dst;
- instr.alloc_tensor.storage = storage;
- instr.alloc_tensor.offset = offset;
- instr.alloc_tensor.ndim = shape.size();
- instr.alloc_tensor.shape = new int64_t[shape.size()];
- for (size_t i = 0; i < shape.size(); ++i) {
- instr.alloc_tensor.shape[i] = shape[i];
- }
- instr.alloc_tensor.dtype = dtype;
- return instr;
-}
-
-Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register,
- DLDataType dtype, RegName dst) {
- Instruction instr;
- instr.op = Opcode::AllocTensorReg;
- instr.dst = dst;
- instr.alloc_tensor_reg.storage = storage;
- instr.alloc_tensor_reg.offset = offset;
- instr.alloc_tensor_reg.shape_register = shape_register;
- instr.alloc_tensor_reg.dtype = dtype;
- return instr;
-}
-
-Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
- RegName dst) {
- Instruction instr;
- instr.op = Opcode::AllocStorage;
- instr.dst = dst;
- instr.alloc_storage.allocation_size = size;
- instr.alloc_storage.alignment = alignment;
- instr.alloc_storage.dtype_hint = dtype_hint;
- return instr;
-}
-
-Instruction Instruction::ShapeOf(RegName tensor, RegName dst) {
- Instruction instr;
- instr.op = Opcode::ShapeOf;
- instr.dst = dst;
- instr.shape_of.tensor = tensor;
- return instr;
-}
-
-Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName dst) {
- Instruction instr;
- instr.op = Opcode::ReshapeTensor;
- instr.dst = dst;
- instr.reshape_tensor.tensor = tensor;
- instr.reshape_tensor.newshape = newshape;
- return instr;
-}
-
-Instruction Instruction::AllocADT(Index tag, Index num_fields,
- const std::vector<RegName>& datatype_fields, RegName dst) {
- Instruction instr;
- instr.op = Opcode::AllocADT;
- instr.dst = dst;
- instr.constructor_tag = tag;
- instr.num_fields = num_fields;
- instr.datatype_fields = new RegName[num_fields];
- for (Index i = 0; i < num_fields; ++i) {
- instr.datatype_fields[i] = datatype_fields[i];
- }
- return instr;
-}
-
-Instruction Instruction::AllocClosure(Index func_index, Index free_vars,
- const std::vector<RegName>& free_var_register, RegName dst) {
- Instruction instr;
- instr.op = Opcode::AllocClosure;
- instr.dst = dst;
- instr.clo_index = func_index;
- instr.num_freevar = free_vars;
- instr.free_vars = new RegName[instr.num_freevar];
- for (Index i = 0; i < instr.num_freevar; ++i) {
- instr.free_vars[i] = free_var_register[i];
- }
- return instr;
-}
-
-Instruction Instruction::GetField(RegName object, Index field_index, RegName dst) {
- Instruction instr;
- instr.op = Opcode::GetField;
- instr.dst = dst;
- instr.object = object;
- instr.field_index = field_index;
- return instr;
-}
-
-Instruction Instruction::GetTag(RegName object, RegName dst) {
- Instruction instr;
- instr.op = Opcode::GetTag;
- instr.dst = dst;
- instr.get_tag.object = object;
- return instr;
-}
-
-Instruction Instruction::If(RegName test, RegName target, Index true_branch, Index false_branch) {
- Instruction instr;
- instr.op = Opcode::If;
- instr.if_op.test = test;
- instr.if_op.target = target;
- instr.if_op.true_offset = true_branch;
- instr.if_op.false_offset = false_branch;
- return instr;
-}
-
-Instruction Instruction::Goto(Index pc_offset) {
- Instruction instr;
- instr.op = Opcode::Goto;
- instr.pc_offset = pc_offset;
- return instr;
-}
-
-Instruction Instruction::Invoke(Index func_index, const std::vector<RegName>& args_registers,
- RegName dst) {
- Instruction instr;
- instr.op = Opcode::Invoke;
- instr.dst = dst;
- instr.func_index = func_index;
- instr.num_args = args_registers.size();
- instr.invoke_args_registers = new RegName[instr.num_args];
- for (Index i = 0; i < instr.num_args; ++i) {
- instr.invoke_args_registers[i] = args_registers[i];
- }
- return instr;
-}
-
-Instruction Instruction::InvokeClosure(RegName closure, const std::vector<RegName>& args,
- RegName dst) {
- Instruction instr;
- instr.op = Opcode::InvokeClosure;
- instr.dst = dst;
- instr.closure = closure;
- instr.num_closure_args = args.size();
- instr.closure_args = new RegName[args.size()];
- for (size_t i = 0; i < args.size(); ++i) {
- instr.closure_args[i] = args[i];
- }
- return instr;
-}
-
-Instruction Instruction::LoadConst(Index const_index, RegName dst) {
- Instruction instr;
- instr.op = Opcode::LoadConst;
- instr.dst = dst;
- instr.const_index = const_index;
- return instr;
-}
-
-Instruction Instruction::LoadConsti(Index val, RegName dst) {
- Instruction instr;
- instr.op = Opcode::LoadConsti;
- instr.dst = dst;
- instr.load_consti.val = val;
- return instr;
-}
-
-Instruction Instruction::Move(RegName src, RegName dst) {
- Instruction instr;
- instr.op = Opcode::Move;
- instr.dst = dst;
- instr.from = src;
- return instr;
-}
-
-void DLDatatypePrint(std::ostream& os, const DLDataType& dtype) {
- switch (dtype.code) {
- case kDLInt:
- os << "int";
- break;
- case kDLUInt:
- os << "uint";
- break;
- case kDLFloat:
- os << "float";
- break;
- }
-
- os << int(dtype.bits);
- if (dtype.lanes != 1) {
- os << "x" << dtype.lanes;
- }
-}
-
-template <typename T>
-std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") {
- if (cnt == 0) {
- return "";
- }
- std::ostringstream oss;
- oss << items[offset];
- for (int i = 1; i < cnt; ++i) {
- oss << delim << items[offset + i];
- }
- return oss.str();
-}
-
-void InstructionPrint(std::ostream& os, const Instruction& instr) {
- switch (instr.op) {
- case Opcode::Move: {
- os << "move $" << instr.dst << " $" << instr.from;
- break;
- }
- case Opcode::Ret: {
- os << "ret $" << instr.result;
- break;
- }
- case Opcode::Fatal: {
- os << "fatal";
- break;
- }
- case Opcode::InvokePacked: {
- os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $"
- << StrJoin<RegName>(instr.packed_args, 0, instr.arity - instr.output_size, ", $")
- << ", out: $"
- << StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size, instr.output_size,
- ", $")
- << ")";
- break;
- }
- case Opcode::AllocTensor: {
- os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " $"
- << instr.alloc_tensor.offset << " ["
- << StrJoin<int64_t>(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] ";
- DLDatatypePrint(os, instr.alloc_tensor.dtype);
- break;
- }
- case Opcode::AllocTensorReg: {
- os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $"
- << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " ";
- DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
- break;
- }
- case Opcode::AllocADT: {
- os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
- << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
- break;
- }
- case Opcode::AllocClosure: {
- os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($"
- << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$") << ")";
- break;
- }
- case Opcode::If: {
- os << "if "
- << "$" << instr.if_op.test << " $" << instr.if_op.target << " " << instr.if_op.true_offset
- << " " << instr.if_op.false_offset;
- break;
- }
- case Opcode::Invoke: {
- os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
- << StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$") << ")";
- break;
- }
- case Opcode::InvokeClosure: {
- os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
- << StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$") << ")";
- break;
- }
- case Opcode::LoadConst: {
- os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
- break;
- }
- case Opcode::LoadConsti: {
- os << "load_consti $" << instr.dst << " " << instr.load_consti.val;
- break;
- }
- case Opcode::GetField: {
- os << "get_field $" << instr.dst << " $" << instr.object << "[" << instr.field_index << "]";
- break;
- }
- case Opcode::GetTag: {
- os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
- break;
- }
- case Opcode::Goto: {
- os << "goto " << instr.pc_offset;
- break;
- }
- case Opcode::AllocStorage: {
- os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " "
- << instr.alloc_storage.alignment << " "
- << DLDataType2String(instr.alloc_storage.dtype_hint);
- break;
- }
- case Opcode::ShapeOf: {
- os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
- break;
- }
- case Opcode::ReshapeTensor: {
- os << "reshape_tensor $" << instr.dst << " $" << instr.reshape_tensor.tensor << " $"
- << instr.reshape_tensor.newshape;
- break;
- }
- default:
- LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
- break;
- }
-}
-
-std::ostream& operator<<(std::ostream& os, const Instruction& instr) {
- InstructionPrint(os, instr);
- return os;
-}
-
void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
os << vm_func.name << ": " << std::endl;
for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
@@ -707,16 +123,19 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name,
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
- CHECK_EQ(args.size() % 2, 0);
+ CHECK_EQ(args.size() % 3, 0);
std::vector<TVMContext> contexts;
- for (int i = 0; i < args.size() / 2; ++i) {
+ std::vector<AllocatorType> alloc_types;
+ for (int i = 0; i < args.size() / 3; ++i) {
TVMContext ctx;
- int device_type = args[i * 2];
+ int device_type = args[i * 3];
ctx.device_type = DLDeviceType(device_type);
- ctx.device_id = args[i * 2 + 1];
+ ctx.device_id = args[i * 3 + 1];
+ int type = args[i * 3 + 2];
contexts.push_back(ctx);
+ alloc_types.push_back(AllocatorType(type));
}
- this->Init(contexts);
+ this->Init(contexts, alloc_types);
});
} else if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
@@ -793,9 +212,6 @@ ObjectRef VirtualMachine::Invoke(const VMFunction& func, const std::vector<Objec
InvokeGlobal(func, args);
RunLoop();
- // TODO(wweic) ctx could be obtained from the ctxs list.
- auto alloc = MemoryManager::Global()->GetAllocator(ctxs_[0]);
- DLOG(INFO) << "Memory used: " << alloc->UsedMemory() << " B";
return return_register_;
}
@@ -864,7 +280,15 @@ void VirtualMachine::LoadExecutable(const Executable* exec) {
}
}
-void VirtualMachine::Init(const std::vector<TVMContext>& ctxs) { ctxs_ = ctxs; }
+void VirtualMachine::Init(const std::vector<TVMContext>& ctxs,
+ const std::vector<AllocatorType>& alloc_types) {
+ CHECK_EQ(ctxs.size(), alloc_types.size());
+ ctxs_ = ctxs;
+ for (size_t i = 0; i < ctxs.size(); ++i) {
+ auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]);
+ allocators_.emplace(ctxs[i], alloc);
+ }
+}
inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) {
frames_.back().register_file[r] = val;
@@ -1090,7 +514,13 @@ void VirtualMachine::RunLoop() {
DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment
<< "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);
- auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
+ auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
+ auto it = allocators_.find(ctxs_[0]);
+ CHECK(it != allocators_.end())
+ << "Did you forget to init the VirtualMachine with contexts?";
+ auto alloc = it->second;
+ storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint);
+ Storage storage(storage_obj);
WriteRegister(instr.dst, storage);
pc_++;
goto main_loop;
diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py
index 3b9d4d4..5c6bd6f 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -135,8 +135,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
elif mode == 'vm':
with tvm.transform.PassContext(opt_level=opt_level, disabled_pass=disabled_pass):
vm_exec = relay.vm.compile(mod, target="llvm", params=params)
- vm = VirtualMachine(vm_exec)
- vm.init(tvm.cpu())
+ vm = VirtualMachine(vm_exec, tvm.cpu())
inputs = {}
for e, i in zip(input_node, input_data):
inputs[e] = tvm.nd.array(i)
diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py
index a6e05be..80e9e41 100644
--- a/tests/python/relay/benchmarking/benchmark_vm.py
+++ b/tests/python/relay/benchmarking/benchmark_vm.py
@@ -61,8 +61,7 @@ def benchmark_execution(mod,
number=2, repeat=20):
with tvm.transform.PassContext(opt_level=3):
exe = vm.compile(mod, target, params=params)
- rly_vm = vm_rt.VirtualMachine(exe)
- rly_vm.init(ctx)
+ rly_vm = vm_rt.VirtualMachine(exe, ctx)
result = rly_vm.run(data)
if measure:
diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py
index 6771bd1..216d23e 100644
--- a/tests/python/relay/test_external_codegen.py
+++ b/tests/python/relay/test_external_codegen.py
@@ -55,8 +55,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
- vm = runtime.vm.VirtualMachine(exe)
- vm.init(ctx)
+ vm = runtime.vm.VirtualMachine(exe, ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py
index a886692..cf3b2b2 100644
--- a/tests/python/relay/test_json_runtime.py
+++ b/tests/python/relay/test_json_runtime.py
@@ -71,8 +71,7 @@ def check_result(mod,
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
exe = runtime.vm.Executable.load_exec(code, lib)
- vm = runtime.vm.VirtualMachine(exe)
- vm.init(ctx)
+ vm = runtime.vm.VirtualMachine(exe, ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), ref_result, rtol=tol, atol=tol)
diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py
index 273c27b..46989da 100644
--- a/tests/python/relay/test_pass_annotate_target.py
+++ b/tests/python/relay/test_pass_annotate_target.py
@@ -56,8 +56,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
- vm = runtime.vm.VirtualMachine(exe)
- vm.init(ctx)
+ vm = runtime.vm.VirtualMachine(exe, ctx)
out = vm.run(**map_inputs)
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)
diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py
index 84474f6..58bb16d 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -200,8 +200,7 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm",
code, lib = exe.save()
lib = update_lib(lib)
exe = runtime.vm.Executable.load_exec(code, lib)
- vm = runtime.vm.VirtualMachine(exe)
- vm.init(ctx)
+ vm = runtime.vm.VirtualMachine(exe, ctx)
outs = vm.run(**map_inputs)
outs = outs if isinstance(outs, runtime.container.ADT) else [outs]
results = result if isinstance(result, list) else [result]
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 91214cb..d3bb084 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -45,7 +45,6 @@ def check_result(args, expected_result, mod=None):
if "cuda" in target:
continue
vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
-
rts_result = vm.evaluate()(*args)
tvm.testing.assert_allclose(expected_result, rts_result.asnumpy())
@@ -57,8 +56,7 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
assert isinstance(f, tvm.IRModule), "expected expression or module"
mod = f
exe = relay.vm.compile(mod, target)
- vm = runtime.vm.VirtualMachine(exe)
- vm.init(ctx)
+ vm = runtime.vm.VirtualMachine(exe, ctx)
return vm.invoke("main", *args)
def vmobj_to_list(o):
diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py
index 95e6c6f..d1bcdcc 100644
--- a/tests/python/relay/test_vm_serialization.py
+++ b/tests/python/relay/test_vm_serialization.py
@@ -45,8 +45,7 @@ def get_serialized_output(mod, *data, params=None, target="llvm",
exe = create_exec(mod, target, params=params)
code, lib = exe.save()
des_exec = _vm.Executable.load_exec(code, lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(ctx)
+ des_vm = _vm.VirtualMachine(des_exec, ctx)
result = des_vm.run(*data)
return result
@@ -135,8 +134,7 @@ def test_save_load():
# deserialize.
des_exec = _vm.Executable.load_exec(loaded_code, loaded_lib)
- des_vm = _vm.VirtualMachine(des_exec)
- des_vm.init(tvm.cpu())
+ des_vm = _vm.VirtualMachine(des_exec, tvm.cpu())
res = des_vm.run(x_data)
tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data)
diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py
index 064b733..97b54c6 100644
--- a/tests/python/unittest/test_runtime_vm_profiler.py
+++ b/tests/python/unittest/test_runtime_vm_profiler.py
@@ -29,8 +29,7 @@ def test_basic():
if not profiler_vm.enabled():
return
exe = relay.vm.compile(mod, target, params=params)
- vm = profiler_vm.VirtualMachineProfiler(exe)
- vm.init(ctx)
+ vm = profiler_vm.VirtualMachineProfiler(exe, ctx)
data = np.random.rand(1, 3, 224, 224).astype('float32')
res = vm.invoke("main", [data])