You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/09 14:46:36 UTC
[tvm] branch main updated: [TIR][USMP] adding the pass to convert to pool offsets (#9418)
This is an automated email from the ASF dual-hosted git repository.
leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ce4fe4 [TIR][USMP] adding the pass to convert to pool offsets (#9418)
3ce4fe4 is described below
commit 3ce4fe47ca976695299069630624757075471702
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Thu Dec 9 14:45:16 2021 +0000
[TIR][USMP] adding the pass to convert to pool offsets (#9418)
* [TIR][USMP] adding the pass to convert to pool offsets
This commit adds a transform pass that consumes
the planned pool allocations using memory planning algorithm
that convertes them to pool offsets.
* adds two test cases for a linear structure with two pools
* adds test case with a single pool for residual structures
Change-Id: I9d31e854461b5c21df72d1452120d286b96791c0
* [TIR][USMP] adding the pass to convert to pool offsets
* Adding a toggle to produce TIR that is TVMScript printable for unit
testing
* Fixing the unit tests
* Ensure deterministic pool variable ordering.
Change-Id: I317675df03327b0ebbf4ca074255384e63f07cd6
* [TIR][USMP] adding the pass to convert to pool offsets
Fixing the references after changes in the memory planning
algorithm.
Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d
* [TIR][USMP] adding the pass to convert to pool offsets
* fixing the lint
Change-Id: I7ff920b92d14a9919c930a4b35a2169c77a57dd1
* [TIR][USMP] adding the pass to convert to pool offsets
* removing unnecessary defitinitions
* remove global var map
* adding explaination for let bindings to pointer type
Change-Id: I31bd1a9f3057ee7f06252263565b0f75c51e6d13
* [TIR][USMP] adding the pass to convert to pool offsets
* rebase changes
* making imports absolute
* fixing typos and removing unnecesary lines
Change-Id: I4c94b9955b001513fecb39ca94f81b1ad99c7bfc
* [TIR][USMP] adding the pass to convert to pool offsets
* fixing typos
Change-Id: I42c557fd394aefdf8c2e825c4e88770eb0732f9b
---
include/tvm/tir/usmp/utils.h | 48 ++
python/tvm/script/tir/__init__.py | 2 +-
python/tvm/script/tir/ty.py | 1 +
python/tvm/tir/usmp/__init__.py | 1 +
python/tvm/tir/usmp/{ => transform}/__init__.py | 3 +-
.../usmp/{__init__.py => transform/_ffi_api.py} | 8 +-
python/tvm/tir/usmp/transform/transform.py | 46 ++
src/printer/text_printer.h | 7 +-
src/tir/ir/stmt.cc | 9 +-
.../convert_pool_allocations_to_offsets.cc | 349 ++++++++++++++
src/tir/usmp/utils.cc | 39 ++
...ransform_convert_pool_allocations_to_offsets.py | 523 +++++++++++++++++++++
12 files changed, 1025 insertions(+), 11 deletions(-)
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 145c61d..30c8f2d 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -226,6 +226,44 @@ class PoolAllocation : public ObjectRef {
};
/*!
+ * \brief This object contains information post-allocation for PoolInfo objects
+ */
+struct AllocatedPoolInfoNode : public Object {
+ /*! \brief The assigned PoolInfo object */
+ PoolInfo pool_info;
+ /*! \brief The allocated size into this pool */
+ Integer allocated_size;
+ /*! \brief An optional associated pool Var*/
+ Optional<Var> pool_var;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("pool_info", &pool_info);
+ v->Visit("allocated_size", &allocated_size);
+ v->Visit("pool_var", &pool_var);
+ }
+
+ bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
+ return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
+ equal(pool_var, other->pool_var);
+ }
+
+ void SHashReduce(SHashReducer hash_reduce) const {
+ hash_reduce(pool_info);
+ hash_reduce(allocated_size);
+ hash_reduce(pool_var);
+ }
+
+ static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
+ TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object);
+};
+
+class AllocatedPoolInfo : public ObjectRef {
+ public:
+ TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
+ TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
+};
+
+/*!
* \brief Convert the IR-bound BufferInfo map to an array of BufferInfo
*
* \param buffer_info_map IR-bound BufferInfo map
@@ -248,6 +286,16 @@ Integer CalculateExtentsSize(const AllocateNode* op);
} // namespace usmp
} // namespace tir
+
+namespace attr {
+/*!
+ * \brief This is a BaseFunc attribute to indicate which input var represent
+ * a PoolInfo Object in the form of a Map<Var, PoolInfo>.
+ */
+static constexpr const char* kPoolArgs = "pool_args";
+
+} // namespace attr
+
} // namespace tvm
#endif // TVM_TIR_USMP_UTILS_H_
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 472b3de..de40459 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -17,7 +17,7 @@
"""TVMScript for TIR"""
# Type system
-from .ty import int8, int16, int32, int64, float16, float32, float64
+from .ty import uint8, int8, int16, int32, int64, float16, float32, float64
from .ty import boolean, handle, Ptr, Tuple, Buffer
from .prim_func import prim_func
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index 2808e7a..0432692 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -137,6 +137,7 @@ class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods,
pass # pylint: disable=unnecessary-pass
+uint8 = ConcreteType("uint8")
int8 = ConcreteType("int8")
int16 = ConcreteType("int16")
int32 = ConcreteType("int32")
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py
index 8aa0d4c..514727d 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/__init__.py
@@ -18,4 +18,5 @@
"""Namespace for Unified Static Memory Planner"""
from . import analysis
+from . import transform
from .utils import BufferInfo
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/transform/__init__.py
similarity index 93%
copy from python/tvm/tir/usmp/__init__.py
copy to python/tvm/tir/usmp/transform/__init__.py
index 8aa0d4c..1a9d833 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/transform/__init__.py
@@ -17,5 +17,4 @@
# pylint: disable=unused-import, redefined-builtin
"""Namespace for Unified Static Memory Planner"""
-from . import analysis
-from .utils import BufferInfo
+from .transform import convert_pool_allocations_to_offsets
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/transform/_ffi_api.py
similarity index 83%
copy from python/tvm/tir/usmp/__init__.py
copy to python/tvm/tir/usmp/transform/_ffi_api.py
index 8aa0d4c..7973ca5 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/transform/_ffi_api.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=unused-import, redefined-builtin
-"""Namespace for Unified Static Memory Planner"""
+"""FFI APIs for tvm.tir.usmp.analysis"""
+import tvm._ffi
-from . import analysis
-from .utils import BufferInfo
+
+tvm._ffi._init_api("tir.usmp.transform", __name__)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
new file mode 100644
index 0000000..f472172
--- /dev/null
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""USMP Transform Python API for passes"""
+# pylint: disable=invalid-name
+
+from typing import Dict
+
+import tvm
+from tvm.tir import Stmt
+from tvm.tir.usmp.utils import PoolAllocation
+from . import _ffi_api
+
+
+def convert_pool_allocations_to_offsets(
+ pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False
+) -> tvm.transform.Pass:
+ """Convert pool allocations to Load nodes with offsets from pools.
+
+ Parameters
+ ----------
+ pool_allocations : Dict[Stmt, PoolAllocation]
+ Allocate or AllocateConst node to pool allocation mapping
+ emit_tvmscript_printable : bool
+ A toggle to emit TVMScript printable IRModule for unit tests
+ removing all attributes that should be attached for integration
+
+ Returns
+ -------
+ ret: tvm.transform.Pass
+ The registered pass that converts the allocations to offsets.
+ """
+ return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations, emit_tvmscript_printable)
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index ebd667a..97146b8 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -449,10 +449,11 @@ class TextPrinter {
Doc PrintFinal(const ObjectRef& node) {
Doc doc;
- if (node->IsInstance<IRModuleNode>()) {
+ if (node.defined() && node->IsInstance<IRModuleNode>()) {
doc << PrintMod(Downcast<IRModule>(node));
- } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
- node->IsInstance<tir::StmtNode>()) {
+ } else if (node.defined() &&
+ (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
+ node->IsInstance<tir::StmtNode>())) {
doc << tir_text_printer_.Print(node);
} else {
doc << relay_text_printer_.PrintFinal(node);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 0d42c20..078561c 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -35,7 +35,14 @@ namespace tir {
LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
ICHECK(value.defined());
ICHECK(body.defined());
- ICHECK_EQ(value.dtype(), var.dtype());
+ auto vdtype = value.dtype();
+ // It is still valid to bind a pointer type
+ // var to a value that is of type handle.
+ if (var->type_annotation.as<PointerTypeNode>()) {
+ ICHECK(vdtype.is_handle());
+ } else {
+ ICHECK_EQ(value.dtype(), var.dtype());
+ }
ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
node->var = std::move(var);
diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
new file mode 100644
index 0000000..5ebf3c5
--- /dev/null
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The StmtExpr mutator class to replace allocate nodes
+ * with offsets within memory pools
+ *
+ * This mutator class will add Pool variables recursively to every PrimFunc
+ * starting from the main PrimFunc. For all allocate nodes, that have been
+ * memory planned, will be mutated into an offset using a Let binding.
+ */
+class PoolAllocationToOffsetConverter : public StmtExprMutator {
+ public:
+ PoolAllocationToOffsetConverter(const IRModule& module,
+ const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+ bool emit_tvmscript_printable = false)
+ : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
+ module_ = module->ShallowCopy();
+ for (const auto& kv : pool_allocations) {
+ // TODO(@manupa-arm): add AllocateConstNode when it is available
+ ICHECK(kv.first->IsInstance<AllocateNode>());
+ Allocate allocate_node = Downcast<Allocate>(kv.first);
+ PoolAllocation pool_allocation = kv.second;
+ PoolInfo pool_info = pool_allocation->pool_info;
+ int byte_pool_offset = pool_allocation->byte_offset->value;
+ int required_pool_size_for_allocation =
+ byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
+ if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
+ all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+ } else {
+ int prev_required_pool_size = all_pools_sizes_[pool_info];
+ if (prev_required_pool_size < required_pool_size_for_allocation) {
+ all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+ }
+ }
+ }
+
+ for (const auto& kv : all_pools_sizes_) {
+ PoolInfo pi = kv.first;
+ int allocated_size = kv.second;
+ allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
+ }
+ std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
+ [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
+ if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
+ return true;
+ }
+ return false;
+ });
+ }
+ IRModule operator()();
+
+ private:
+ PrimExpr VisitExpr_(const CallNode* op) override;
+ Stmt VisitStmt_(const AllocateNode* op) override;
+ PrimExpr VisitExpr_(const LoadNode* op) override;
+ Stmt VisitStmt_(const StoreNode* op) override;
+
+ /*! \brief This is a structure where the modified function
+ * signature is kept while body of the function is mutated
+ */
+ struct ScopeInfo {
+ Array<tir::Var> params;
+ Map<PoolInfo, tir::Var> pools_to_params;
+ Array<AllocatedPoolInfo> allocated_pool_params;
+ Map<tir::Var, Buffer> buffer_map;
+ };
+
+ /*! \brief The function scope information that are needed
+ * in the mutation of the function need to be stacked and
+ * popped when each function is entered/exited in the
+ * mutation process.
+ */
+ std::stack<ScopeInfo> scope_stack;
+ /*! \brief Each PrimFunc signature needs to be updated
+ * with pool variables. This is a helper function to
+ * capture the updated information to ScopeInfo object.
+ */
+ ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func);
+ /*! \brief This is a helper to create the PrimFunc with
+ * pool variables that calls the UpdateFunctionScopeInfo
+ * inside of it.
+ */
+ PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc);
+ /*! \brief This is a helper to append the pool args to
+ * the callsite of the function.
+ */
+ Array<PrimExpr> AppendPoolParamsToArgs(const Array<PrimExpr>& args);
+ /*! \brief Some arguments that used to be Allocate nodes
+ * should be replaced by Let nodes in the pass that loads
+ * the space from a pool variable.
+ */
+ Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args);
+
+ /*! \brief The tir::Var map to PoolInfo objects */
+ Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_;
+ /*! \brief The buffer var map to their allocate nodes */
+ Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
+ /*! \brief The IRModule being constructed/mutated */
+ IRModule module_;
+ /*! \brief The input allocate node to PoolAllocation map */
+ Map<tir::Stmt, PoolAllocation> pool_allocations_;
+ /*! \brief The set of ordered pools to ensure an unique order of args for functions */
+ std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
+ /*! \brief The storage of calculated pool size at init */
+ std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
+ /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded
+ * to position from a pool as designated by a PoolAllocation
+ */
+ Map<tir::Var, tir::Var> allocate_buf_to_let_var_;
+ /*! \brief A counter to give references to pools a reproducible unique set of names */
+ int pool_var_count_ = 0;
+ /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
+ bool emit_tvmscript_printable_ = false;
+ /*! \brief A counter to give references to pools a reproducible unique set of names */
+ std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
+};
+
+PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo(
+ const PrimFunc& original_func) {
+ ScopeInfo si;
+ si.params = original_func->params;
+ si.buffer_map = original_func->buffer_map;
+ Map<tir::Var, PoolInfo> ret;
+ for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
+ PoolInfo pool_info = allocated_pool_info->pool_info;
+ String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++);
+ String var_name = pool_ref_name + "_var";
+ DataType elem_dtype = DataType::UInt(8);
+ Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global"));
+ Var pool_var;
+ if (!emit_tvmscript_printable_) {
+ pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global"));
+ } else {
+ pool_var = Var(var_name, DataType::Handle(8));
+ }
+ si.params.push_back(pool_var);
+ si.pools_to_params.Set(pool_info, pool_var);
+ si.allocated_pool_params.push_back(AllocatedPoolInfo(
+ allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var));
+
+ int pool_size = all_pools_sizes_[pool_info];
+ String buffer_var_name = pool_ref_name + "_buffer_var";
+ si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name,
+ 16, 1, BufferType::kDefault));
+ }
+ return si;
+}
+
+PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
+ const PrimFunc& original_primfunc) {
+ // Only create the new function if it was not modified with pool params
+ if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) {
+ ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
+ this->scope_stack.push(si);
+ Stmt new_body = this->VisitStmt(original_primfunc->body);
+ this->scope_stack.pop();
+ DictAttrs original_attrs = original_primfunc->attrs;
+ // We dont need attrs of PrimFunc that might include non printable attrs such as target
+ // for unit tests where emit_tvmscript_printable_ is to be used.
+ if (emit_tvmscript_printable_) {
+ original_attrs = DictAttrs();
+ }
+ PrimFunc ret =
+ PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
+ if (!emit_tvmscript_printable_) {
+ return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
+ }
+ visited_primfuncs.insert(ret);
+ return ret;
+ }
+ return original_primfunc;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(
+ const Array<PrimExpr>& args) {
+ Array<PrimExpr> new_args;
+ for (const auto& arg : args) {
+ new_args.push_back(VisitExpr(arg));
+ }
+ ScopeInfo top_scope = this->scope_stack.top();
+ for (const auto& pools_vars : top_scope.pools_to_params) {
+ tir::Var pool_var = pools_vars.second;
+ Buffer buffer_var = top_scope.buffer_map[pool_var];
+ new_args.push_back(buffer_var->data);
+ }
+ return new_args;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
+ const Array<PrimExpr>& args) {
+ Array<PrimExpr> ret;
+ for (const PrimExpr& arg : args) {
+ if (arg->IsInstance<VarNode>() &&
+ allocate_buf_to_let_var_.find(Downcast<Var>(arg)) != allocate_buf_to_let_var_.end()) {
+ ret.push_back(allocate_buf_to_let_var_[Downcast<Var>(arg)]);
+ } else {
+ ret.push_back(arg);
+ }
+ }
+ return ret;
+}
+
+PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
+ if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
+ String func_name = Downcast<StringImm>(op->args[0])->value;
+ Array<PrimExpr> new_args;
+ if (module_->ContainGlobalVar(func_name)) {
+ GlobalVar gv = module_->GetGlobalVar(func_name);
+ PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
+ PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+ module_->Update(gv, prim_func);
+ new_args = AppendPoolParamsToArgs(op->args);
+ new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+ } else {
+ new_args = ReplaceAllocateArgsWithLetArgs(op->args);
+ }
+ return Call(op->dtype, op->op, new_args);
+ }
+ if (op->op->IsInstance<PrimFuncNode>()) {
+ PrimFunc func = Downcast<PrimFunc>(op->op);
+ PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+ Array<PrimExpr> new_args = AppendPoolParamsToArgs(op->args);
+ new_args = AppendPoolParamsToArgs(new_args);
+ new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+ return Call(op->dtype, prim_func, new_args);
+ }
+ return StmtExprMutator::VisitExpr_(op);
+}
+
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
+ if (pool_allocations_.count(GetRef<Allocate>(op))) {
+ ScopeInfo scope_info = scope_stack.top();
+ PoolAllocation pool_allocation = pool_allocations_[GetRef<Allocate>(op)];
+ Var param = scope_info.pools_to_params[pool_allocation->pool_info];
+ Buffer buffer_var = scope_info.buffer_map[param];
+ Load load_node =
+ Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition);
+ Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node});
+ Var tir_var;
+ if (!emit_tvmscript_printable_) {
+ tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation);
+ } else {
+ tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8));
+ }
+ allocate_buf_to_let_var_.Set(op->buffer_var, tir_var);
+ Stmt new_body = VisitStmt(op->body);
+ allocate_buf_to_let_var_.erase(op->buffer_var);
+ return LetStmt(tir_var, address_of_load, new_body);
+ }
+ return StmtExprMutator::VisitStmt_(op);
+}
+
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) {
+ if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) {
+ return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index,
+ VisitExpr(op->predicate));
+ }
+ return StmtExprMutator::VisitStmt_(op);
+}
+
+PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) {
+ if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) {
+ return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index,
+ VisitExpr(op->predicate));
+ }
+ return StmtExprMutator::VisitExpr_(op);
+}
+
+IRModule PoolAllocationToOffsetConverter::operator()() {
+ GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix);
+ PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
+ ScopeInfo si = UpdateFunctionScopeInfo(main_func);
+ this->scope_stack.push(si);
+ Stmt main_func_body = this->VisitStmt(main_func->body);
+ this->scope_stack.pop();
+ // We dont need attrs of PrimFunc that might include non printable attrs such as target
+ // for unit tests where emit_tvmscript_printable_ is to be used.
+ if (!emit_tvmscript_printable_) {
+ main_func =
+ PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs);
+ main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params);
+ } else {
+ main_func =
+ PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs());
+ }
+ module_->Update(gv, main_func);
+ if (!emit_tvmscript_printable_) {
+ return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params);
+ }
+ return this->module_;
+}
+
+namespace transform {
+
+tvm::transform::Pass ConvertPoolAllocationsToOffsets(
+ const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+ Bool emit_tvmscript_printable = Bool(false)) {
+ auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
+ return Downcast<IRModule>(PoolAllocationToOffsetConverter(
+ m, pool_allocations, emit_tvmscript_printable->value != 0)());
+ };
+ return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets",
+ {});
+}
+
+TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets")
+ .set_body_typed(ConvertPoolAllocationsToOffsets);
+
+} // namespace transform
+
+} // namespace usmp
+} // namespace tir
+} // namespace tvm
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index 7a6a683..14b3d26 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -135,6 +135,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});
+AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) {
+ auto allocated_poolinfo_node = make_object<AllocatedPoolInfoNode>();
+ allocated_poolinfo_node->pool_info = pool_info;
+ allocated_poolinfo_node->allocated_size = allocated_size;
+ if (pool_var.defined()) {
+ allocated_poolinfo_node->pool_var = pool_var;
+ }
+ data_ = std::move(allocated_poolinfo_node);
+}
+
+TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode);
+TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo")
+ .set_body_typed([](PoolInfo pool_info, Integer allocated_size) {
+ return AllocatedPoolInfo(pool_info, allocated_size);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<AllocatedPoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const AllocatedPoolInfoNode*>(ref.get());
+ p->stream << "AllocatedPoolInfoNode(\n"
+ << "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size
+ << ")";
+ });
+
Array<BufferInfo> CreateArrayBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map) {
Array<BufferInfo> ret;
for (const auto& kv : buffer_info_map) {
@@ -144,6 +168,19 @@ Array<BufferInfo> CreateArrayBufferInfo(const Map<BufferInfo, Stmt>& buffer_info
return ret;
}
+Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
+ const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
+ const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) {
+ Map<Stmt, PoolAllocation> ret;
+ for (const auto& kv : buffer_info_to_pool_allocation) {
+ BufferInfo bi = kv.first;
+ Stmt stmt_ = buffer_info_to_stmt[bi];
+ PoolAllocation pa = kv.second;
+ ret.Set(stmt_, pa);
+ }
+ return ret;
+}
+
Integer CalculateExtentsSize(const AllocateNode* op) {
size_t element_size_bytes = op->dtype.bytes();
size_t num_elements = 1;
@@ -163,6 +200,8 @@ TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo")
return (CreateArrayBufferInfo(buffer_info_map));
});
+TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations);
+
} // namespace usmp
} // namespace tir
} // namespace tvm
diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
new file mode 100644
index 0000000..fc61577
--- /dev/null
+++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -0,0 +1,523 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm.script import tir as T
+from tvm.tir import stmt_functor
+from tvm.tir.usmp import utils as usmp_utils
+from tvm.target import Target
+
+
+def _get_primfuncs_from_module(module):
+ primfuncs = list()
+ for gv, primfunc in module.functions.items():
+ primfuncs.append(primfunc)
+ return primfuncs
+
+
+def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
+ """Helper to assign poolinfos to allocate nodes in a tir.PrimFunc"""
+
+ def set_poolinfos(stmt):
+ if isinstance(stmt, tvm.tir.Allocate):
+ return tvm.tir.Allocate(
+ buffer_var=stmt.buffer_var,
+ dtype=stmt.dtype,
+ extents=stmt.extents,
+ condition=stmt.condition,
+ body=stmt.body,
+ annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos},
+ )
+
+ return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
+
+
+def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
+ """Helper to assign poolinfos to allocate nodes in a IRModule"""
+ ret = tvm.IRModule()
+ for global_var, basefunc in mod.functions.items():
+ if isinstance(basefunc, tvm.tir.PrimFunc):
+ ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
+ return ret
+
+
+def _assign_targets_to_primfuncs_irmodule(mod, target):
+ """Helper to assign target for PrimFunc in a IRModule"""
+ ret = tvm.IRModule()
+ for global_var, basefunc in mod.functions.items():
+ if isinstance(basefunc, tvm.tir.PrimFunc):
+ ret[global_var] = basefunc.with_attr("target", target)
+ return ret
+
+
+# fmt: off
+@tvm.script.ir_module
+class LinearStructure:
+ @T.prim_func
+ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+ placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+ placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ # body
+ for ax0_ax1_fused_1 in T.serial(0, 224):
+ for ax2_1, ax3_inner_1 in T.grid(224, 3):
+ T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True})
+ placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1)
+ T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+ # body
+ PaddedInput_7 = T.allocate([157323], "int16", "global")
+ for i0_i1_fused_7 in T.serial(0, 229):
+ for i2_7, i3_7 in T.grid(229, 3):
+ T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True)
+ for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
+ Conv2dOutput_7 = T.allocate([64], "int32", "global")
+ for ff_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_7, ff_3, 0, True)
+ for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
+ T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True)
+ for ax3_inner_7 in T.serial(0, 64):
+ T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
+ placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+ T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+ # body
+ tensor_2 = T.allocate([200704], "uint8", "global")
+ for ax0_ax1_fused_4 in T.serial(0, 56):
+ for ax2_4 in T.serial(0, 56):
+ for ax3_init in T.serial(0, 64):
+ T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
+ for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+ T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dt [...]
+ for ax0_ax1_fused_5 in T.serial(0, 56):
+ for ax2_5, ax3_3 in T.grid(56, 64):
+ T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
+
+ @T.prim_func
+ def run_model(input: T.handle, output: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "run_model", "runner_function": True})
+ # body
+ T.attr("default", "device_id", 0)
+ T.attr("default", "device_type", 1)
+ sid_9 = T.allocate([301056], "int8", "global")
+ sid_8 = T.allocate([802816], "int8", "global")
+ T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32"))
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class LinearStructurePlanned:
+ @T.prim_func
+ def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None:
+ fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ T.attr("default", "device_id", 0)
+ T.attr("default", "device_type", 1)
+ sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle")
+ sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle")
+ T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None:
+ placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8")
+ T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16")
+ fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle")
+ for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56):
+ for ax3_init in T.serial(0, 64):
+ T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True)
+ for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+ T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True)
+ for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64):
+ T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None:
+ placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8")
+ placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16")
+ T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16")
+ fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3):
+ T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None:
+ placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16")
+ placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16")
+ placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32")
+ T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8")
+ fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle")
+ for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3):
+ T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True)
+ for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
+ Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle")
+ for ff_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_7_let, ff_3, 0, True)
+ for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
+ T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True)
+ for ax3_inner_7 in T.serial(0, 64):
+ T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
+# fmt: on
+
+
+def test_mobilenet_subgraph():
+ target = Target("c")
+ fast_memory_pool = usmp_utils.PoolInfo(
+ pool_name="fast_memory",
+ target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+ size_hint_bytes=200704,
+ )
+ slow_memory_pool = usmp_utils.PoolInfo(
+ pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
+ )
+ tir_mod = LinearStructure
+ tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+ tir_mod = assign_poolinfos_to_allocates_in_irmodule(
+ tir_mod, [fast_memory_pool, slow_memory_pool]
+ )
+ main_func = tir_mod["run_model"]
+ buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ buffer_info_map = buffer_analysis.buffer_info_stmts
+
+ fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+ buffer_info_arr = fcreate_array_bi(buffer_info_map)
+ fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+ buffer_pool_allocations = fusmp_algo_greedy_by_size(
+ buffer_info_arr, buffer_analysis.memory_pressure
+ )
+ fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+ pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
+ tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+ pool_allocations, emit_tvmscript_printable=True
+ )(tir_mod)
+
+ tir_mod_with_offsets_ref = LinearStructurePlanned
+ tir_mod_with_offsets_ref = tvm.script.from_source(
+ tir_mod_with_offsets_ref.script(show_meta=False)
+ )
+ # The TIR produced fails on roundtrip TVMScript testing.
+ # Therefore, indicates the TVMScript produced here and/or the parser
+ # is lacking functionality. Thus for these tests, uses a string
+ # version of the TVMScript for each function as a check instead.
+ for gv, func in tir_mod_with_offsets_ref.functions.items():
+ assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str(
+ tir_mod_with_offsets[gv.name_hint].script()
+ )
+
+
+# fmt: off
+@tvm.script.ir_module
+class ResnetStructure:
+ @T.prim_func
+ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True})
+ placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
+ placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
+ T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
+ # body
+ for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
+ T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True})
+ placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
+ placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
+ placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
+ T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
+ # body
+ PaddedInput_1 = T.allocate([379456], "int16", "global")
+ for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
+ T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
+ for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
+ Conv2dOutput_1 = T.allocate([64], "int32", "global")
+ for ff_1 in T.serial(0, 64):
+ T.store(Conv2dOutput_1, ff_1, 0, True)
+ for ry, rx, rc_1 in T.grid(3, 3, 64):
+ T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
+ for ax3_inner_2 in T.serial(0, 64):
+ T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True})
+ placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16")
+ placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
+ placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
+ T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
+ # body
+ PaddedInput_2 = T.allocate([360000], "int16", "global")
+ for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
+ T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
+ for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
+ Conv2dOutput_2 = T.allocate([64], "int32", "global")
+ for ax3_outer_1 in T.serial(0, 4):
+ for ff_2 in T.serial(0, 64):
+ T.store(Conv2dOutput_2, ff_2, 0, True)
+ for rc_2 in T.serial(0, 64):
+ T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True)
+ for ax3_inner_3 in T.serial(0, 64):
+ T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True})
+ placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16")
+ placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16")
+ placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
+ placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
+ T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
+ # body
+ PaddedInput_3 = T.allocate([360000], "int16", "global")
+ for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
+ T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
+ for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
+ Conv2dOutput_3 = T.allocate([64], "int32", "global")
+ for ax3_outer_2 in T.serial(0, 4):
+ for ff_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_3, ff_3, 0, True)
+ for rc_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True)
+ for ax3_inner_4 in T.serial(0, 64):
+ T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 [...]
+
+ @T.prim_func
+ def run_model(input: T.handle, output: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "run_model", "runner_function": True})
+ # body
+ T.attr("default", "device_id", 0)
+ T.attr("default", "device_type", 1)
+ sid_2 = T.allocate([720000], "int8", "global")
+ sid_6 = T.allocate([5760000], "int8", "global")
+ sid_7 = T.allocate([720000], "int8", "global")
+ sid_8 = T.allocate([720000], "int8", "global")
+ T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32"))
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None:
+ # function attr dict
+ T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True})
+ placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
+ placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
+ placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
+ T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
+ # body
+ PaddedInput = T.allocate([360000], "int16", "global")
+ for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
+ T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
+ for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
+ Conv2dOutput = T.allocate([64], "int32", "global")
+ for ff in T.serial(0, 64):
+ T.store(Conv2dOutput, ff, 0, True)
+ for rc in T.serial(0, 64):
+ T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
+ for ax3_inner_1 in T.serial(0, 64):
+ T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class ResnetStructurePlanned:
+ @T.prim_func
+ def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
+ placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
+ placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
+ T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
+ global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
+ T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None:
+ placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16")
+ placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16")
+ placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
+ placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
+ T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
+ global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
+ for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
+ T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
+ for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
+ Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle")
+ for ax3_outer_2 in T.serial(0, 4):
+ for ff_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_3_let, ff_3, 0, True)
+ for rc_3 in T.serial(0, 64):
+ T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True)
+ for ax3_inner_4 in T.serial(0, 64):
+ T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * [...]
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None:
+ placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16")
+ placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
+ placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
+ T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
+ global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle")
+ for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
+ T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
+ for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
+ Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle")
+ for ax3_outer_1 in T.serial(0, 4):
+ for ff_2 in T.serial(0, 64):
+ T.store(Conv2dOutput_2_let, ff_2, 0, True)
+ for rc_2 in T.serial(0, 64):
+ T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True)
+ for ax3_inner_3 in T.serial(0, 64):
+ T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
+ placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
+ placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
+ placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
+ T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
+ global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle")
+ for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
+ T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
+ for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
+ Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle")
+ for ff in T.serial(0, 64):
+ T.store(Conv2dOutput_let, ff, 0, True)
+ for rc in T.serial(0, 64):
+ T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
+ for ax3_inner_1 in T.serial(0, 64):
+ T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+ @T.prim_func
+ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None:
+ placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
+ placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
+ placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
+ T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
+ global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle")
+ for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
+ T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
+ for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
+ Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle")
+ for ff_1 in T.serial(0, 64):
+ T.store(Conv2dOutput_1_let, ff_1, 0, True)
+ for ry, rx, rc_1 in T.grid(3, 3, 64):
+ T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
+ for ax3_inner_2 in T.serial(0, 64):
+ T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+ @T.prim_func
+ def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
+ global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+ # body
+ T.attr("default", "device_id", 0)
+ T.attr("default", "device_type", 1)
+ sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle")
+ sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
+ sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
+ sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
+ T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
+ T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
+ __tvm_meta__ = None
+# fmt: on
+
+
+def test_resnet_subgraph():
+ target = Target("c")
+ global_workspace_pool = usmp_utils.PoolInfo(
+ pool_name="global_workspace",
+ target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+ )
+ tir_mod = ResnetStructure
+ tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+ tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+ main_func = tir_mod["run_model"]
+ buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+ buffer_info_map = buffer_analysis.buffer_info_stmts
+
+ fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+ buffer_info_arr = fcreate_array_bi(buffer_info_map)
+ fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+ buffer_pool_allocations = fusmp_algo_greedy_by_size(
+ buffer_info_arr, buffer_analysis.memory_pressure
+ )
+ fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+ pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
+ tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+ pool_allocations, emit_tvmscript_printable=True
+ )(tir_mod)
+
+ tir_mod_with_offsets_ref = ResnetStructurePlanned
+
+ # The TIR produced fails on roundtrip TVMScript testing.
+ # Therefore, indicates the TVMScript produced here and/or the parser
+ # is lacking functionality. Thus for these tests, uses a string
+ # version of the TVMScript for each function as a check instead.
+ for gv, func in tir_mod_with_offsets_ref.functions.items():
+ assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str(
+ tir_mod_with_offsets[gv.name_hint].script()
+ )
+
+
+if __name__ == "__main__":
+ pytest.main([__file__] + sys.argv[1:])