You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/06/29 23:16:47 UTC
[tvm] branch main updated: Decoupling AOT from graph memory planner
(#8096)
This is an automated email from the ASF dual-hosted git repository.
areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b803bab Decoupling AOT from graph memory planner (#8096)
b803bab is described below
commit b803bab4a3e072d00ca08d0dcf9a5585194b98d4
Author: Giuseppe Rossini <gi...@arm.com>
AuthorDate: Wed Jun 30 00:16:29 2021 +0100
Decoupling AOT from graph memory planner (#8096)
* Fix an issue with storage-rewrite pass and packed functions
Change-Id: I13888471d4b8927a4012d6a8e749fb7a8935dd77
* Rebasing
Change-Id: I7aa12e0217b8a2e1ff2a97a7c5fdda6b7597ae64
* Addressing comments
Change-Id: If9f1ee190690f9a810fe41eb1933d736f1eb4ec3
* Add a pass to legalize packed calls
Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436
* Add a unit test for the legalization pass
Change-Id: I5b0d75380ff660dd5a0acf5b14fa84bb992fbec4
* rebasing
Change-Id: I52ceab5cf6e9b54390cb36c18dbb8e22505d8e18
* Use common StorageInfo
Change-Id: Ia8b7de1373f167ca7d0d69a99846d417405bbe48
---
include/tvm/tir/transform.h | 5 +
python/tvm/tir/transform/transform.py | 11 +
src/relay/backend/aot_executor_codegen.cc | 354 ++++++++++++++-------
src/tir/transforms/ir_utils.h | 23 ++
src/tir/transforms/legalize_packed_calls.cc | 121 +++++++
src/tir/transforms/lower_tvm_builtin.cc | 10 -
tests/python/relay/aot/aot_test_utils.py | 47 ++-
tests/python/relay/aot/test_crt_aot.py | 48 +++
.../unittest/test_aot_legalize_packed_call.py | 80 +++++
9 files changed, 573 insertions(+), 126 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 2113d58..5ee847e 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -419,6 +419,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
TVM_DLL Pass CompactBufferAllocation();
/*!
+ * This pass legalizes packed calls by wrapping their arguments into TVMValues
+ */
+TVM_DLL Pass LegalizePackedCalls();
+
+/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
* ensure that the flattened TIR can not be scheduled again.
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 8a32a7e..51330f8 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -451,6 +451,17 @@ def LowerTVMBuiltin():
return _ffi_api.LowerTVMBuiltin()
+def LegalizePackedCalls():
+ """Legalize packed calls to have its arguments wrapped in TVMValues
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The result pass
+ """
+ return _ffi_api.LegalizePackedCalls()
+
+
def LowerIntrin():
"""Lower target specific intrinsic calls.
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 93935af..9b495ad 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -31,6 +31,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
+#include <tvm/tir/transform.h>
#include <algorithm>
#include <list>
@@ -46,50 +47,175 @@ namespace backend {
using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap =
+ std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
public:
- explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map)
- : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+ // run the visitor on a function.
+ void Run(const Function& func) {
+ node_device_map_ = CollectDeviceInfo(func);
- IntegerArray FindReturnSid(Function func) {
- VisitExpr(func->body);
- return return_sid_;
+ for (Expr param : func->params) {
+ CreateStorage(param.operator->());
+ }
+
+ GetStorage(func->body);
}
- protected:
- void AssignReturnSid(Expr e) {
- auto iter = storage_device_map_.find(e);
- if (iter != storage_device_map_.end()) {
- return_sid_ = (*iter).second[0];
+ std::vector<int> GetReturnIds() const { return return_ids_; }
+
+ StorageMap GetStorageMap() const { return storage_device_map_; }
+
+ void VisitExpr_(const ConstantNode* op) final {
+ CreateStorage(op);
+ AssignReturnSid(GetRef<Expr>(op));
+ }
+
+ void VisitExpr_(const CallNode* op) final {
+ // create token for the call node.
+ CreateStorage(op);
+ for (Expr arg : op->args) {
+ GetStorage(arg);
}
+ AssignReturnSid(GetRef<Expr>(op));
}
- void VisitExpr_(const ConstantNode* cn) override {
- ExprVisitor::VisitExpr_(cn);
- AssignReturnSid(GetRef<Expr>(cn));
+ void VisitExpr_(const VarNode* op) final {
+ ExprVisitor::VisitExpr_(op);
+ AssignReturnSid(GetRef<Expr>(op));
}
- void VisitExpr_(const VarNode* vn) override {
- ExprVisitor::VisitExpr_(vn);
- AssignReturnSid(GetRef<Expr>(vn));
+ void VisitExpr_(const FunctionNode* op) final {
+ // do not recurse into sub function.
}
- void VisitExpr_(const CallNode* cn) override {
- ExprVisitor::VisitExpr_(cn);
- AssignReturnSid(GetRef<Expr>(cn));
+ void VisitExpr_(const GlobalVarNode* op) final {
+ // Do nothing.
}
- void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+ void VisitExpr_(const OpNode* op) final {
+ // Do nothing.
+ }
- void VisitExpr_(const TupleNode* tn) override {
- ExprVisitor::VisitExpr_(tn);
- AssignReturnSid(GetRef<Expr>(tn));
+ void VisitExpr_(const TupleNode* op) final {
+ std::vector<int64_t> storage_ids;
+ std::vector<DLDeviceType> device_types;
+ std::vector<int64_t> storage_sizes_in_bytes;
+ Expr expr = GetRef<Expr>(op);
+ for (Expr field : op->fields) {
+ auto sid = GetStorage(field);
+ storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
+ device_types.insert(device_types.end(), sid->device_types.begin(), sid->device_types.end());
+ storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
+ sid->storage_sizes_in_bytes.begin(),
+ sid->storage_sizes_in_bytes.end());
+ }
+ storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes);
+ AssignReturnSid(expr);
}
+ void VisitExpr_(const TupleGetItemNode* op) final {
+ Expr expr = GetRef<Expr>(op);
+ auto sids = GetStorage(op->tuple);
+ ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
+ storage_device_map_[expr] =
+ StorageInfo({sids->storage_ids[op->index]}, {sids->device_types[op->index]},
+ {sids->storage_sizes_in_bytes[op->index]});
+ AssignReturnSid(expr);
+ }
+
+ void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
+
+ void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; }
+
private:
- Map<Expr, Array<IntegerArray>> storage_device_map_;
- IntegerArray return_sid_;
+ void AssignReturnSid(Expr e) {
+ if (storage_device_map_.find(e) != storage_device_map_.end()) {
+ StorageInfo& sinfo = storage_device_map_[e];
+ return_ids_.clear();
+ for (auto sid : sinfo->storage_ids) {
+ return_ids_.push_back(sid);
+ }
+ }
+ }
+ /*!
+ * \brief ceil(size/word_size) to get number of words.
+ * \param size The original size.
+ * \param word_size The element size.
+ */
+ static size_t DivRoundUp(size_t size, size_t word_size) {
+ return (size + word_size - 1) / word_size;
+ }
+ /*!
+ * \brief Get the memory requirement.
+ * \param prototype The prototype token.
+ * \return The required memory size.
+ */
+ size_t GetMemorySizeBytes(const TensorTypeNode* ttype) {
+ ICHECK(ttype != nullptr);
+ size_t size = 1;
+ for (IndexExpr dim : ttype->shape) {
+ const int64_t* pval = tir::as_const_int(dim);
+ ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
+ ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
+ size *= static_cast<size_t>(pval[0]);
+ }
+ size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
+ return size;
+ }
+ /*!
+ * \brief Get the necessary storage for the expression.
+ * \param expr The expression.
+ * \return The corresponding token.
+ */
+ StorageInfo GetStorage(const Expr& expr) {
+ this->VisitExpr(expr);
+ auto it = storage_device_map_.find(expr);
+ ICHECK(it != storage_device_map_.end());
+ return it->second;
+ }
+
+ /*!
+ * \brief Create storage for the expression.
+ * \param expr The expression.
+ */
+ void CreateStorage(const ExprNode* op) {
+ std::vector<int64_t> storage_ids;
+ std::vector<DLDeviceType> device_types;
+ std::vector<int64_t> storage_sizes_in_bytes;
+ Expr expr = GetRef<Expr>(op);
+ int device_type_int =
+ node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0;
+ if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
+ for (Type t : tuple_type->fields) {
+ const auto* ttype = t.as<TensorTypeNode>();
+ ICHECK(ttype);
+ storage_ids.push_back(next_available_sid_++);
+ storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
+ device_types.push_back(DLDeviceType(device_type_int));
+ }
+ } else {
+ const auto* ttype = op->checked_type().as<TensorTypeNode>();
+ ICHECK(ttype);
+ storage_ids.push_back(next_available_sid_++);
+ storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
+ device_types.push_back(DLDeviceType(device_type_int));
+ }
+ storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes);
+ }
+ /*! \brief mapping of expression -> storageInfo*/
+ StorageMap storage_device_map_;
+ /*! \brief mapping of expression -> device type*/
+ Map<Expr, Integer> node_device_map_;
+ /*! \brief current id of the temporary allocated*/
+ int next_available_sid_{0};
+ /*! \brief the set of intermediate tensors that are return variables */
+ std::vector<int> return_ids_;
};
/*! \brief Code generator for AOT executor */
@@ -120,65 +246,24 @@ class AOTExecutorCodegen : public ExprVisitor {
* \brief Return a vector of variables that represents the sids for the given Relay Expr
*/
std::vector<tir::Var> PackSid(Expr expr) {
- Array<IntegerArray> sids = storage_device_map_[expr];
- std::vector<tir::Var> sid_vars;
+ std::vector<tir::Var> buffer_vars;
+ StorageInfo& sinfo = storage_device_map_[expr];
// Note that an expression can have multiple sids associated with it
// e.g., returning multiple values from a function
- for (const auto& sid : sids[0]) {
+ for (auto sid : sinfo->storage_ids) {
// Determine if an sid is an output buffer
- int sid_int = static_cast<int>((sid.as<IntImmNode>())->value);
- auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
+ auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
- sid_vars.push_back(main_signature_[input_vars_.size() + output_index]);
+ buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]);
continue;
}
- // Pack the sid inside the TVMValue
- auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
- auto sid_value = sids_table_[sid];
- if (!use_unpacked_api_) {
- tvm::PrimExpr set_tensor =
- tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
- {sid_array, 0, tir::builtin::kArrData, sid_value});
- stmts_.push_back(
- tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
- } else {
- stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0)));
- }
-
- sid_vars.push_back(sid_array);
+ auto sid_value = sids_table_[sid];
+ buffer_vars.push_back(sid_value);
}
- return sid_vars;
- }
-
- /*!
- * \brief Utility function to return a parameter associated with an expression
- * \param expr Relay Expression associated with the parameter
- * \return Variable that represents the DLTensor associated with the parameters
- */
- tir::Var PackParam(Expr expr) {
- int param_sid = param_storage_ids_[params_by_expr_[expr]];
- auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());
-
- // Compose the lookup_call using a local stack
- Array<tir::Stmt> lookup_call;
- // Set the param to the value returned by lookup_call
- auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
- {tir::StringImm(params_by_expr_[expr])});
-
- if (!use_unpacked_api_) {
- tvm::PrimExpr set_param_array =
- tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
- {param_array, 0, tir::builtin::kArrData, param_handle});
- stmts_.push_back(
- tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
- } else {
- stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
- }
-
- return param_array;
+ return buffer_vars;
}
/*!
@@ -190,9 +275,6 @@ class AOTExecutorCodegen : public ExprVisitor {
// Input variable
int main_index = std::distance(input_vars_.begin(), input_iter);
return {main_signature_[main_index]};
- } else if (params_by_expr_.find(arg) != params_by_expr_.end()) {
- // Parameter of the network
- return {PackParam(arg)};
} else {
// Storage identifier (i.e., intermediate memory)
return PackSid(arg);
@@ -208,8 +290,14 @@ class AOTExecutorCodegen : public ExprVisitor {
// Pack the inputs
for (Expr arg : call->args) {
- auto var_arg = FindExpr(arg);
- args.push_back(var_arg[0]);
+ if (params_by_expr_.find(arg) != params_by_expr_.end()) {
+ auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+ {tir::StringImm(params_by_expr_[arg])});
+ args.push_back(param_handle);
+ } else {
+ auto var_arg = FindExpr(arg);
+ args.push_back(var_arg[0]);
+ }
}
auto ret_expr = Downcast<Expr>(call);
@@ -237,7 +325,7 @@ class AOTExecutorCodegen : public ExprVisitor {
* TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a
* copy-on-write fashion.
*/
- void CopyToOutput(te::Var out, te::Var in, size_t size) {
+ void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
// Define intermediate DLTensor to load/store the data
auto tmp0 = te::Var("tmp0", DataType::Handle());
auto tmp1 = te::Var("tmp1", DataType::Handle());
@@ -249,10 +337,15 @@ class AOTExecutorCodegen : public ExprVisitor {
PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{out, 0, tir::builtin::kArrData});
if (use_unpacked_api_) {
- retval_get = in;
tostore = out;
}
+ // Do not pack the input if the flag is set or the caller
+ // explicitly asked to do so (e.g., copying a param to the output)
+ if (use_unpacked_api_ || !pack_input) {
+ retval_get = in;
+ }
+
// Copy the variable from the input to the output
tir::Stmt copy = tir::For(
loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial,
@@ -390,8 +483,8 @@ class AOTExecutorCodegen : public ExprVisitor {
}
ICHECK_GE(storage_device_map_.count(expr), 0);
- auto& device_type = storage_device_map_[expr][1];
- auto call_dev_type = device_type[0]->value;
+ StorageInfo& sinfo = storage_device_map_[expr];
+ auto call_dev_type = sinfo->device_types[0];
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
@@ -425,17 +518,23 @@ class AOTExecutorCodegen : public ExprVisitor {
void VisitExpr_(const VarNode* op) override {
Expr expr = GetRef<Expr>(op);
+ StorageInfo& sinfo = storage_device_map_[expr];
// If the Var node is an output node we need to copy the content of the variable to the output
// It's safe to check the SID here because Var StorageToken are never reallocated
- Array<IntegerArray> sids = storage_device_map_[expr];
-
- auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
- static_cast<int>((sids[0][0].as<IntImmNode>())->value));
+ auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
- auto var_expr = FindExpr(expr);
- CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]);
+ if (params_by_expr_.find(expr) != params_by_expr_.end()) {
+ auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+ {tir::StringImm(params_by_expr_[expr])});
+ CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle,
+ /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]);
+ } else {
+ auto var_expr = FindExpr(expr);
+ CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
+ /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]);
+ }
}
}
@@ -443,19 +542,20 @@ class AOTExecutorCodegen : public ExprVisitor {
Expr expr = GetRef<Expr>(op);
size_t index = params_.size();
std::string name = "p" + std::to_string(index);
-
- param_storage_ids_[name] = storage_device_map_[expr][0][0]->value;
+ StorageInfo& sinfo = storage_device_map_[expr];
+ param_storage_ids_[name] = sinfo->storage_ids[0];
params_[name] = op->data;
params_by_expr_.Set(expr, name);
// If the Constant node is an output node we need to copy the content of the parameter to the
// output A Var node can only produce a single output
- Array<IntegerArray> sids = storage_device_map_[expr];
- auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
- static_cast<int>((sids[0][0].as<IntImmNode>())->value));
+ auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
- CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]);
+ auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+ {tir::StringImm(params_by_expr_[expr])});
+ CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false,
+ sinfo->storage_sizes_in_bytes[0]);
}
}
@@ -495,7 +595,9 @@ class AOTExecutorCodegen : public ExprVisitor {
throw std::invalid_argument("match case not yet implemented");
}
- // Create the main PrimFunc to execute the graph
+ // Create the main PrimFunc to execute the graph. Please note that
+ // the packed function calls don't pack their arguments. The AOT
+ // runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);
@@ -511,9 +613,9 @@ class AOTExecutorCodegen : public ExprVisitor {
continue;
}
- for (unsigned int i = 0; i < kv.second[0].size(); i++) {
- int size = kv.second[2][i];
- int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value);
+ for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) {
+ int size = kv.second->storage_sizes_in_bytes[i];
+ int sid = kv.second->storage_ids[i];
if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
continue;
@@ -523,6 +625,8 @@ class AOTExecutorCodegen : public ExprVisitor {
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body);
+ body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"),
+ body);
}
allocated[sid] = true;
}
@@ -578,7 +682,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::unordered_map<std::string, int64_t> param_storage_ids_;
/*! \brief plan memory of device result */
- Map<Expr, Array<IntegerArray>> storage_device_map_;
+ StorageMap storage_device_map_;
+ /*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
@@ -589,7 +694,7 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
- IntegerArray return_sid_;
+ std::vector<int> return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;
@@ -602,9 +707,11 @@ class AOTExecutorCodegen : public ExprVisitor {
compile_engine_(CompileEngine::Global()) {}
LoweredOutput Codegen(relay::Function func, String mod_name) {
- // Get the module, storage map and token sizes
- auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
- storage_device_map_ = (*pf)(func);
+ auto aot_allocator = AOTOnDemandAllocator();
+ aot_allocator.Run(func);
+
+ // Retrieve the storage map
+ storage_device_map_ = aot_allocator.GetStorageMap();
mod_name_ = mod_name;
for (auto input : func->params) {
@@ -614,20 +721,23 @@ class AOTExecutorCodegen : public ExprVisitor {
// Define the storage allocator ids
for (auto kv : storage_device_map_) {
- for (const auto& sid : kv.second[0]) {
- te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
- sids_table_[sid] = sid_var;
+ for (auto sid : kv.second->storage_ids) {
+ te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
+ sids_table_[sid] = buffer_var;
}
}
- // Find the return sid
- return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
+ // Retrieve the return sids
+ return_sid_ = aot_allocator.GetReturnIds();
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}
VisitExpr(func->body);
+ // Create the runner function. Please note that the function is not legal yet
+ // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
+ // to run the LegalizePackedCalls pass.
auto prim_func = CreateMainFunc(func->params.size());
UpdateMainWorkspaceSize(prim_func, func);
LoweredOutput ret;
@@ -649,14 +759,28 @@ class AOTExecutorCodegen : public ExprVisitor {
}
ret.external_mods = compile_engine_->LowerExternalFunctions();
+ // Build the TIR IRModule
+ Map<GlobalVar, BaseFunc> symbol_map;
+ symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
+ IRModule mod_run(symbol_map);
+
+ // Apply storage rewrite pass to the runner function to do memory planning
+ auto storage_rewrite = tir::transform::StorageRewrite();
+ mod_run = storage_rewrite(mod_run);
+
+ // Legalize AOT if needed. This means that all the packed calls
+ // need to be wrapped in TVMValues (unless use_unpacked_api is set)
+ if (!use_unpacked_api_) {
+ auto pack_calls = tir::transform::LegalizePackedCalls();
+ mod_run = pack_calls(mod_run);
+ }
+
+ // Update the lowered functions
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
- ret.lowered_funcs[target_host_str]->Add(
- GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
+ ret.lowered_funcs[target_host_str]->Update(mod_run);
} else {
- Map<GlobalVar, BaseFunc> symbol_map;
- symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
- ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
+ ret.lowered_funcs.Set(target_host_str, mod_run);
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 3b4e693..906ff8a 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -29,6 +29,8 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
+#include <limits>
+#include <string>
#include <vector>
namespace tvm {
@@ -162,6 +164,27 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
}
/*!
+ * \brief Create an int32 constant
+ * \param index the value of the constant
+ * \return the PrimExpr that represents the constant
+ */
+inline PrimExpr ConstInt32(size_t index) {
+ ICHECK_LE(index, std::numeric_limits<int>::max());
+ return make_const(DataType::Int(32), static_cast<int>(index));
+}
+
+/*!
+ * \brief Allocate TVMValues on the stack
+ * \param type type of allocation
+ * \param num number of TVMValues to allocate
+ * \return PrimExpr representing the TVMValue
+ */
+inline PrimExpr StackAlloca(std::string type, size_t num) {
+ Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
+ return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
+}
+
+/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
* \return The converted form.
diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc
new file mode 100644
index 0000000..424da1e
--- /dev/null
+++ b/src/tir/transforms/legalize_packed_calls.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file make_packed_call.cc
+ * \brief Rewrite packed calls in AOT so that the arguments are packed
+ */
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_map>
+
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+using InputMap =
+ std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+/**
+ * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize
+ * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in)
+ */
+class PackedCallLegalizer : public StmtExprMutator {
+ public:
+ Stmt Legalize(const InputMap& params, tir::Stmt body) {
+ inputs_ = params;
+ return StmtExprMutator::VisitStmt(body);
+ }
+
+ Stmt VisitStmt_(const EvaluateNode* op) final {
+ if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op);
+ const CallNode* call = op->value.as<CallNode>();
+ // Given a packed call f(A,B,C), we need a set of new statements
+ // let A_packed = set_struct(tvm_value1, A)
+ // let B_packed = set_struct(tvm_value2, B)
+ // let C_packed = set_struct(tvm_value3, C)
+ // call_packed(f, A_packed, B_packed, C_packed)
+ std::vector<Stmt> new_stmts;
+ if (call) {
+ if (call->op.same_as(builtin::tvm_call_cpacked())) {
+ Array<PrimExpr> packed_args{call->args[0]};
+ std::vector<tir::Var> tvm_values;
+ for (unsigned i = 1; i < call->args.size(); i++) {
+ // No need to pack inputs of the prim_func
+ if (inputs_[call->args[i]] == true) {
+ packed_args.push_back(call->args[i]);
+ } else {
+ // Pack the argument inside a TVMValue
+ std::stringstream ss;
+ ss << "tvm_value_" << tvm_value_index_++;
+ auto sid_array = tir::Var(ss.str(), DataType::Handle());
+ tvm_values.push_back(sid_array);
+
+ new_stmts.push_back(tir::Evaluate(
+ tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+ {sid_array, 0, tir::builtin::kArrData, call->args[i]})));
+ packed_args.push_back(sid_array);
+ }
+ }
+ // Evaluate the packed call
+ new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)));
+ tir::Stmt call_stmt = tir::SeqStmt(new_stmts);
+
+ // Allocate the TVMValues on the stack and define the variables
+ for (auto v : tvm_values) {
+ call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt);
+ }
+ return call_stmt;
+ }
+ }
+ return StmtExprMutator::VisitStmt_(op);
+ }
+
+ private:
+ InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed.
+ int tvm_value_index_; // Index of the actual tvm_value variable
+};
+
+namespace transform {
+
+Pass LegalizePackedCalls() {
+ auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+ auto* n = f.CopyOnWrite();
+
+ // Create the
+ InputMap inputs;
+ for (auto i : f->params) {
+ inputs[i] = true;
+ }
+ n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body));
+ return std::move(f);
+ };
+ return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls").set_body_typed(LegalizePackedCalls);
+} // namespace transform
+
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index 0e2e612..8b70817 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -34,16 +34,6 @@
namespace tvm {
namespace tir {
-inline PrimExpr ConstInt32(size_t index) {
- ICHECK_LE(index, std::numeric_limits<int>::max());
- return make_const(DataType::Int(32), static_cast<int>(index));
-}
-
-inline PrimExpr StackAlloca(std::string type, size_t num) {
- Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
- return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
-}
-
// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py
index a18a0fa..836ff4b 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -42,6 +42,46 @@ def mangle_name(mod_name, name):
return mod_name + "_" + name
+def convert_to_relay(
+ tflite_model_buf,
+ input_data,
+ input_node,
+):
+ """Convert a tflite model buffer in a Relay module"""
+
+ def convert_to_list(x):
+ if not isinstance(x, list):
+ x = [x]
+ return x
+
+ # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
+ try:
+ import tflite.Model
+
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+ except AttributeError:
+ import tflite
+
+ tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+ except ImportError:
+ raise ImportError("The tflite package must be installed")
+
+ input_data = convert_to_list(input_data)
+ input_node = convert_to_list(input_node)
+
+ shape_dict = {}
+ dtype_dict = {}
+ for i, e in enumerate(input_node):
+ shape_dict[e] = input_data[i].shape
+ dtype_dict[e] = input_data[i].dtype.name
+
+ mod, params = relay.frontend.from_tflite(
+ tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+ )
+ mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
+ return mod, params
+
+
def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout):
"""
This method runs a process and logs the output to both a log file and stdout
@@ -221,6 +261,7 @@ def compile_and_run(
params=None,
workspace_byte_alignment=8,
mod_name=None,
+ enable_op_fusion=True,
):
"""
This method verifies the generated source
@@ -232,7 +273,11 @@ def compile_and_run(
if not use_calculated_workspaces:
cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK "
- with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+ config = {"tir.disable_vectorize": True}
+ if not enable_op_fusion:
+ config["relay.FuseOps.max_depth"] = 1
+
+ with tvm.transform.PassContext(opt_level=3, config=config):
lib = tvm.relay.build(mod, target, target_host=target, params=params, mod_name=mod_name)
tmp_path = utils.tempdir()
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 36596a4..13cbfa7 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -465,5 +465,53 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5),
)
+def test_quant_mobilenet_tfl():
+ """Since in AOT we pass directly the output buffer from the user, in quantized networks sharing the output buffers is not possible.
+ This is because the output data type is int8 and the intermediate buffer are int32 or int16. We use mobilenet quantized to stress this
+ situation and verify that the output buffer sharing is disabled in AOT."""
+ pytest.importorskip("tflite")
+
+ import tvm.relay.testing.tf as tf_testing
+
+ tflite_model_file = tf_testing.get_workload_official(
+ "https://storage.googleapis.com/download.tensorflow.org/"
+ "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+ "mobilenet_v1_1.0_224_quant.tflite",
+ )
+ with open(tflite_model_file, "rb") as f:
+ tflite_model_buf = f.read()
+ data_shape = (1, 224, 224, 3)
+ in_min, in_max = (0, 255)
+ data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8")
+ mod, params = convert_to_relay(tflite_model_buf, data, "input")
+ inputs = {"input": data}
+ output_list = generate_ref_data(mod, inputs, params)
+ input_list = [inputs["input"]]
+ compile_and_run(mod, input_list, output_list, "--unpacked-api=0", True, params)
+
+
+@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"])
+def test_transpose(target_options):
+ """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place."""
+
+ dtype = "float32"
+ x = relay.var("x", shape=(10, 5), dtype=dtype)
+ y = relay.var("y", shape=(10, 5), dtype=dtype)
+ t = relay.var("z", shape=(), dtype=dtype)
+ a = relay.add(x, y)
+ b = relay.transpose(a)
+ z = relay.add(b, t)
+ # Check result.
+ func = relay.Function([x, y, t], z)
+ x_data = np.random.rand(10, 5).astype(dtype)
+ y_data = np.random.rand(10, 5).astype(dtype)
+ t_data = np.random.uniform(size=()).astype(dtype)
+ inputs = {"x": x_data, "y": y_data, "z": t_data}
+
+ output_list = generate_ref_data(func, inputs)
+ input_list = [inputs["x"], inputs["y"], inputs["z"]]
+ compile_and_run(func, input_list, output_list, target_options, True, enable_op_fusion=False)
+
+
if __name__ == "__main__":
pytest.main([__file__])
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
new file mode 100644
index 0000000..626af0c
--- /dev/null
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import tvm
+from tvm.script import ty
+from tvm import te, tir
+import numpy as np
+import tvm.testing
+import pytest
+
+
+@tvm.script.tir
+class Module:
+ def tir_packed_call() -> None:
+ A = tir.var("handle")
+ B = tir.var("handle")
+ C = tir.var("handle")
+ # body
+ tir.evaluate(
+ tir.tvm_call_cpacked(
+ "tvm_test_cpacked",
+ A,
+ B,
+ C,
+ dtype="int32",
+ )
+ )
+
+
+@tvm.script.tir
+class Expected:
+ def tir_packed_call() -> None:
+ A = tir.var("handle")
+ B = tir.var("handle")
+ C = tir.var("handle")
+
+ # body
+ tvm_value_2 = tir.var("handle")
+ tvm_value_1 = tir.var("handle")
+ tvm_value_0 = tir.var("handle")
+ with tir.let(tvm_value_2, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+ with tir.let(tvm_value_1, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+ with tir.let(tvm_value_0, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+ tir.evaluate(tir.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle"))
+ tir.evaluate(tir.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle"))
+ tir.evaluate(tir.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle"))
+ tir.evaluate(
+ tir.tvm_call_cpacked(
+ "tvm_test_cpacked",
+ tvm_value_0,
+ tvm_value_1,
+ tvm_value_2,
+ dtype="int32",
+ )
+ )
+
+
+def test_aot_packed_call():
+ mod = Module()
+ expected = Expected()
+ out = tir.transform.LegalizePackedCalls()(mod)
+ tvm.ir.assert_structural_equal(expected, out, map_free_vars=True)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])