You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2021/12/09 14:46:36 UTC

[tvm] branch main updated: [TIR][USMP] adding the pass to convert to pool offsets (#9418)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ce4fe4  [TIR][USMP] adding the pass to convert to pool offsets (#9418)
3ce4fe4 is described below

commit 3ce4fe47ca976695299069630624757075471702
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Thu Dec 9 14:45:16 2021 +0000

    [TIR][USMP] adding the pass to convert to pool offsets (#9418)
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    This commit adds a transform pass that consumes
    the planned pool allocations using memory planning algorithm
    that convertes them to pool offsets.
    
    * adds two test cases for a linear structure with two pools
    * adds test case with a single pool for residual structures
    
    Change-Id: I9d31e854461b5c21df72d1452120d286b96791c0
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    * Adding a toggle to produce TIR that is TVMScript printable for unit
    testing
    * Fixing the unit tests
    * Ensure deterministic pool variable ordering.
    
    Change-Id: I317675df03327b0ebbf4ca074255384e63f07cd6
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    Fixing the references after changes in the memory planning
    algorithm.
    
    Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    * fixing the lint
    
    Change-Id: I7ff920b92d14a9919c930a4b35a2169c77a57dd1
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    * removing unnecessary defitinitions
    * remove global var map
    * adding explaination for let bindings to pointer type
    
    Change-Id: I31bd1a9f3057ee7f06252263565b0f75c51e6d13
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    * rebase changes
    * making imports absolute
    * fixing typos and removing unnecesary lines
    
    Change-Id: I4c94b9955b001513fecb39ca94f81b1ad99c7bfc
    
    * [TIR][USMP] adding the pass to convert to pool offsets
    
    * fixing typos
    
    Change-Id: I42c557fd394aefdf8c2e825c4e88770eb0732f9b
---
 include/tvm/tir/usmp/utils.h                       |  48 ++
 python/tvm/script/tir/__init__.py                  |   2 +-
 python/tvm/script/tir/ty.py                        |   1 +
 python/tvm/tir/usmp/__init__.py                    |   1 +
 python/tvm/tir/usmp/{ => transform}/__init__.py    |   3 +-
 .../usmp/{__init__.py => transform/_ffi_api.py}    |   8 +-
 python/tvm/tir/usmp/transform/transform.py         |  46 ++
 src/printer/text_printer.h                         |   7 +-
 src/tir/ir/stmt.cc                                 |   9 +-
 .../convert_pool_allocations_to_offsets.cc         | 349 ++++++++++++++
 src/tir/usmp/utils.cc                              |  39 ++
 ...ransform_convert_pool_allocations_to_offsets.py | 523 +++++++++++++++++++++
 12 files changed, 1025 insertions(+), 11 deletions(-)

diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index 145c61d..30c8f2d 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -226,6 +226,44 @@ class PoolAllocation : public ObjectRef {
 };
 
 /*!
+ * \brief This object contains information post-allocation for PoolInfo objects
+ */
+struct AllocatedPoolInfoNode : public Object {
+  /*! \brief The assigned PoolInfo object */
+  PoolInfo pool_info;
+  /*! \brief The allocated size into this pool */
+  Integer allocated_size;
+  /*! \brief An optional associated pool Var*/
+  Optional<Var> pool_var;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("pool_info", &pool_info);
+    v->Visit("allocated_size", &allocated_size);
+    v->Visit("pool_var", &pool_var);
+  }
+
+  bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const {
+    return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) &&
+           equal(pool_var, other->pool_var);
+  }
+
+  void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(pool_info);
+    hash_reduce(allocated_size);
+    hash_reduce(pool_var);
+  }
+
+  static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object);
+};
+
+class AllocatedPoolInfo : public ObjectRef {
+ public:
+  TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var());
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode);
+};
+
+/*!
  * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo
  *
  * \param buffer_info_map IR-bound BufferInfo map
@@ -248,6 +286,16 @@ Integer CalculateExtentsSize(const AllocateNode* op);
 
 }  // namespace usmp
 }  // namespace tir
+
+namespace attr {
+/*!
+ * \brief This is a BaseFunc attribute to indicate which input var represent
+ * a PoolInfo Object in the form of a Map<Var, PoolInfo>.
+ */
+static constexpr const char* kPoolArgs = "pool_args";
+
+}  // namespace attr
+
 }  // namespace tvm
 
 #endif  // TVM_TIR_USMP_UTILS_H_
diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py
index 472b3de..de40459 100644
--- a/python/tvm/script/tir/__init__.py
+++ b/python/tvm/script/tir/__init__.py
@@ -17,7 +17,7 @@
 """TVMScript for TIR"""
 
 # Type system
-from .ty import int8, int16, int32, int64, float16, float32, float64
+from .ty import uint8, int8, int16, int32, int64, float16, float32, float64
 from .ty import boolean, handle, Ptr, Tuple, Buffer
 
 from .prim_func import prim_func
diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py
index 2808e7a..0432692 100644
--- a/python/tvm/script/tir/ty.py
+++ b/python/tvm/script/tir/ty.py
@@ -137,6 +137,7 @@ class GenericBufferType(SpecialStmt):  # pylint: disable=too-few-public-methods,
         pass  # pylint: disable=unnecessary-pass
 
 
+uint8 = ConcreteType("uint8")
 int8 = ConcreteType("int8")
 int16 = ConcreteType("int16")
 int32 = ConcreteType("int32")
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py
index 8aa0d4c..514727d 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/__init__.py
@@ -18,4 +18,5 @@
 """Namespace for Unified Static Memory Planner"""
 
 from . import analysis
+from . import transform
 from .utils import BufferInfo
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/transform/__init__.py
similarity index 93%
copy from python/tvm/tir/usmp/__init__.py
copy to python/tvm/tir/usmp/transform/__init__.py
index 8aa0d4c..1a9d833 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/transform/__init__.py
@@ -17,5 +17,4 @@
 # pylint: disable=unused-import, redefined-builtin
 """Namespace for Unified Static Memory Planner"""
 
-from . import analysis
-from .utils import BufferInfo
+from .transform import convert_pool_allocations_to_offsets
diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/transform/_ffi_api.py
similarity index 83%
copy from python/tvm/tir/usmp/__init__.py
copy to python/tvm/tir/usmp/transform/_ffi_api.py
index 8aa0d4c..7973ca5 100644
--- a/python/tvm/tir/usmp/__init__.py
+++ b/python/tvm/tir/usmp/transform/_ffi_api.py
@@ -14,8 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=unused-import, redefined-builtin
-"""Namespace for Unified Static Memory Planner"""
+"""FFI APIs for tvm.tir.usmp.analysis"""
+import tvm._ffi
 
-from . import analysis
-from .utils import BufferInfo
+
+tvm._ffi._init_api("tir.usmp.transform", __name__)
diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py
new file mode 100644
index 0000000..f472172
--- /dev/null
+++ b/python/tvm/tir/usmp/transform/transform.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""USMP Transform Python API for passes"""
+# pylint: disable=invalid-name
+
+from typing import Dict
+
+import tvm
+from tvm.tir import Stmt
+from tvm.tir.usmp.utils import PoolAllocation
+from . import _ffi_api
+
+
+def convert_pool_allocations_to_offsets(
+    pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False
+) -> tvm.transform.Pass:
+    """Convert pool allocations to Load nodes with offsets from pools.
+
+    Parameters
+    ----------
+    pool_allocations : Dict[Stmt, PoolAllocation]
+        Allocate or AllocateConst node to pool allocation mapping
+    emit_tvmscript_printable : bool
+        A toggle to emit TVMScript printable IRModule for unit tests
+        removing all attributes that should be attached for integration
+
+    Returns
+    -------
+    ret: tvm.transform.Pass
+        The registered pass that converts the allocations to offsets.
+    """
+    return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations, emit_tvmscript_printable)
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index ebd667a..97146b8 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -449,10 +449,11 @@ class TextPrinter {
 
   Doc PrintFinal(const ObjectRef& node) {
     Doc doc;
-    if (node->IsInstance<IRModuleNode>()) {
+    if (node.defined() && node->IsInstance<IRModuleNode>()) {
       doc << PrintMod(Downcast<IRModule>(node));
-    } else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
-               node->IsInstance<tir::StmtNode>()) {
+    } else if (node.defined() &&
+               (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
+                node->IsInstance<tir::StmtNode>())) {
       doc << tir_text_printer_.Print(node);
     } else {
       doc << relay_text_printer_.PrintFinal(node);
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 0d42c20..078561c 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -35,7 +35,14 @@ namespace tir {
 LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
   ICHECK(value.defined());
   ICHECK(body.defined());
-  ICHECK_EQ(value.dtype(), var.dtype());
+  auto vdtype = value.dtype();
+  // It is still valid to bind a pointer type
+  // var to a value that is of type handle.
+  if (var->type_annotation.as<PointerTypeNode>()) {
+    ICHECK(vdtype.is_handle());
+  } else {
+    ICHECK_EQ(value.dtype(), var.dtype());
+  }
 
   ObjectPtr<LetStmtNode> node = make_object<LetStmtNode>();
   node->var = std::move(var);
diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
new file mode 100644
index 0000000..5ebf3c5
--- /dev/null
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -0,0 +1,349 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc
+ * \brief This pass would convert the pool allocations to offsets from pools
+ */
+
+#include <tvm/arith/analyzer.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*!
+ * \brief The StmtExpr mutator class to replace allocate nodes
+ * with offsets within memory pools
+ *
+ * This mutator class will add Pool variables recursively to every PrimFunc
+ * starting from the main PrimFunc. For all allocate nodes, that have been
+ * memory planned, will be mutated into an offset using a Let binding.
+ */
+class PoolAllocationToOffsetConverter : public StmtExprMutator {
+ public:
+  PoolAllocationToOffsetConverter(const IRModule& module,
+                                  const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+                                  bool emit_tvmscript_printable = false)
+      : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
+    module_ = module->ShallowCopy();
+    for (const auto& kv : pool_allocations) {
+      // TODO(@manupa-arm): add AllocateConstNode when it is available
+      ICHECK(kv.first->IsInstance<AllocateNode>());
+      Allocate allocate_node = Downcast<Allocate>(kv.first);
+      PoolAllocation pool_allocation = kv.second;
+      PoolInfo pool_info = pool_allocation->pool_info;
+      int byte_pool_offset = pool_allocation->byte_offset->value;
+      int required_pool_size_for_allocation =
+          byte_pool_offset + CalculateExtentsSize(allocate_node.operator->());
+      if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) {
+        all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+      } else {
+        int prev_required_pool_size = all_pools_sizes_[pool_info];
+        if (prev_required_pool_size < required_pool_size_for_allocation) {
+          all_pools_sizes_[pool_info] = required_pool_size_for_allocation;
+        }
+      }
+    }
+
+    for (const auto& kv : all_pools_sizes_) {
+      PoolInfo pi = kv.first;
+      int allocated_size = kv.second;
+      allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size));
+    }
+    std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(),
+              [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) {
+                if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) {
+                  return true;
+                }
+                return false;
+              });
+  }
+  IRModule operator()();
+
+ private:
+  PrimExpr VisitExpr_(const CallNode* op) override;
+  Stmt VisitStmt_(const AllocateNode* op) override;
+  PrimExpr VisitExpr_(const LoadNode* op) override;
+  Stmt VisitStmt_(const StoreNode* op) override;
+
+  /*! \brief This is a structure where the modified function
+   * signature is kept while body of the function is mutated
+   */
+  struct ScopeInfo {
+    Array<tir::Var> params;
+    Map<PoolInfo, tir::Var> pools_to_params;
+    Array<AllocatedPoolInfo> allocated_pool_params;
+    Map<tir::Var, Buffer> buffer_map;
+  };
+
+  /*! \brief The function scope information that are needed
+   * in the mutation of the function need to be stacked and
+   * popped when each function is entered/exited in the
+   * mutation process.
+   */
+  std::stack<ScopeInfo> scope_stack;
+  /*! \brief Each PrimFunc signature needs to be updated
+   * with pool variables. This is a helper function to
+   * capture the updated information to ScopeInfo object.
+   */
+  ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func);
+  /*! \brief This is a helper to create the PrimFunc with
+   * pool variables that calls the UpdateFunctionScopeInfo
+   * inside of it.
+   */
+  PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc);
+  /*! \brief This is a helper to append the pool args to
+   * the callsite of the function.
+   */
+  Array<PrimExpr> AppendPoolParamsToArgs(const Array<PrimExpr>& args);
+  /*! \brief Some arguments that used to be Allocate nodes
+   * should be replaced by Let nodes in the pass that loads
+   * the space from a pool variable.
+   */
+  Array<PrimExpr> ReplaceAllocateArgsWithLetArgs(const Array<PrimExpr>& args);
+
+  /*! \brief The tir::Var map to PoolInfo objects */
+  Map<tir::Var, PoolInfo> primfunc_args_to_pool_info_map_;
+  /*! \brief The buffer var map to their allocate nodes */
+  Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_;
+  /*! \brief The IRModule being constructed/mutated */
+  IRModule module_;
+  /*! \brief The input allocate node to PoolAllocation map */
+  Map<tir::Stmt, PoolAllocation> pool_allocations_;
+  /*! \brief The set of ordered pools to ensure an unique order of args for functions */
+  std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
+  /*! \brief The storage of calculated pool size at init */
+  std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
+  /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded
+   * to position from a pool as designated by a PoolAllocation
+   */
+  Map<tir::Var, tir::Var> allocate_buf_to_let_var_;
+  /*! \brief A counter to give references to pools a reproducible unique set of names */
+  int pool_var_count_ = 0;
+  /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */
+  bool emit_tvmscript_printable_ = false;
+  /*! \brief A counter to give references to pools a reproducible unique set of names */
+  std::unordered_set<PrimFunc, ObjectPtrHash, ObjectPtrEqual> visited_primfuncs;
+};
+
+PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo(
+    const PrimFunc& original_func) {
+  ScopeInfo si;
+  si.params = original_func->params;
+  si.buffer_map = original_func->buffer_map;
+  Map<tir::Var, PoolInfo> ret;
+  for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) {
+    PoolInfo pool_info = allocated_pool_info->pool_info;
+    String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++);
+    String var_name = pool_ref_name + "_var";
+    DataType elem_dtype = DataType::UInt(8);
+    Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global"));
+    Var pool_var;
+    if (!emit_tvmscript_printable_) {
+      pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global"));
+    } else {
+      pool_var = Var(var_name, DataType::Handle(8));
+    }
+    si.params.push_back(pool_var);
+    si.pools_to_params.Set(pool_info, pool_var);
+    si.allocated_pool_params.push_back(AllocatedPoolInfo(
+        allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var));
+
+    int pool_size = all_pools_sizes_[pool_info];
+    String buffer_var_name = pool_ref_name + "_buffer_var";
+    si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name,
+                                       16, 1, BufferType::kDefault));
+  }
+  return si;
+}
+
+PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams(
+    const PrimFunc& original_primfunc) {
+  // Only create the new function if it was not modified with pool params
+  if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) {
+    ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc);
+    this->scope_stack.push(si);
+    Stmt new_body = this->VisitStmt(original_primfunc->body);
+    this->scope_stack.pop();
+    DictAttrs original_attrs = original_primfunc->attrs;
+    // We dont need attrs of PrimFunc that might include non printable attrs such as target
+    // for unit tests where emit_tvmscript_printable_ is to be used.
+    if (emit_tvmscript_printable_) {
+      original_attrs = DictAttrs();
+    }
+    PrimFunc ret =
+        PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs);
+    if (!emit_tvmscript_printable_) {
+      return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params);
+    }
+    visited_primfuncs.insert(ret);
+    return ret;
+  }
+  return original_primfunc;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::AppendPoolParamsToArgs(
+    const Array<PrimExpr>& args) {
+  Array<PrimExpr> new_args;
+  for (const auto& arg : args) {
+    new_args.push_back(VisitExpr(arg));
+  }
+  ScopeInfo top_scope = this->scope_stack.top();
+  for (const auto& pools_vars : top_scope.pools_to_params) {
+    tir::Var pool_var = pools_vars.second;
+    Buffer buffer_var = top_scope.buffer_map[pool_var];
+    new_args.push_back(buffer_var->data);
+  }
+  return new_args;
+}
+
+Array<PrimExpr> PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs(
+    const Array<PrimExpr>& args) {
+  Array<PrimExpr> ret;
+  for (const PrimExpr& arg : args) {
+    if (arg->IsInstance<VarNode>() &&
+        allocate_buf_to_let_var_.find(Downcast<Var>(arg)) != allocate_buf_to_let_var_.end()) {
+      ret.push_back(allocate_buf_to_let_var_[Downcast<Var>(arg)]);
+    } else {
+      ret.push_back(arg);
+    }
+  }
+  return ret;
+}
+
+PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
+  if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
+    String func_name = Downcast<StringImm>(op->args[0])->value;
+    Array<PrimExpr> new_args;
+    if (module_->ContainGlobalVar(func_name)) {
+      GlobalVar gv = module_->GetGlobalVar(func_name);
+      PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
+      PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+      module_->Update(gv, prim_func);
+      new_args = AppendPoolParamsToArgs(op->args);
+      new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+    } else {
+      new_args = ReplaceAllocateArgsWithLetArgs(op->args);
+    }
+    return Call(op->dtype, op->op, new_args);
+  }
+  if (op->op->IsInstance<PrimFuncNode>()) {
+    PrimFunc func = Downcast<PrimFunc>(op->op);
+    PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
+    Array<PrimExpr> new_args = AppendPoolParamsToArgs(op->args);
+    new_args = AppendPoolParamsToArgs(new_args);
+    new_args = ReplaceAllocateArgsWithLetArgs(new_args);
+    return Call(op->dtype, prim_func, new_args);
+  }
+  return StmtExprMutator::VisitExpr_(op);
+}
+
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
+  if (pool_allocations_.count(GetRef<Allocate>(op))) {
+    ScopeInfo scope_info = scope_stack.top();
+    PoolAllocation pool_allocation = pool_allocations_[GetRef<Allocate>(op)];
+    Var param = scope_info.pools_to_params[pool_allocation->pool_info];
+    Buffer buffer_var = scope_info.buffer_map[param];
+    Load load_node =
+        Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition);
+    Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node});
+    Var tir_var;
+    if (!emit_tvmscript_printable_) {
+      tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation);
+    } else {
+      tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8));
+    }
+    allocate_buf_to_let_var_.Set(op->buffer_var, tir_var);
+    Stmt new_body = VisitStmt(op->body);
+    allocate_buf_to_let_var_.erase(op->buffer_var);
+    return LetStmt(tir_var, address_of_load, new_body);
+  }
+  return StmtExprMutator::VisitStmt_(op);
+}
+
+Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) {
+  if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) {
+    return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index,
+                 VisitExpr(op->predicate));
+  }
+  return StmtExprMutator::VisitStmt_(op);
+}
+
+PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) {
+  if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) {
+    return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index,
+                VisitExpr(op->predicate));
+  }
+  return StmtExprMutator::VisitExpr_(op);
+}
+
+IRModule PoolAllocationToOffsetConverter::operator()() {
+  GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix);
+  PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
+  ScopeInfo si = UpdateFunctionScopeInfo(main_func);
+  this->scope_stack.push(si);
+  Stmt main_func_body = this->VisitStmt(main_func->body);
+  this->scope_stack.pop();
+  // We dont need attrs of PrimFunc that might include non printable attrs such as target
+  // for unit tests where emit_tvmscript_printable_ is to be used.
+  if (!emit_tvmscript_printable_) {
+    main_func =
+        PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs);
+    main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params);
+  } else {
+    main_func =
+        PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs());
+  }
+  module_->Update(gv, main_func);
+  if (!emit_tvmscript_printable_) {
+    return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params);
+  }
+  return this->module_;
+}
+
+namespace transform {
+
+tvm::transform::Pass ConvertPoolAllocationsToOffsets(
+    const Map<tir::Stmt, PoolAllocation>& pool_allocations,
+    Bool emit_tvmscript_printable = Bool(false)) {
+  auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
+    return Downcast<IRModule>(PoolAllocationToOffsetConverter(
+        m, pool_allocations, emit_tvmscript_printable->value != 0)());
+  };
+  return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets",
+                                          {});
+}
+
+TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets")
+    .set_body_typed(ConvertPoolAllocationsToOffsets);
+
+}  // namespace transform
+
+}  // namespace usmp
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index 7a6a683..14b3d26 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -135,6 +135,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
                 << ")";
     });
 
+AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) {
+  auto allocated_poolinfo_node = make_object<AllocatedPoolInfoNode>();
+  allocated_poolinfo_node->pool_info = pool_info;
+  allocated_poolinfo_node->allocated_size = allocated_size;
+  if (pool_var.defined()) {
+    allocated_poolinfo_node->pool_var = pool_var;
+  }
+  data_ = std::move(allocated_poolinfo_node);
+}
+
+TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode);
+TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo")
+    .set_body_typed([](PoolInfo pool_info, Integer allocated_size) {
+      return AllocatedPoolInfo(pool_info, allocated_size);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<AllocatedPoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const AllocatedPoolInfoNode*>(ref.get());
+      p->stream << "AllocatedPoolInfoNode(\n"
+                << "pool_info=" << node->pool_info << ",\n  allocated_size=" << node->allocated_size
+                << ")";
+    });
+
 Array<BufferInfo> CreateArrayBufferInfo(const Map<BufferInfo, Stmt>& buffer_info_map) {
   Array<BufferInfo> ret;
   for (const auto& kv : buffer_info_map) {
@@ -144,6 +168,19 @@ Array<BufferInfo> CreateArrayBufferInfo(const Map<BufferInfo, Stmt>& buffer_info
   return ret;
 }
 
+Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
+    const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
+    const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) {
+  Map<Stmt, PoolAllocation> ret;
+  for (const auto& kv : buffer_info_to_pool_allocation) {
+    BufferInfo bi = kv.first;
+    Stmt stmt_ = buffer_info_to_stmt[bi];
+    PoolAllocation pa = kv.second;
+    ret.Set(stmt_, pa);
+  }
+  return ret;
+}
+
 Integer CalculateExtentsSize(const AllocateNode* op) {
   size_t element_size_bytes = op->dtype.bytes();
   size_t num_elements = 1;
@@ -163,6 +200,8 @@ TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo")
       return (CreateArrayBufferInfo(buffer_info_map));
     });
 
+TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations);
+
 }  // namespace usmp
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
new file mode 100644
index 0000000..fc61577
--- /dev/null
+++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py
@@ -0,0 +1,523 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import sys
+
+import tvm
+from tvm.script import tir as T
+from tvm.tir import stmt_functor
+from tvm.tir.usmp import utils as usmp_utils
+from tvm.target import Target
+
+
+def _get_primfuncs_from_module(module):
+    primfuncs = list()
+    for gv, primfunc in module.functions.items():
+        primfuncs.append(primfunc)
+    return primfuncs
+
+
+def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
+    """Helper to assign poolinfos to allocate nodes in a tir.PrimFunc"""
+
+    def set_poolinfos(stmt):
+        if isinstance(stmt, tvm.tir.Allocate):
+            return tvm.tir.Allocate(
+                buffer_var=stmt.buffer_var,
+                dtype=stmt.dtype,
+                extents=stmt.extents,
+                condition=stmt.condition,
+                body=stmt.body,
+                annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos},
+            )
+
+    return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
+
+
+def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
+    """Helper to assign poolinfos to allocate nodes in a IRModule"""
+    ret = tvm.IRModule()
+    for global_var, basefunc in mod.functions.items():
+        if isinstance(basefunc, tvm.tir.PrimFunc):
+            ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
+    return ret
+
+
+def _assign_targets_to_primfuncs_irmodule(mod, target):
+    """Helper to assign target for PrimFunc in a IRModule"""
+    ret = tvm.IRModule()
+    for global_var, basefunc in mod.functions.items():
+        if isinstance(basefunc, tvm.tir.PrimFunc):
+            ret[global_var] = basefunc.with_attr("target", target)
+    return ret
+
+
+# fmt: off
+@tvm.script.ir_module
+class LinearStructure:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True})
+        placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1)
+        T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        # body
+        PaddedInput_7 = T.allocate([157323], "int16", "global")
+        for i0_i1_fused_7 in T.serial(0, 229):
+            for i2_7, i3_7 in T.grid(229, 3):
+                T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
+            Conv2dOutput_7 = T.allocate([64], "int32", "global")
+            for ff_3 in T.serial(0, 64):
+                T.store(Conv2dOutput_7, ff_3, 0, True)
+                for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
+                    T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True)
+            for ax3_inner_7 in T.serial(0, 64):
+                T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True})
+        placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        tensor_2 = T.allocate([200704], "uint8", "global")
+        for ax0_ax1_fused_4 in T.serial(0, 56):
+            for ax2_4 in T.serial(0, 56):
+                for ax3_init in T.serial(0, 64):
+                    T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True)
+                for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+                    T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dt [...]
+        for ax0_ax1_fused_5 in T.serial(0, 56):
+            for ax2_5, ax3_3 in T.grid(56, 64):
+                T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
+
+    @T.prim_func
+    def run_model(input: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "run_model", "runner_function": True})
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_9 = T.allocate([301056], "int8", "global")
+        sid_8 = T.allocate([802816], "int8", "global")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32"))
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class LinearStructurePlanned:
+    @T.prim_func
+    def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None:
+        fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle")
+        sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32"))
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None:
+        placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8")
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16")
+        fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle")
+        for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56):
+            for ax3_init in T.serial(0, 64):
+                T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True)
+            for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64):
+                T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True)
+        for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64):
+            T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None:
+        placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8")
+        placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16")
+        T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16")
+        fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3):
+            T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None:
+        placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16")
+        placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16")
+        placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32")
+        T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8")
+        fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle")
+        for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3):
+            T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544):
+            Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle")
+            for ff_3 in T.serial(0, 64):
+                T.store(Conv2dOutput_7_let, ff_3, 0, True)
+                for ry_2, rx_2, rc_7 in T.grid(7, 7, 3):
+                    T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True)
+            for ax3_inner_7 in T.serial(0, 64):
+                T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True)
+# fmt: on
+
+
+def test_mobilenet_subgraph():
+    target = Target("c")
+    fast_memory_pool = usmp_utils.PoolInfo(
+        pool_name="fast_memory",
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+        size_hint_bytes=200704,
+    )
+    slow_memory_pool = usmp_utils.PoolInfo(
+        pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
+    )
+    tir_mod = LinearStructure
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+    tir_mod = assign_poolinfos_to_allocates_in_irmodule(
+        tir_mod, [fast_memory_pool, slow_memory_pool]
+    )
+    main_func = tir_mod["run_model"]
+    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    buffer_info_map = buffer_analysis.buffer_info_stmts
+
+    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    buffer_pool_allocations = fusmp_algo_greedy_by_size(
+        buffer_info_arr, buffer_analysis.memory_pressure
+    )
+    fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
+    tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+        pool_allocations, emit_tvmscript_printable=True
+    )(tir_mod)
+
+    tir_mod_with_offsets_ref = LinearStructurePlanned
+    tir_mod_with_offsets_ref = tvm.script.from_source(
+        tir_mod_with_offsets_ref.script(show_meta=False)
+    )
+    # The TIR produced fails on roundtrip TVMScript testing.
+    # Therefore, indicates the TVMScript produced here and/or the parser
+    # is lacking functionality. Thus for these tests, uses a string
+    # version of the TVMScript for each function as a check instead.
+    for gv, func in tir_mod_with_offsets_ref.functions.items():
+        assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str(
+            tir_mod_with_offsets[gv.name_hint].script()
+        )
+
+
+# fmt: off
+@tvm.script.ir_module
+class ResnetStructure:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True})
+        placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
+        placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
+        T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
+        # body
+        for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
+            T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True})
+        placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
+        placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
+        placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
+        T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
+        # body
+        PaddedInput_1 = T.allocate([379456], "int16", "global")
+        for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
+            T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
+            Conv2dOutput_1 = T.allocate([64], "int32", "global")
+            for ff_1 in T.serial(0, 64):
+                T.store(Conv2dOutput_1, ff_1, 0, True)
+                for ry, rx, rc_1 in T.grid(3, 3, 64):
+                    T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
+            for ax3_inner_2 in T.serial(0, 64):
+                T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True})
+        placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16")
+        placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
+        placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
+        T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
+        # body
+        PaddedInput_2 = T.allocate([360000], "int16", "global")
+        for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
+            T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
+        for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
+            Conv2dOutput_2 = T.allocate([64], "int32", "global")
+            for ax3_outer_1 in T.serial(0, 4):
+                for ff_2 in T.serial(0, 64):
+                    T.store(Conv2dOutput_2, ff_2, 0, True)
+                    for rc_2 in T.serial(0, 64):
+                        T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True)
+                for ax3_inner_3 in T.serial(0, 64):
+                    T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True})
+        placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16")
+        placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16")
+        placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
+        placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
+        # body
+        PaddedInput_3 = T.allocate([360000], "int16", "global")
+        for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
+            T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
+        for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
+            Conv2dOutput_3 = T.allocate([64], "int32", "global")
+            for ax3_outer_2 in T.serial(0, 4):
+                for ff_3 in T.serial(0, 64):
+                    T.store(Conv2dOutput_3, ff_3, 0, True)
+                    for rc_3 in T.serial(0, 64):
+                        T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True)
+                for ax3_inner_4 in T.serial(0, 64):
+                    T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256  [...]
+
+    @T.prim_func
+    def run_model(input: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "run_model", "runner_function": True})
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_2 = T.allocate([720000], "int8", "global")
+        sid_6 = T.allocate([5760000], "int8", "global")
+        sid_7 = T.allocate([720000], "int8", "global")
+        sid_8 = T.allocate([720000], "int8", "global")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32"))
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True})
+        placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
+        placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
+        placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
+        T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
+        # body
+        PaddedInput = T.allocate([360000], "int16", "global")
+        for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
+            T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
+        for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
+            Conv2dOutput = T.allocate([64], "int32", "global")
+            for ff in T.serial(0, 64):
+                T.store(Conv2dOutput, ff, 0, True)
+                for rc in T.serial(0, 64):
+                    T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
+            for ax3_inner_1 in T.serial(0, 64):
+                T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class ResnetStructurePlanned:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None:
+        placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8")
+        placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32")
+        T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16")
+        global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16):
+            T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None:
+        placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16")
+        placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16")
+        placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32")
+        placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32")
+        T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8")
+        global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle")
+        for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64):
+            T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True)
+        for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625):
+            Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle")
+            for ax3_outer_2 in T.serial(0, 4):
+                for ff_3 in T.serial(0, 64):
+                    T.store(Conv2dOutput_3_let, ff_3, 0, True)
+                    for rc_3 in T.serial(0, 64):
+                        T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True)
+                for ax3_inner_4 in T.serial(0, 64):
+                    T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 *  [...]
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None:
+        placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16")
+        placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16")
+        placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32")
+        T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32")
+        global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle")
+        for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64):
+            T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True)
+        for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625):
+            Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle")
+            for ax3_outer_1 in T.serial(0, 4):
+                for ff_2 in T.serial(0, 64):
+                    T.store(Conv2dOutput_2_let, ff_2, 0, True)
+                    for rc_2 in T.serial(0, 64):
+                        T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True)
+                for ax3_inner_3 in T.serial(0, 64):
+                    T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None:
+        placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16")
+        placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16")
+        placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32")
+        T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16")
+        global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle")
+        for i0_i1_fused, i2, i3 in T.grid(75, 75, 64):
+            T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True)
+        for ax0_ax1_fused_ax2_fused in T.serial(0, 5625):
+            Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle")
+            for ff in T.serial(0, 64):
+                T.store(Conv2dOutput_let, ff, 0, True)
+                for rc in T.serial(0, 64):
+                    T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True)
+            for ax3_inner_1 in T.serial(0, 64):
+                T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None:
+        placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16")
+        placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16")
+        placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32")
+        T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16")
+        global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle")
+        for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64):
+            T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True)
+        for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625):
+            Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle")
+            for ff_1 in T.serial(0, 64):
+                T.store(Conv2dOutput_1_let, ff_1, 0, True)
+                for ry, rx, rc_1 in T.grid(3, 3, 64):
+                    T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True)
+            for ax3_inner_2 in T.serial(0, 64):
+                T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
+
+    @T.prim_func
+    def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None:
+        global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16)
+        # body
+        T.attr("default", "device_id", 0)
+        T.attr("default", "device_type", 1)
+        sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle")
+        sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle")
+        sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
+        sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle")
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32"))
+    __tvm_meta__ = None
+# fmt: on
+
+
+def test_resnet_subgraph():
+    target = Target("c")
+    global_workspace_pool = usmp_utils.PoolInfo(
+        pool_name="global_workspace",
+        target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
+    )
+    tir_mod = ResnetStructure
+    tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
+    tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
+    main_func = tir_mod["run_model"]
+    buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
+    buffer_info_map = buffer_analysis.buffer_info_stmts
+
+    fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
+    buffer_info_arr = fcreate_array_bi(buffer_info_map)
+    fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
+    buffer_pool_allocations = fusmp_algo_greedy_by_size(
+        buffer_info_arr, buffer_analysis.memory_pressure
+    )
+    fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
+    pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
+    tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
+        pool_allocations, emit_tvmscript_printable=True
+    )(tir_mod)
+
+    tir_mod_with_offsets_ref = ResnetStructurePlanned
+
+    # The TIR produced fails on roundtrip TVMScript testing.
+    # Therefore, indicates the TVMScript produced here and/or the parser
+    # is lacking functionality. Thus for these tests, uses a string
+    # version of the TVMScript for each function as a check instead.
+    for gv, func in tir_mod_with_offsets_ref.functions.items():
+        assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str(
+            tir_mod_with_offsets[gv.name_hint].script()
+        )
+
+
+if __name__ == "__main__":
+    pytest.main([__file__] + sys.argv[1:])