You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2022/04/26 00:38:05 UTC

[tvm] branch main updated: [USMP] Adding support for U4 usecase (#10785)

This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce29f02f4c [USMP] Adding support for U4 usecase (#10785)
ce29f02f4c is described below

commit ce29f02f4cacec8ed346d5f70508cce928b623de
Author: Manupa Karunaratne <ma...@arm.com>
AuthorDate: Tue Apr 26 01:37:59 2022 +0100

    [USMP] Adding support for U4 usecase (#10785)
    
    * [USMP] Adding support for U4 usecase
    
    This commit adds support for placing I/O
    tensors within the workspace buffer.
    
    This is enabled using PassConfig option
    tir.usmp.use_workspace_io. Once it is enabled,
    it will remove the I/O tensors from the TIR
    main PrimFunc and replace them with Allocate
    nodes that is annotated to contain Input and
    Output tensors.
    
    The USMP will plan memory for them accordingly.
    (i.e. it will re-use space used by them for
    intermediaries depending on the liveness).
    
    This will only be supported with C Interface API.
    Thus, this commit produces two functions to the
    metadata sources to obtain input and output structs
    that points to location inside the workspace struct.
    
    Change-Id: I4c7e750ead9a880ba900602c17f53a125f97dbf9
    
    * fixup! [USMP] Adding support for U4 usecase
    
    Change-Id: I78f03d36b12b4a5e8eae8d11701f51019489defc
    
    * fixup! [USMP] Adding support for U4 usecase
    
    Change-Id: I857f3d0ba7bc192d56d750c44b232998b2876e7a
---
 include/tvm/tir/usmp/transform.h                   |  11 ++
 include/tvm/tir/usmp/utils.h                       |  47 ++++-
 python/tvm/micro/model_library_format.py           |  31 +--
 src/relay/backend/aot_executor_codegen.cc          | 118 +++++++----
 src/relay/backend/utils.cc                         |   4 +-
 src/relay/backend/utils.h                          |  17 +-
 src/target/source/interface_c.cc                   |  48 ++++-
 src/target/source/source_module.cc                 | 130 ++++++++----
 src/tir/usmp/analysis/extract_buffer_info.cc       |  24 ++-
 .../convert_pool_allocations_to_offsets.cc         |   3 +-
 src/tir/usmp/transform/create_io_allocates.cc      | 219 +++++++++++++++++++++
 src/tir/usmp/unified_static_memory_planner.cc      |  48 +++--
 src/tir/usmp/utils.cc                              |  23 ++-
 tests/cpp/target/source/interface_c_test.cc        |  94 ++++++---
 tests/micro/zephyr/test_utils.py                   |   2 +-
 tests/python/relay/aot/aot_test_utils.py           |  68 ++++---
 tests/python/relay/aot/test_c_device_api.py        |  12 +-
 tests/python/relay/aot/test_crt_aot_usmp.py        | 176 +++++++++++++++++
 .../test_tir_usmp_transform_create_io_allocates.py | 206 +++++++++++++++++++
 19 files changed, 1086 insertions(+), 195 deletions(-)

diff --git a/include/tvm/tir/usmp/transform.h b/include/tvm/tir/usmp/transform.h
index 6de64704bd..ccb684463f 100644
--- a/include/tvm/tir/usmp/transform.h
+++ b/include/tvm/tir/usmp/transform.h
@@ -56,6 +56,17 @@ TVM_DLL Pass ConvertPoolAllocationsToOffsets(const Map<tir::Stmt, PoolAllocation
  */
 TVM_DLL Pass AssignPoolInfo();
 
+/*!
+ * \brief This pass creates Allocate nodes for I/O tensors
+ *
+ * If the user wants to place the I/O tensors in the workspace, this pass is required to be
+ * run. In doing so, it will create Allocate nodes for I/O tensors to be planned, and be removed
+ * from function arguments.
+ *
+ * \return the pass
+ */
+TVM_DLL Pass CreateAllocatesForIO();
+
 }  // namespace transform
 }  // namespace usmp
 }  // namespace tir
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index d9f212f489..f7858acb17 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -41,10 +41,20 @@ constexpr const char* kUSMPEnableOption = "tir.usmp.enable";
  * \brief PassContext option to select the memory planning algorithm in USMP
  */
 constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
+/*!
+ * \brief PassContext option to enable placing I/O tensors in the workspace
+ */
+constexpr const char* kUSMPUseWorkspaceIO = "tir.usmp.use_workspace_io";
 
 namespace tir {
 namespace usmp {
 
+/*!
+ * \brief A special kind to distinguish between I/O tensors to the model
+ * and intermediate tensors of the model
+ */
+enum class BufferInfoKind { kIntermediate = 0, kInput = 1, kOutput = 2 };
+
 /*!
  * \brief Describes an abstract memory buffer that will get allocated inside a pool.
  * The actual memory buffer in represented by PoolAllocationNode after static memory planning.
@@ -65,6 +75,8 @@ struct BufferInfoNode : public Object {
   Integer alignment;
   /*! \brief The liveness conflicting other buffer info objects */
   Array<ObjectRef> conflicts;
+  /*! \brief Whether BufferInfo object retains info about IO tensors or intermediaries */
+  BufferInfoKind kind;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name_hint", &name_hint);
@@ -72,12 +84,13 @@ struct BufferInfoNode : public Object {
     v->Visit("pool_candidates", &pool_candidates);
     v->Visit("alignment", &alignment);
     v->Visit("conflicts", &conflicts);
+    v->Visit("kind", &kind);
   }
 
   bool SEqualReduce(const BufferInfoNode* other, SEqualReducer equal) const {
     return equal(name_hint, other->name_hint) && equal(size_bytes, other->size_bytes) &&
            equal(pool_candidates, other->pool_candidates) && equal(alignment, other->alignment) &&
-           equal(conflicts, other->conflicts);
+           equal(conflicts, other->conflicts) && equal(kind, other->kind);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
@@ -86,6 +99,7 @@ struct BufferInfoNode : public Object {
     hash_reduce(alignment);
     hash_reduce(conflicts);
     hash_reduce(pool_candidates);
+    hash_reduce(kind);
   }
   /*!
    * \brief Set the liveness conflicts of this BufferInfo
@@ -101,7 +115,8 @@ struct BufferInfoNode : public Object {
 class BufferInfo : public ObjectRef {
  public:
   TVM_DLL BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
-                     Integer alignment = runtime::kDefaultWorkspaceAlignment);
+                     Integer alignment = runtime::kDefaultWorkspaceAlignment,
+                     BufferInfoKind kind = BufferInfoKind::kIntermediate);
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode);
 };
 
@@ -237,6 +252,18 @@ Integer CalculateModuleWorkspaceSize(const IRModule& mod);
  */
 static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";
 
+/*!
+ * \brief The allocate node attribute to indicate it is being used to hold
+ * an input tensor, that needs to be initialized with.
+ */
+static constexpr const char* kInputTensorAllocate = "input_tensor";
+
+/*!
+ * \brief The allocate node attribute to indicate it is being used to hold
+ * an output tensor.
+ */
+static constexpr const char* kOutputTensorAllocate = "output_tensor";
+
 /*!
  * \brief Calculate the size of the extents in bytes
  *
@@ -254,6 +281,16 @@ Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
     const Map<BufferInfo, Stmt>& buffer_info_to_stmt,
     const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
 
+/*!
+ * \brief Obtains I/O tensor names to their PoolAllocation objects
+ *
+ * \param buffer_info_to_pool_allocation the map of BufferInfo objects to PoolAllocation objects
+ *
+ * This function will obtain pool allocations for I/O tensors if that had been planned
+ */
+Map<String, PoolAllocation> GetIOPoolAllocations(
+    const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation);
+
 }  // namespace usmp
 }  // namespace tir
 
@@ -265,10 +302,10 @@ namespace attr {
 static constexpr const char* kPoolArgs = "pool_args";
 
 /*!
- * \brief This is a IRModule attribute that contains all the PoolInfo objects
- * as an Array.
+ * \brief This is a IRModule attribute that contains I/O Tensor names to pool
+ * allocations.
  */
-static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos";
+static constexpr const char* kIOTensorPoolAllocations = "io_tensor_pool_allocations";
 
 }  // namespace attr
 
diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py
index 6b59b34430..6b95220b67 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -47,7 +47,7 @@ class UnsupportedInModelLibraryFormatError(Exception):
 
 
 def generate_c_interface_header(
-    module_name, inputs, outputs, pools, devices, workspace_size, include_path
+    module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size, include_path
 ):
     """Generate C Interface header to be included in MLF"""
     mangled_name = to_c_variable_style(prefix_generated_name(module_name))
@@ -55,7 +55,7 @@ def generate_c_interface_header(
 
     interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
     interface_c_module = interface_c_create(
-        module_name, inputs, outputs, pools, devices, workspace_size
+        module_name, inputs, outputs, pools, io_pool_allocations, devices, workspace_size
     )
 
     with open(metadata_header, "w") as header_file:
@@ -281,17 +281,8 @@ def _convert_tuple_to_outputs(ret_type, offset=0):
 
 
 def _get_inputs_and_outputs_from_module(mod):
-    main_func = _get_main_relay_func(mod)
-    inputs = [argument.name_hint for argument in main_func.params]
-
-    if "output_tensor_names" in main_func.attrs:
-        outputs = main_func.attrs["output_tensor_names"]
-    else:
-        if isinstance(main_func.ret_type, TupleType):
-            outputs = _convert_tuple_to_outputs(main_func.ret_type)
-        else:
-            outputs = ["output"]
-
+    inputs = [str(input_var.name) for input_var in mod.executor_codegen_metadata.inputs]
+    outputs = list(mod.executor_codegen_metadata.outputs)
     return inputs, outputs
 
 
@@ -299,6 +290,10 @@ def _get_pools_from_module(mod):
     return list(dict(mod.executor_codegen_metadata.pool_inputs).values())
 
 
+def _get_io_pool_allocation_from_module(mod):
+    return dict(mod.executor_codegen_metadata.io_pool_allocations)
+
+
 def _should_generate_interface_header(mod):
     return "interface-api" in mod.executor and mod.executor["interface-api"] == "c"
 
@@ -369,9 +364,17 @@ def _export_graph_model_library_format(
         inputs, outputs = _get_inputs_and_outputs_from_module(mod)
         devices = mod.get_devices()
         pools = _get_pools_from_module(mod)
+        io_pool_allocations = _get_io_pool_allocation_from_module(mod)
         workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"])
         generate_c_interface_header(
-            mod.libmod_name, inputs, outputs, pools, devices, workspace_size, include_path
+            mod.libmod_name,
+            inputs,
+            outputs,
+            pools,
+            io_pool_allocations,
+            devices,
+            workspace_size,
+            include_path,
         )
 
     parameters_dir = tempdir / "parameters"
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 9a194965de..22d4b1c032 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -784,13 +784,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
    * brief Create tir::Var for input/output while updating
    * the buffer_maps.
    */
-  void CreateIOVar(const Expr& expr, std::string name) {
+  void CreateIOVar(const Expr& expr, const std::string& original_name,
+                   bool use_unique_name = true) {
     if (expr->IsInstance<TupleNode>()) {
       Tuple tuple = Downcast<Tuple>(expr);
       for (unsigned i = 0; i < tuple->fields.size(); i++) {
-        CreateIOVar(tuple->fields[i], name + std::to_string(i) + "_");
+        CreateIOVar(tuple->fields[i], original_name);
       }
     } else {
+      std::string name = original_name;
+      if (use_unique_name) {
+        name = GetUniqueIOVarName(original_name);
+      }
       tir::Var var = tir::Var(name, DataType::Handle());
       main_signature_.push_back(var);
       auto tensor_type = expr->checked_type().as<TensorTypeNode>();
@@ -804,6 +809,19 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     }
   }
 
+  /*!
+   * brief Create a unique name for I/O Var
+   */
+  std::string GetUniqueIOVarName(std::string name) {
+    if (io_var_names_.find(name) == io_var_names_.end()) {
+      io_var_names_[name] = 1;
+      return name;
+    } else {
+      io_var_names_[name] = io_var_names_[name] + 1;
+      return name + std::to_string(io_var_names_[name]);
+    }
+  }
+
   /*!
    * brief Calculate workspace sizes for PrimFuncs in the IRModule
    */
@@ -945,6 +963,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   std::vector<tir::Stmt> stmts_;
   /*! \brief the list of return sids (note that the function might return more then one output */
   std::vector<int> return_sid_;
+  /*! \brief This is per IO var name counter to aid the generating unique names */
+  std::unordered_map<std::string, int> io_var_names_;
 
  public:
   AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
@@ -1032,7 +1052,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     for (auto input : lowered_main_func->params) {
       input_vars_.push_back(input);
       std::string input_name = SanitizeName(input->name_hint());
-      CreateIOVar(input, input_name);
+      // We dont want the compiler changing input names in the
+      // event of a sanitization collision. Therefore, enforcing
+      // the var created to use the input_name strictly.
+      CreateIOVar(input, input_name, /*use_unique_name = */ false);
     }
 
     // Define the storage allocator ids
@@ -1052,7 +1075,27 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     // Retrieve the return sids
     return_sid_ = final_aot_allocator.GetReturnIds();
     // Insert outputs to main func signature
-    CreateIOVar(lowered_main_func->body, "output");
+    // If output tensor names were provided use them
+    if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
+      Array<String> output_tensor_names = opt.value();
+      if (lowered_main_func->body->IsInstance<TupleNode>()) {
+        Tuple output_tuple = Downcast<Tuple>(lowered_main_func->body);
+        for (unsigned i = 0; i < output_tuple->fields.size(); i++) {
+          // AoT Executor Codegen does not create these names,
+          // thus should be used as they are provided.
+          CreateIOVar(output_tuple->fields[i], output_tensor_names[i],
+                      /*use_unique_name = */ false);
+        }
+      } else {
+        // AoT Executor Codegen does not create these names,
+        // thus should be used as they are provided.
+        CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
+      }
+    } else {
+      // If output tensor names are not provided we will generate output(x)
+      // where x is a counter to create unique names.
+      CreateIOVar(lowered_main_func->body, "output");
+    }
 
     CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
     VisitExpr(lowered_main_func->body);
@@ -1071,8 +1114,27 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
     // function and replacing it with its TIR version. We should try to make this a Pass.
     lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
-    auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
-    lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), prim_func);
+    auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
+    // Extract additional information around main TIR PrimFunc arguments
+    Array<String> devices = ListDevices();
+    const auto main_func_params_end_iterator =
+        tir_main_func->params.begin() + tir_main_func->params.size();
+    const auto outputs_begin_iterator =
+        main_func_params_end_iterator - return_sid_.size() - devices.size();
+    Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator);
+    Array<TensorType> input_tensor_types;
+    for (auto i : inputs) {
+      input_tensor_types.push_back(io_tensor_types_[i]);
+    }
+    Array<tir::Var> outputs =
+        Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size());
+    std::vector<String> output_var_names;
+    for (const tir::Var& output : outputs) {
+      output_var_names.push_back(output->name_hint);
+    }
+
+    Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
+    lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func);
     // Parallel for loops are not supported in AoT codegen.
     lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);
 
@@ -1109,9 +1171,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
 
     ret.external_mods = external_modules.value();
 
+    // Extract USMP metadata to pass onto metadata sources
     Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
     std::vector<tir::Var> pool_vars;
-    tir::PrimFunc tir_main_func =
+    tir_main_func =
         Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
     Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
         tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
@@ -1122,41 +1185,16 @@ class AOTExecutorCodegen : public MixedModeVisitor {
         pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info);
       }
     }
-    Array<String> devices = ListDevices();
-    Array<tir::Var> inputs =
-        Array<tir::Var>(tir_main_func->params.begin(),
-                        tir_main_func->params.begin() + tir_main_func->params.size() -
-                            return_sid_.size() - pool_vars.size() - devices.size());
+    Map<String, tir::usmp::PoolAllocation> io_pool_allocations =
+        lowered_mod
+            ->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations)
+            .value_or({});
 
-    Array<TensorType> input_tensor_types;
-    for (auto i : inputs) {
-      input_tensor_types.push_back(io_tensor_types_[i]);
-    }
-
-    std::vector<String> output_var_names;
-    if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
-      Array<String> output_tensor_names = opt.value();
-      for (size_t i = 0; i < output_tensor_names.size(); ++i) {
-        output_var_names.push_back(output_tensor_names[i]);
-      }
-    }
-
-    // If output names have not been specified then generate default output names
-    if (output_var_names.size() == 0) {
-      if (return_sid_.size() == 1) {
-        output_var_names.push_back(String("output"));
-      } else {
-        for (size_t i = 0; i < return_sid_.size(); ++i) {
-          output_var_names.push_back(String("output" + std::to_string(i)));
-        }
-      }
-    }
-
-    Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
+    ret.metadata =
+        ExecutorCodegenMetadata(inputs, input_tensor_types, output_var_names, output_tensor_types,
+                                pool_vars, devices, runtime::kTvmExecutorAot, mod_name,
+                                interface_api, unpacked_api, pool_var_info, io_pool_allocations);
 
-    ret.metadata = ExecutorCodegenMetadata(
-        inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices,
-        runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, pool_var_info);
     return ret;
   }
 
diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc
index 2bddf75566..4a6fe90289 100644
--- a/src/relay/backend/utils.cc
+++ b/src/relay/backend/utils.cc
@@ -185,7 +185,8 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata(
     Array<tir::Var> inputs, Array<TensorType> input_tensor_types, Array<String> outputs,
     Array<TensorType> output_tensor_types, Array<tir::Var> pools, Array<String> devices,
     String executor, String mod_name, String interface_api, bool unpacked_api,
-    Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs) {
+    Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs,
+    Map<String, tir::usmp::PoolAllocation> io_pool_allocations) {
   auto n = make_object<ExecutorCodegenMetadataNode>();
   n->inputs = inputs;
   n->input_tensor_types = input_tensor_types;
@@ -198,6 +199,7 @@ ExecutorCodegenMetadata::ExecutorCodegenMetadata(
   n->unpacked_api = unpacked_api;
   n->mod_name = mod_name;
   n->pool_inputs = pool_inputs;
+  n->io_pool_allocations = io_pool_allocations;
   data_ = std::move(n);
 }
 
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index a9035b9ae5..a31ff605ca 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -83,6 +83,8 @@ class ExecutorCodegenMetadataNode : public Object {
   bool unpacked_api;
   /*! \brief the input var names that correspond to pool_inputs */
   Optional<Map<tir::Var, tir::usmp::AllocatedPoolInfo>> pool_inputs;
+  /*! \brief the I/O tensor to PoolAllocations if any*/
+  Map<String, tir::usmp::PoolAllocation> io_pool_allocations;
 
   String mod_name = "";
 
@@ -96,6 +98,7 @@ class ExecutorCodegenMetadataNode : public Object {
     v->Visit("executor", &executor);
     v->Visit("unpacked_api", &unpacked_api);
     v->Visit("pool_inputs", &pool_inputs);
+    v->Visit("io_pool_allocations", &io_pool_allocations);
   }
 
   static constexpr const char* _type_key = "MetadataObj";
@@ -107,13 +110,13 @@ class ExecutorCodegenMetadataNode : public Object {
  */
 class ExecutorCodegenMetadata : public ObjectRef {
  public:
-  TVM_DLL ExecutorCodegenMetadata(Array<tir::Var> inputs, Array<TensorType> input_tensor_types,
-                                  Array<String> outputs, Array<TensorType> output_tensor_types,
-                                  Array<tir::Var> pools, Array<String> devices, String executor,
-                                  String mod_name, String interface_api = "packed",
-                                  bool unpacked_api = false,
-                                  Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs =
-                                      Map<tir::Var, tir::usmp::AllocatedPoolInfo>());
+  TVM_DLL ExecutorCodegenMetadata(
+      Array<tir::Var> inputs, Array<TensorType> input_tensor_types, Array<String> outputs,
+      Array<TensorType> output_tensor_types, Array<tir::Var> pools, Array<String> devices,
+      String executor, String mod_name, String interface_api = "packed", bool unpacked_api = false,
+      Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs =
+          Map<tir::Var, tir::usmp::AllocatedPoolInfo>(),
+      Map<String, tir::usmp::PoolAllocation> io_pool_allocations = {{}});
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef,
                                         ExecutorCodegenMetadataNode);
diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc
index 9f10fd2881..12d930d8f8 100644
--- a/src/target/source/interface_c.cc
+++ b/src/target/source/interface_c.cc
@@ -42,13 +42,15 @@ using namespace tvm::relay::backend;
 class InterfaceCNode : public runtime::ModuleNode {
  public:
   InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> outputs,
-                 Array<tir::usmp::AllocatedPoolInfo> pools, Array<String> devices,
+                 Array<tir::usmp::AllocatedPoolInfo> pools,
+                 Map<String, tir::usmp::PoolAllocation> io_pool_allocations, Array<String> devices,
                  int workspace_size)
       : module_name_(module_name),
         inputs_(inputs),
         outputs_(outputs),
         devices_(devices),
         pools_(FilterExternalPools(pools)),
+        io_pool_allocations_(io_pool_allocations),
         workspace_size_(workspace_size) {}
   const char* type_key() const { return "h"; }
 
@@ -74,6 +76,13 @@ class InterfaceCNode : public runtime::ModuleNode {
       EmitStruct(code, "workspace_pools", pool_names);
     }
 
+    if (!io_pool_allocations_.empty()) {
+      std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"}));
+      EmitMapIOToPoolsFunction(code, inputs_struct, "map_inputs", inputs_);
+      std::string outputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"}));
+      EmitMapIOToPoolsFunction(code, outputs_struct, "map_outputs", outputs_);
+    }
+
     EmitRunFunction(code);
     // Emit workspace
     EmitIntegerValueMacro(code, "Workspace size", "WORKSPACE_SIZE", workspace_size_);
@@ -152,9 +161,11 @@ class InterfaceCNode : public runtime::ModuleNode {
         ToCVariableStyle(PrefixGeneratedName({module_name_, "workspace_pools"}));
 
     code_stream << "/*!\n"
-                << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n"
-                << " * \\param inputs Input tensors for the module \n"
-                << " * \\param outputs Output tensors for the module \n";
+                << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n";
+    if (io_pool_allocations_.empty()) {
+      code_stream << " * \\param inputs Input tensors for the module \n";
+      code_stream << " * \\param outputs Output tensors for the module \n";
+    }
 
     if (!devices_.empty()) {
       code_stream << " * \\param devices Device context pointers for the module \n";
@@ -167,8 +178,10 @@ class InterfaceCNode : public runtime::ModuleNode {
                 << "int32_t " << run_function << "(\n";
 
     std::stringstream call_args_ss;
-    call_args_ss << "  struct " << inputs_struct << "* inputs,\n";
-    call_args_ss << "  struct " << outputs_struct << "* outputs,\n";
+    if (io_pool_allocations_.empty()) {
+      call_args_ss << "  struct " << inputs_struct << "* inputs,\n";
+      call_args_ss << "  struct " << outputs_struct << "* outputs,\n";
+    }
     if (!devices_.empty()) {
       call_args_ss << "  struct " << devices_struct << "* devices,\n";
     }
@@ -181,6 +194,23 @@ class InterfaceCNode : public runtime::ModuleNode {
     code_stream << call_args_str << "\n);\n";
   }
 
+  void EmitMapIOToPoolsFunction(std::stringstream& code_stream, const std::string& struct_type,
+                                const std::string& function_name,
+                                const Array<String>& tensor_names) {
+    code_stream << "/*!\n"
+                << " * \\brief Maps I/O inside the workspace pools for TVM module \""
+                << module_name_ << "\"\n"
+                << " * \\param workspace_pools Workspace memory pool struct for the module \n"
+                << " * \\return I/O tensor struct for the module \n";
+    std::string map_function = ToCVariableStyle(PrefixGeneratedName({module_name_, function_name}));
+    code_stream << " */\n"
+                << "struct " << struct_type << " " << map_function << "(\n";
+    std::string pools_struct =
+        ToCVariableStyle(PrefixGeneratedName({module_name_, "workspace_pools"}));
+    code_stream << "  struct " << pools_struct << "* workspace_pools\n";
+    code_stream << ");\n\n";
+  }
+
   Array<tir::usmp::AllocatedPoolInfo> FilterExternalPools(
       const Array<tir::usmp::AllocatedPoolInfo>& pools) {
     Array<tir::usmp::AllocatedPoolInfo> external_pools;
@@ -197,14 +227,16 @@ class InterfaceCNode : public runtime::ModuleNode {
   Array<String> outputs_;
   Array<String> devices_;
   Array<tir::usmp::AllocatedPoolInfo> pools_;
+  Map<String, tir::usmp::PoolAllocation> io_pool_allocations_;
   int workspace_size_;
 };
 
 runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
                                  Array<String> outputs, Array<tir::usmp::AllocatedPoolInfo> pools,
+                                 Map<String, tir::usmp::PoolAllocation> io_pool_allocations,
                                  Array<String> devices, int workspace_size) {
-  auto n =
-      make_object<InterfaceCNode>(module_name, inputs, outputs, pools, devices, workspace_size);
+  auto n = make_object<InterfaceCNode>(module_name, inputs, outputs, pools, io_pool_allocations,
+                                       devices, workspace_size);
   return runtime::Module(n);
 }
 
diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc
index ef5755f3e8..046b7e9606 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -251,6 +251,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
     }
   }
 
+  void GenerateIOWorkspaceMapFunction(const std::string& struct_type,
+                                      const std::string& function_name,
+                                      const Array<String>& tensor_names) {
+    std::string map_function = runtime::get_name_mangled(metadata_->mod_name, function_name);
+    code_ << "struct " << struct_type << " " << map_function << "(\n";
+    std::string pools_struct = runtime::get_name_mangled(metadata_->mod_name, "workspace_pools");
+    code_ << "  struct " << pools_struct << "* workspace_pools\n";
+    code_ << "\n){\n";
+    code_ << "struct " << struct_type << " ret = {\n";
+    for (const String& name : tensor_names) {
+      tir::usmp::PoolAllocation pool_allocation = metadata_->io_pool_allocations[name];
+      code_ << "\t." << name << " = "
+            << "&((uint8_t*)workspace_pools->" << pool_allocation->pool_info->pool_name << ")["
+            << pool_allocation->byte_offset << "],\n";
+    }
+    code_ << "};\n";
+    code_ << "return ret;\n";
+    code_ << "}\n\n";
+  }
+
   bool IsInternalWorkspaceBuffer(const tir::Var& pool_var) {
     if (metadata_->pool_inputs.defined()) {
       Map<tir::Var, tir::usmp::AllocatedPoolInfo> allocated_pool_infos =
@@ -271,16 +291,18 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
 
     {
       std::stringstream call_args_ss;
-      for (const tir::Var& input_var : metadata_->inputs) {
-        if (input_var->type_annotation.defined()) {
-          codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
-        } else {
-          codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
+      if (metadata_->io_pool_allocations.empty()) {
+        for (const tir::Var& input_var : metadata_->inputs) {
+          if (input_var->type_annotation.defined()) {
+            codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
+          } else {
+            codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
+          }
+          call_args_ss << " " << input_var->name_hint << ",";
+        }
+        for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
+          call_args_ss << "void* output" << i << ",";
         }
-        call_args_ss << " " << input_var->name_hint << ",";
-      }
-      for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
-        call_args_ss << "void* output" << i << ",";
       }
       for (const tir::Var& pool_var : metadata_->pools) {
         if (pool_var->type_annotation.defined()) {
@@ -303,12 +325,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
 
     {
       std::stringstream call_args_ss;
-      for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
-        call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
-      }
-      for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
-        int j = metadata_->inputs.size() + i;
-        call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,";
+      if (metadata_->io_pool_allocations.empty()) {
+        for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) {
+          call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,";
+        }
+        for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
+          int j = metadata_->inputs.size() + i;
+          call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,";
+        }
       }
       for (const tir::Var& pool_var : metadata_->pools) {
         if (IsInternalWorkspaceBuffer(pool_var)) {
@@ -329,15 +353,17 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
     int entrypoint_arg_count = 0;
     int run_func_arg_count = 0;
 
-    for (unsigned int i = 0; i < metadata_->inputs.size(); i++) {
-      run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
-      entrypoint_arg_count++;
-      run_func_arg_count++;
-    }
-    for (unsigned int i = 0; i < metadata_->outputs.size(); i++) {
-      run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
-      entrypoint_arg_count++;
-      run_func_arg_count++;
+    if (metadata_->io_pool_allocations.empty()) {
+      for (unsigned int i = 0; i < metadata_->inputs.size(); i++) {
+        run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
+        entrypoint_arg_count++;
+        run_func_arg_count++;
+      }
+      for (unsigned int i = 0; i < metadata_->outputs.size(); i++) {
+        run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count);
+        entrypoint_arg_count++;
+        run_func_arg_count++;
+      }
     }
     for (const tir::Var& pool_var : metadata_->pools) {
       if (IsInternalWorkspaceBuffer(pool_var)) {
@@ -361,8 +387,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
              "out_type_code, void* resource_handle) {\n";
 
     // We are creating a copy of the set of pointers
-    size_t number_of_io_tensors =
-        metadata_->inputs.size() + metadata_->outputs.size() + metadata_->pools.size();
+    size_t number_of_io_tensors = metadata_->inputs.size() + metadata_->outputs.size() +
+                                  metadata_->pools.size() - metadata_->io_pool_allocations.size();
     code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n";
 
     std::unordered_map<int, ObjectRef> run_func_to_entry_point_args =
@@ -390,19 +416,33 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
   void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func,
                                     const std::string& mod_name) {
     code_ << "#include <" << mod_name << ".h>\n";
+    if (!metadata_->io_pool_allocations.empty()) {
+      const std::string input_struct_type =
+          runtime::get_name_mangled(metadata_->mod_name, "inputs");
+      Array<String> input_tensor_names;
+      for (const tir::Var& input_var : metadata_->inputs) {
+        input_tensor_names.push_back(input_var->name_hint);
+      }
+      GenerateIOWorkspaceMapFunction(input_struct_type, "map_inputs", input_tensor_names);
+      const std::string output_struct_type =
+          runtime::get_name_mangled(metadata_->mod_name, "outputs");
+      GenerateIOWorkspaceMapFunction(output_struct_type, "map_outputs", metadata_->outputs);
+    }
     code_ << "TVM_DLL int32_t " << run_func << "(";
     {
       std::stringstream call_args_ss;
-      for (const tir::Var& input_var : metadata_->inputs) {
-        if (input_var->type_annotation.defined()) {
-          codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
-        } else {
-          codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
+      if (metadata_->io_pool_allocations.empty()) {
+        for (const tir::Var& input_var : metadata_->inputs) {
+          if (input_var->type_annotation.defined()) {
+            codegen_c_base_.PrintType(input_var->type_annotation, call_args_ss);
+          } else {
+            codegen_c_base_.PrintType(input_var.dtype(), call_args_ss);
+          }
+          call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ",";
+        }
+        for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
+          call_args_ss << "void* output" << i << ",";
         }
-        call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ",";
-      }
-      for (unsigned int i = 0; i < metadata_->outputs.size(); ++i) {
-        call_args_ss << "void* output" << i << ",";
       }
       for (const tir::Var& pool_var : metadata_->pools) {
         if (pool_var->type_annotation.defined()) {
@@ -424,8 +464,10 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
     code_ << "int32_t " << entrypoint_name << "(";
     {
       std::stringstream call_args_ss;
-      call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,";
-      call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,";
+      if (metadata_->io_pool_allocations.empty()) {
+        call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,";
+        call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,";
+      }
       if (!metadata_->pools.empty()) {
         bool is_external_pools_present = false;
         for (tir::Var pool_var : metadata_->pools) {
@@ -452,12 +494,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
 
     {
       std::stringstream call_args_ss;
-      for (const auto& input : metadata_->inputs) {
-        call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ",";
-      }
-      for (const auto& output : metadata_->outputs) {
-        call_args_ss << "outputs->" << relay::backend::SanitizeName(output);
-        call_args_ss << ",";
+      if (metadata_->io_pool_allocations.empty()) {
+        for (const auto& input : metadata_->inputs) {
+          call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ",";
+        }
+        for (const auto& output : metadata_->outputs) {
+          call_args_ss << "outputs->" << relay::backend::SanitizeName(output);
+          call_args_ss << ",";
+        }
       }
 
       for (const tir::Var& pool_var : metadata_->pools) {
diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc
index 6f4642ff15..b90cfddb71 100644
--- a/src/tir/usmp/analysis/extract_buffer_info.cc
+++ b/src/tir/usmp/analysis/extract_buffer_info.cc
@@ -227,10 +227,8 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
       auto pool_candidates =
           Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]);
 
-      // TODO(@manupa-arm): improve the error when the responsible component for attaching a single
-      // pool is added
       ICHECK(pool_candidates.size() > 0)
-          << "The core compiler should at least attach a single PoolInfo. If there were no "
+          << "The AssignPoolInfo pass should at least attach a single PoolInfo. If there were no "
              "user-given arguments for memory pools, the default behaviour is a single size "
              "un-restricted pool is assigned";
       PrimFunc func = scope_stack_.top().func;
@@ -241,8 +239,24 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) {
         workspace_alignment =
             executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
       }
-      auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes,
-                                    pool_candidates, workspace_alignment);
+
+      BufferInfoKind bi_kind = BufferInfoKind::kIntermediate;
+      String buffer_info_name = op->buffer_var->name_hint;
+      if (op->annotations.find(kInputTensorAllocate) != op->annotations.end()) {
+        bi_kind = BufferInfoKind::kInput;
+        // using original input name instead of the buffer_var name
+        // because this name will be used in the lowering to convey
+        // the pool allocation.
+        buffer_info_name = Downcast<String>(op->annotations[kInputTensorAllocate]);
+      } else if (op->annotations.find(kOutputTensorAllocate) != op->annotations.end()) {
+        bi_kind = BufferInfoKind::kOutput;
+        // using original output name instead of the buffer_var name
+        // because this name will be used in the lowering to convey
+        // the pool allocation.
+        buffer_info_name = Downcast<String>(op->annotations[kOutputTensorAllocate]);
+      }
+      auto buffer_info = BufferInfo(GetUniqueBufferName(buffer_info_name), size_bytes,
+                                    pool_candidates, workspace_alignment, bi_kind);
       auto allocate = GetRef<Allocate>(op);
       allocate_infos[op->buffer_var] =
           AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call};
diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
index ba5ab891ba..dc71e3d608 100644
--- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
+++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
@@ -168,7 +168,8 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
 };
 
 Optional<Var> PoolAllocationToOffsetConverter::GetResourceHandle(const PrimFunc& func) {
-  if (func->buffer_map.find(func->params.back()) == func->buffer_map.end()) {
+  if (!func->params.empty() &&
+      func->buffer_map.find(func->params.back()) == func->buffer_map.end()) {
     return func->params.back();
   }
   return Optional<Var>();
diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc
new file mode 100644
index 0000000000..59eee96163
--- /dev/null
+++ b/src/tir/usmp/transform/create_io_allocates.cc
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/target/target.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+#include <tvm/tir/usmp/algorithms.h>
+#include <tvm/tir/usmp/analysis.h>
+#include <tvm/tir/usmp/transform.h>
+#include <tvm/tir/usmp/utils.h>
+
+#include <stack>
+#include <string>
+
+namespace tvm {
+namespace tir {
+namespace usmp {
+
+/*! \brief Creates Allocate nodes with special annotations
+ * for I/O tensors in the graph to be memory planned.*/
+class IOAllocateCreator : public StmtExprVisitor {
+ public:
+  explicit IOAllocateCreator(const IRModule& module) {
+    main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
+    ICHECK(main_func_.defined()) << "main function is not in the module";
+    for (const auto& gv_func : module->functions) {
+      if (gv_func.second->IsInstance<PrimFuncNode>()) {
+        functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
+      }
+    }
+    mod_ = module->ShallowCopy();
+  }
+  IRModule operator()();
+
+ private:
+  void VisitExpr_(const BufferLoadNode* op) override;
+  void VisitExpr_(const LoadNode* op) override;
+  void VisitExpr_(const CallNode* op) override;
+  void VisitStmt_(const BufferStoreNode* op) override;
+  void VisitStmt_(const StoreNode* op) override;
+
+  /*! \brief Updates aliases that buffer vars inside the primfunc refer
+   * to in terms call arguments they get bound to.*/
+  void UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func);
+
+  /*! \brief The IRModule that is being mutated */
+  IRModule mod_;
+  /*! \brief The main function that calls into operator subgraphs */
+  PrimFunc main_func_;
+  /*! \brief The input Vars of the main function */
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
+  /*! \brief The output Vars of the main function */
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+  /*! \brief The buffer vars associated with the I/O Vars */
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> io_buffer_vars_;
+  /*! \brief The aliases that buffer vars inside the primfunc refer
+   * to in terms call arguments */
+  std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> aliases_;
+  /*!
+   * \brief The TIR main function calls by name to PrimFuncs to be able to
+   * support BYOC. Therefore, this Map records functions that are present
+   * in the IRModule by name/
+   */
+  Map<String, PrimFunc> functions_;
+};
+
+/*!
+ * \brief The function obtains the matched buffer vars for
+ * the params of the PrimFunc.
+ */
+Array<Var> static GetMatchedBuffers(const PrimFunc& func) {
+  Array<Var> buffer_vars;
+  for (unsigned int i = 0; i < func->params.size() - 1; i++) {
+    Var param = func->params[i];
+    buffer_vars.push_back(func->buffer_map[param]->data);
+  }
+  Var last_param = func->params.back();
+  // Checks whether last var is present in the buffer map
+  // because it could be the resource handle
+  if (func->buffer_map.find(last_param) != func->buffer_map.end()) {
+    buffer_vars.push_back(func->buffer_map[last_param]->data);
+  }
+  return buffer_vars;
+}
+
+/*!
+ * \brief The function updates aliases that each buffer var with its
+ * associated argument in the callsite.
+ */
+void IOAllocateCreator::UpdateAliases(const Array<PrimExpr>& args, const PrimFunc& func) {
+  auto param_buffers = GetMatchedBuffers(func);
+  // Last var could be a resource handle that does not have a Buffer
+  ICHECK(args.size() == param_buffers.size() || args.size() - 1 == param_buffers.size());
+  for (size_t i = 0; i < param_buffers.size(); i++) {
+    auto arg = args[i];
+    if (arg->IsInstance<VarNode>()) {
+      auto param_buf = param_buffers[i];
+      aliases_[param_buf] = Downcast<Var>(arg);
+    }
+  }
+}
+
+void IOAllocateCreator::VisitExpr_(const CallNode* op) {
+  if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
+    StringImm func_name = Downcast<StringImm>(op->args[0])->value;
+    if (functions_.find(func_name->value) != functions_.end()) {
+      auto func = functions_.at(func_name->value);
+      auto actual_args = Array<PrimExpr>(op->args.begin() + 1, op->args.end());
+      this->UpdateAliases(actual_args, func);
+      VisitStmt(func->body);
+      return;
+    }
+  }
+  if (op->op->IsInstance<PrimFuncNode>()) {
+    auto func = Downcast<PrimFunc>(op->op);
+    this->UpdateAliases(op->args, func);
+    VisitStmt(func->body);
+    return;
+  }
+  StmtExprVisitor::VisitExpr_(op);
+}
+
+void IOAllocateCreator::VisitExpr_(const BufferLoadNode* op) {
+  if (aliases_.find(op->buffer->data) != aliases_.end()) {
+    Var aliased_var = aliases_[op->buffer->data];
+    if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) {
+      ICHECK(outputs_.find(aliased_var) == outputs_.end())
+          << "BufferLoad nodes should not be reading from output buffer vars.";
+      inputs_.insert(aliased_var);
+    }
+  }
+  StmtExprVisitor::VisitExpr_(op);
+}
+
+void IOAllocateCreator::VisitExpr_(const LoadNode* op) { LOG(FATAL) << "should not come here"; }
+
+void IOAllocateCreator::VisitStmt_(const BufferStoreNode* op) {
+  if (aliases_.find(op->buffer->data) != aliases_.end()) {
+    Var aliased_var = aliases_[op->buffer->data];
+    if (io_buffer_vars_.find(aliased_var) != io_buffer_vars_.end()) {
+      ICHECK(inputs_.find(aliased_var) == inputs_.end())
+          << "BufferStore nodes should not be writing to input buffer vars.";
+      outputs_.insert(aliased_var);
+    }
+  }
+  StmtExprVisitor::VisitStmt_(op);
+}
+
+void IOAllocateCreator::VisitStmt_(const StoreNode* op) { LOG(FATAL) << "should not come here"; }
+
+IRModule IOAllocateCreator::operator()() {
+  Array<Var> new_main_params;
+  Stmt main_body = main_func_->body;
+  for (const Var& param : main_func_->params) {
+    if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) {
+      Var buffer_var = main_func_->buffer_map[param]->data;
+      io_buffer_vars_.insert(buffer_var);
+      aliases_[buffer_var] = buffer_var;
+    }
+  }
+  VisitStmt(main_body);
+  ICHECK(io_buffer_vars_.size() == inputs_.size() + outputs_.size())
+      << "Every IO Buffer var should be categorized either to be input or output";
+  for (const Var& param : main_func_->params) {
+    if (main_func_->buffer_map.find(param) != main_func_->buffer_map.end()) {
+      Buffer param_buffer = main_func_->buffer_map[param];
+      String io_annotation;
+      if (inputs_.find(param_buffer->data) != inputs_.end()) {
+        io_annotation = String(kInputTensorAllocate);
+      } else {
+        io_annotation = String(kOutputTensorAllocate);
+      }
+      main_body = Allocate(param_buffer->data, param_buffer->dtype, param_buffer->shape,
+                           const_true(), main_body, {{io_annotation, param->name_hint}});
+    } else {
+      new_main_params.push_back(param);
+    }
+  }
+  const GlobalVar& gv = mod_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
+  mod_->Update(gv,
+               PrimFunc(new_main_params, main_body, main_func_->ret_type, main_func_->buffer_map,
+                        main_func_->preflattened_buffer_map, main_func_->attrs, main_func_->span));
+  return mod_;
+}
+
+namespace transform {
+
+tvm::transform::Pass CreateAllocatesForIO() {
+  auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
+    return IOAllocateCreator(m)();
+  };
+  return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.CreateAllocatesForIO", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.usmp.transform.CreateAllocatesForIO").set_body_typed(CreateAllocatesForIO);
+
+}  // namespace transform
+
+}  // namespace usmp
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc
index e848440f02..ae91547390 100644
--- a/src/tir/usmp/unified_static_memory_planner.cc
+++ b/src/tir/usmp/unified_static_memory_planner.cc
@@ -23,6 +23,8 @@
  * a single composite pass.
  */
 
+#include <tvm/relay/executor.h>
+#include <tvm/relay/runtime.h>
 #include <tvm/target/target.h>
 #include <tvm/tir/stmt_functor.h>
 #include <tvm/tir/transform.h>
@@ -37,6 +39,7 @@ namespace tvm {
 
 TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPEnableOption, Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPAlgorithmOption, String);
+TVM_REGISTER_PASS_CONFIG_OPTION(kUSMPUseWorkspaceIO, Bool);
 
 namespace tir {
 namespace usmp {
@@ -49,10 +52,15 @@ static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
                {"greedy_by_conflicts", algo::GreedyByConflicts},
                {"hill_climb", algo::HillClimb}};
 
-IRModule PlanMemory(const IRModule& mod, String algo) {
+IRModule PlanMemory(const IRModule& mod, String algo, bool use_workspace_io) {
   VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
-  PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
-  BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, mod);
+  IRModule module = mod->ShallowCopy();
+  if (use_workspace_io) {
+    module = transform::CreateAllocatesForIO()(module);
+  }
+  module = transform::AssignPoolInfo()(module);
+  PrimFunc main_func = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
+  BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, module);
   Array<BufferInfo> buffer_info_arr =
       CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts);
   CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo
@@ -61,9 +69,14 @@ IRModule PlanMemory(const IRModule& mod, String algo) {
       algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure);
   Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations(
       buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
-  IRModule ret = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(mod);
+  module = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(module);
+  if (use_workspace_io) {
+    Map<String, PoolAllocation> io_pool_allocations =
+        GetIOPoolAllocations(buffer_info_pool_allocations);
+    module = WithAttr(module, tvm::attr::kIOTensorPoolAllocations, io_pool_allocations);
+  }
   tir::PrimFunc tir_main_func =
-      Downcast<tir::PrimFunc>(ret->Lookup(::tvm::runtime::symbol::tvm_module_main));
+      Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
   Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
       tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
   if (allocated_pool_infos) {
@@ -71,7 +84,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) {
       VLOG(1) << "pool_size = " << allocated_pool_info->allocated_size;
     }
   }
-  return ret;
+  return module;
 }
 
 }  // namespace usmp
@@ -81,14 +94,25 @@ namespace transform {
 tvm::transform::Pass UnifiedStaticMemoryPlanner() {
   auto usmp_main_pass_func = [=](IRModule m, tvm::transform::PassContext ctx) {
     auto algorithm_str = ctx->GetConfig(kUSMPAlgorithmOption, String(usmp::kDefaultAlgo));
-    return Downcast<IRModule>(
-        usmp::PlanMemory(m, algorithm_str.value_or(String(usmp::kDefaultAlgo))));
+    auto use_workspace_io = ctx->GetConfig(kUSMPUseWorkspaceIO, Bool(false));
+    tvm::relay::Executor executor_config =
+        m->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor).value();
+    String interface_api = executor_config->GetAttr<String>("interface-api").value_or("packed");
+    tvm::relay::Runtime runtime_config =
+        m->GetAttr<tvm::relay::Runtime>(tvm::attr::kRuntime).value();
+    if (use_workspace_io.value()) {
+      CHECK(interface_api == "c") << kUSMPUseWorkspaceIO
+                                  << " option is only compatible with interface_api c.\n"
+                                  << "Please use interface_api c to be able to enable "
+                                  << kUSMPUseWorkspaceIO << "\n";
+    }
+    return Downcast<IRModule>(usmp::PlanMemory(m,
+                                               algorithm_str.value_or(String(usmp::kDefaultAlgo)),
+                                               use_workspace_io.value_or(Bool(false))));
   };
 
-  return tvm::transform::Sequential(
-      {tvm::tir::usmp::transform::AssignPoolInfo(),
-       tvm::transform::CreateModulePass(usmp_main_pass_func, 0,
-                                        "tir.transform.UnifiedStaticMemoryPlanner", {})});
+  return tvm::transform::CreateModulePass(usmp_main_pass_func, 0,
+                                          "tir.transform.UnifiedStaticMemoryPlanner", {});
 }
 
 TVM_REGISTER_GLOBAL("tir.transform.UnifiedStaticMemoryPlanner")
diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc
index 03fac32590..d02f0d8d33 100644
--- a/src/tir/usmp/utils.cc
+++ b/src/tir/usmp/utils.cc
@@ -37,12 +37,13 @@ namespace tir {
 namespace usmp {
 
 BufferInfo::BufferInfo(String name_hint, Integer size_bytes, Array<PoolInfo> pool_candidates,
-                       Integer alignment) {
+                       Integer alignment, BufferInfoKind kind) {
   auto bufinfo_node = make_object<BufferInfoNode>();
   bufinfo_node->name_hint = name_hint;
   bufinfo_node->size_bytes = size_bytes;
   bufinfo_node->pool_candidates = pool_candidates;
   bufinfo_node->alignment = alignment;
+  bufinfo_node->kind = kind;
   data_ = std::move(bufinfo_node);
 }
 
@@ -65,10 +66,15 @@ TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoSetConflicts")
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<BufferInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
       auto* node = static_cast<const BufferInfoNode*>(ref.get());
+      std::unordered_map<BufferInfoKind, String> toString = {
+          {BufferInfoKind::kIntermediate, "kIntermediate"},
+          {BufferInfoKind::kInput, "kInput"},
+          {BufferInfoKind::kOutput, "kOutput"}};
       p->stream << "BufferInfoNode(\n"
                 << "name_hint=" << node->name_hint << ",\n  size_bytes=" << node->size_bytes
                 << ",\n  pool_candidates=" << node->pool_candidates
-                << ",\n  alignment=" << node->alignment << ")";
+                << ",\n  alignment=" << node->alignment << ",\n  kind=" << toString[node->kind]
+                << ")";
     });
 
 BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts,
@@ -161,6 +167,19 @@ Map<Stmt, PoolAllocation> AssignStmtPoolAllocations(
   return ret;
 }
 
+Map<String, PoolAllocation> GetIOPoolAllocations(
+    const Map<BufferInfo, PoolAllocation>& buffer_info_to_pool_allocation) {
+  Map<String, PoolAllocation> io_tensor_name_to_pool_allocation;
+  for (const auto& kv : buffer_info_to_pool_allocation) {
+    BufferInfo buffer_info = kv.first;
+    PoolAllocation pool_allocation = kv.second;
+    if (buffer_info->kind != BufferInfoKind::kIntermediate) {
+      io_tensor_name_to_pool_allocation.Set(buffer_info->name_hint, pool_allocation);
+    }
+  }
+  return io_tensor_name_to_pool_allocation;
+}
+
 Integer CalculateExtentsSize(const AllocateNode* op) {
   size_t element_size_bytes = op->dtype.bytes();
   size_t num_elements = 1;
diff --git a/tests/cpp/target/source/interface_c_test.cc b/tests/cpp/target/source/interface_c_test.cc
index 71657a89e4..bc81d48b27 100644
--- a/tests/cpp/target/source/interface_c_test.cc
+++ b/tests/cpp/target/source/interface_c_test.cc
@@ -31,6 +31,7 @@ namespace codegen {
 
 runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
                                  Array<String> outputs, Array<tir::usmp::AllocatedPoolInfo> pools,
+                                 Map<String, tir::usmp::PoolAllocation> io_pool_allocations,
                                  Array<String> devices, int workspace_size);
 
 namespace {
@@ -52,7 +53,7 @@ TEST(InterfaceAPI, ContainsHeaderGuards) {
                      << "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str()));
@@ -73,7 +74,7 @@ TEST(InterfaceAPI, ContainsRunFunction) {
                << ");\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
 }
@@ -94,7 +95,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithDevices) {
                << ");\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {"device"}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
@@ -118,13 +119,56 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) {
   PoolInfo pool_info = PoolInfo("my_memory_pool", {});
   tir::usmp::AllocatedPoolInfo allocated_pool_info =
       tir::usmp::AllocatedPoolInfo(pool_info, 100000);
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0);
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
+                                                 {allocated_pool_info}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
 }
 
+TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) {
+  std::stringstream run_function_with_map_functions;
+
+  run_function_with_map_functions
+      << "/*!\n"
+      << " * \\brief Maps I/O inside the workspace pools for TVM module \"ultimate_cat_spotter\"\n"
+      << " * \\param workspace_pools Workspace memory pool struct for the module \n"
+      << " * \\return I/O tensor struct for the module \n"
+      << " */\n"
+      << "struct tvmgen_ultimate_cat_spotter_inputs tvmgen_ultimate_cat_spotter_map_inputs(\n"
+      << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n"
+      << ");\n"
+      << "\n"
+      << "/*!\n"
+      << " * \\brief Maps I/O inside the workspace pools for TVM module \"ultimate_cat_spotter\"\n"
+      << " * \\param workspace_pools Workspace memory pool struct for the module \n"
+      << " * \\return I/O tensor struct for the module \n"
+      << " */\n"
+      << "struct tvmgen_ultimate_cat_spotter_outputs tvmgen_ultimate_cat_spotter_map_outputs(\n"
+      << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n"
+      << ");\n"
+      << "\n"
+      << "/*!\n"
+      << " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n"
+      << " * \\param workspace_pools Workspace memory pool pointers for the module \n"
+      << " */\n"
+      << "int32_t tvmgen_ultimate_cat_spotter_run(\n"
+      << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n"
+      << ");\n";
+
+  PoolInfo pool_info = PoolInfo("my_memory_pool", {});
+  tir::usmp::AllocatedPoolInfo allocated_pool_info =
+      tir::usmp::AllocatedPoolInfo(pool_info, 100000);
+  tir::usmp::PoolAllocation pool_allocation_input{pool_info, 1000};
+  tir::usmp::PoolAllocation pool_allocation_output{pool_info, 2000};
+  runtime::Module test_module = InterfaceCCreate(
+      "ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info},
+      {{"input", pool_allocation_input}, {"output", pool_allocation_output}}, {}, 0);
+  std::string header_source = test_module->GetSource();
+  std::cout << header_source << "\n";
+  ASSERT_THAT(header_source, HasSubstr(run_function_with_map_functions.str()));
+}
+
 TEST(InterfaceAPI, ContainsInputStructSingle) {
   std::stringstream input_struct;
 
@@ -136,7 +180,7 @@ TEST(InterfaceAPI, ContainsInputStructSingle) {
                << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -151,7 +195,7 @@ TEST(InterfaceAPI, ContainsInputStructMany) {
                << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -166,7 +210,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) {
                << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
@@ -174,7 +218,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) {
 
 TEST(InterfaceAPI, ContainsInputStructClash) {
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, {}, {}, 0);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
@@ -189,7 +233,7 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -204,7 +248,7 @@ TEST(InterfaceAPI, ContainsOutputStructMany) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -219,7 +263,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
@@ -227,7 +271,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) {
 
 TEST(InterfaceAPI, ContainsOutputStructClash) {
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, {}, {}, 0);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
@@ -241,7 +285,7 @@ TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str())));
@@ -258,7 +302,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSingle) {
                 << "};\n\n";
 
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {"device"}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
@@ -273,7 +317,7 @@ TEST(InterfaceAPI, ContainsDeviceStructMany) {
                 << "};\n\n";
 
   runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {},
-                                                 {"device1", "device2"}, 0);
+                                                 {}, {"device1", "device2"}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
@@ -288,7 +332,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
                 << "};\n\n";
 
   runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {},
-                                                 {"device+1", "device+2"}, 0);
+                                                 {}, {"device+1", "device+2"}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
@@ -296,13 +340,13 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
 
 TEST(InterfaceAPI, ContainsDeviceStructClash) {
   runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {},
-                                                 {"device+", "device-"}, 0);
+                                                 {}, {"device+", "device-"}, 0);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
 TEST(InterfaceAPI, ContainsWorkspaceSize) {
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 765432);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, {}, 765432);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source,
@@ -327,8 +371,8 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) {
       << "  void* my_memory_pool;\n"
       << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0);
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
+                                                 {allocated_pool_info}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -362,7 +406,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) {
 
   runtime::Module test_module =
       InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
-                       {allocated_pool_info1, allocated_pool_info2}, {}, 0);
+                       {allocated_pool_info1, allocated_pool_info2}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -397,8 +441,8 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) {
       << "  void* my_memory_pool_1;\n"
       << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0);
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
+                                                 {allocated_pool_info}, {}, {}, 0);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -421,7 +465,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) {
 
   runtime::Module test_module =
       InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
-                       {allocated_pool_info1, allocated_pool_info2}, {}, 0);
+                       {allocated_pool_info1, allocated_pool_info2}, {}, {}, 0);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
diff --git a/tests/micro/zephyr/test_utils.py b/tests/micro/zephyr/test_utils.py
index ea17ac9a35..e0aad7c3c6 100644
--- a/tests/micro/zephyr/test_utils.py
+++ b/tests/micro/zephyr/test_utils.py
@@ -210,7 +210,7 @@ def generate_project(
                         model_files_path, arcname=os.path.relpath(model_files_path, tar_temp_dir)
                     )
                 header_path = generate_c_interface_header(
-                    lowered.libmod_name, ["input_1"], ["Identity"], [], [], 0, model_files_path
+                    lowered.libmod_name, ["input_1"], ["Identity"], [], {}, [], 0, model_files_path
                 )
                 tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
 
diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py
index 3318473a83..2c4262a3d2 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -169,6 +169,16 @@ AOT_USMP_CORSTONE300_RUNNER = AOTTestRunner(
     },
 )
 
+NP_TYPE_TO_C = {
+    "int8": "int8_t",
+    "uint8": "uint8_t",
+    "int16": "int16_t",
+    "uint16": "uint16_t",
+    "int32": "int32_t",
+    "uint32": "uint32_t",
+    "float32": "float",
+}
+
 
 def mangle_name(mod_name, name):
     mod_name = mangle_module_name(mod_name)
@@ -429,11 +439,14 @@ def emit_main_data_setup(main_file, input_map, output_map, mod_name):
     main_file.write("};\n")
 
 
-def emit_main_c_interface_call(main_file, devices, workspace_pool_names, mod_name):
+def emit_main_c_interface_call(
+    main_file, devices, workspace_pool_names, mod_name, use_workspace_io
+):
     sub_strings = list()
     sub_strings.append(f'{mangle_name(mod_name,"run")}(')
-    sub_strings.append(f'&{mangle_name(mod_name,"inputs")}, ')
-    sub_strings.append(f'&{mangle_name(mod_name,"outputs")}, ')
+    if not use_workspace_io:
+        sub_strings.append(f'&{mangle_name(mod_name,"inputs")}, ')
+        sub_strings.append(f'&{mangle_name(mod_name,"outputs")}, ')
     if workspace_pool_names:
         sub_strings.append(f'&{mangle_name(mod_name,"workspace_pools")}, ')
     if devices:
@@ -500,10 +513,9 @@ def emit_main_packed_call(main_file, input_map, output_list, mod_name):
     main_file.write("\n")
 
 
-def emit_main_compare(main_file, outputs, output_tolerance, mod_name):
+def emit_main_compare(main_file, outputs, output_tolerance, mod_name, use_interface_c=False):
     for key in outputs:
         sanitized_tensor_name = re.sub(r"\W", "_", key)
-        actual_data_name = mangle_name(mod_name, f"output_data_{sanitized_tensor_name}")
         expected_data_name = mangle_name(mod_name, f"expected_output_data_{sanitized_tensor_name}")
         is_float_dtype = outputs[key].dtype == "float32"
 
@@ -513,9 +525,19 @@ def emit_main_compare(main_file, outputs, output_tolerance, mod_name):
             comparison_function = "fabs"
             tolerance = output_tolerance or 0.001
 
+        data_length_var_name = (
+            mangle_name(mod_name, f"output_data_{sanitized_tensor_name}") + "_len"
+        )
+        if use_interface_c:
+            c_type = NP_TYPE_TO_C[str(outputs[key].dtype)]
+            actual_data_name = f"(({c_type}*)" + mangle_name(
+                mod_name, f"outputs.{sanitized_tensor_name})"
+            )
+        else:
+            actual_data_name = mangle_name(mod_name, f"output_data_{sanitized_tensor_name}")
         main_file.write(
             f"""
-            for (int i = 0; i<{actual_data_name}_len; i++) {{
+            for (int i = 0; i<{data_length_var_name}; i++) {{
                 if ({comparison_function}({actual_data_name}[i]-{expected_data_name}[i]) > {tolerance}) {{
                     printf("{AOT_FAILURE_TOKEN}\\n");
                     return -1;
@@ -563,6 +585,7 @@ def create_main(
     interface_api,
     workspace_bytes,
     use_stack_allocator=True,
+    use_workspace_io=False,
 ):
     file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
     # create header file
@@ -605,9 +628,12 @@ def create_main(
                         if not allocated_pool.pool_info.is_internal
                     ]
                 emit_main_device_structs(main_file, devices, model.name)
-                emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name)
-                emit_main_data_structs(main_file, model.inputs, model.outputs, model.name)
-                emit_main_c_interface_call(main_file, devices, workspace_pool_names, model.name)
+                if not use_workspace_io:
+                    emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name)
+                    emit_main_data_structs(main_file, model.inputs, model.outputs, model.name)
+                emit_main_c_interface_call(
+                    main_file, devices, workspace_pool_names, model.name, use_workspace_io
+                )
         else:
             emit_main_fake_packed_values(main_file)
             for compiled_model in compiled_models:
@@ -617,7 +643,9 @@ def create_main(
 
         for compiled_model in compiled_models:
             model = compiled_model.model
-            emit_main_compare(main_file, model.outputs, model.output_tolerance, model.name)
+            emit_main_compare(
+                main_file, model.outputs, model.output_tolerance, model.name, interface_api == "c"
+            )
         emit_main_epilogue(main_file, custom_epilogue)
 
 
@@ -627,15 +655,6 @@ def create_header_file(tensor_name, npy_data, output_path, data_linkage):
     It is used to capture the tensor data (for both inputs and expected outputs) to be bundled into the standalone application.
     """
     file_path = pathlib.Path(f"{output_path}/" + tensor_name).resolve()
-    np_type_to_c = {
-        "int8": "int8_t",
-        "uint8": "uint8_t",
-        "int16": "int16_t",
-        "uint16": "uint16_t",
-        "int32": "int32_t",
-        "uint32": "uint32_t",
-        "float32": "float",
-    }
     # create header file
     raw_path = file_path.with_suffix(".h").resolve()
     with open(raw_path, "w") as header_file:
@@ -646,7 +665,7 @@ def create_header_file(tensor_name, npy_data, output_path, data_linkage):
 
         emit_data_linkage(header_file, data_linkage)
 
-        header_file.write(f"{np_type_to_c[str(npy_data.dtype)]} {tensor_name}[] =")
+        header_file.write(f"{NP_TYPE_TO_C[str(npy_data.dtype)]} {tensor_name}[] =")
 
         header_file.write("{")
         for i in np.ndindex(npy_data.shape):
@@ -726,6 +745,7 @@ def run_and_check(
     data_linkage: AOTDataLinkage = None,
     test_dir: str = None,
     verbose: bool = False,
+    use_workspace_io: bool = False,
 ):
     """
     This method uses the original test data and compiled runtime.Modules
@@ -805,6 +825,7 @@ def run_and_check(
             interface_api,
             workspace_bytes,
             use_stack_allocator,
+            use_workspace_io,
         )
 
         # Verify that compiles fine
@@ -931,11 +952,8 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
         main = mod
     else:
         main = mod["main"]
-    if main.attrs == None or main.attrs["output_tensor_names"] == None:
-        if output_count == 1:
-            output_tensor_names = ["output"]
-        else:
-            output_tensor_names = [f"output{i}" for i in range(output_count)]
+    if main.attrs is None or main.attrs["output_tensor_names"] is None:
+        output_tensor_names = ["output" if i == 0 else f"output{i+1}" for i in range(output_count)]
     else:
         output_tensor_names = main.attrs["output_tensor_names"]
 
diff --git a/tests/python/relay/aot/test_c_device_api.py b/tests/python/relay/aot/test_c_device_api.py
index d547b52e85..f9fa0c6ead 100644
--- a/tests/python/relay/aot/test_c_device_api.py
+++ b/tests/python/relay/aot/test_c_device_api.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
 
 import numpy as np
 import pytest
+import re
 
 from tvm import relay
 from tvm.ir.module import IRModule
@@ -133,7 +134,6 @@ def non_device_api_main_func():
 def test_device_api_hooks_unpacked_api(device_api_main_func):
     """Check for Device API hooks with unpacked internal calls"""
     main_func = device_api_main_func(interface_api="c", use_unpacked_api=True)
-    input_name = main_func.params[0].name
 
     # Activate Device
     assert (
@@ -151,12 +151,12 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
         + " device_context_ethos_u))\n"
     )
     # Device Call
-    assert (
-        str(main_func.body[1][0][0][1])
-        == "tir.tvm_check_return(0, -1, tir.call_extern("
-        + '"tvmgen_default_ethos_u_main_0",'
-        + f" {input_name}_buffer_var, output_buffer_var, device_context_ethos_u))\n"
+    # We dont need to check exact input and output var names in this test.
+    # Hence, using a regex to cover any legal I/O name.
+    regex = re.compile(
+        'tir\.tvm_check_return\(0, -1, tir\.call_extern\("tvmgen_default_ethos_u_main_0", \w+, \w+, device_context_ethos_u\)\)'
     )
+    assert regex.match(str(main_func.body[1][0][0][1]))
     # Close Device
     assert (
         str(main_func.body[1][0][0][2])
diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py
index 77ff99fd6d..23283392ee 100644
--- a/tests/python/relay/aot/test_crt_aot_usmp.py
+++ b/tests/python/relay/aot/test_crt_aot_usmp.py
@@ -18,6 +18,7 @@
 
 from collections import OrderedDict
 import sys
+import re
 
 import numpy as np
 import pytest
@@ -278,6 +279,11 @@ def _get_workspace_size_define_macro(pool_name: str, model_name="default") -> st
     return prefix + pool_name.upper() + postfix
 
 
+def _add_module_prefix(suffix: str, model_name="default") -> str:
+    """A helper function create struct types"""
+    return "tvmgen_" + model_name + "_" + suffix
+
+
 @pytest.mark.parametrize(
     "model_url, usmp_algo",
     [
@@ -458,3 +464,173 @@ def test_tflite_model_u2_usecase_two_models_with_a_single_external_pool(model_ur
         runner=test_runner,
         interface_api=interface_api,
     )
+
+
+@pytest.mark.parametrize(
+    "model_url, usmp_algo",
+    [
+        (MOBILENET_V1_URL, "greedy_by_size"),
+    ],
+)
+def test_tflite_model_u4_usecase_single_external_pool(model_url, usmp_algo):
+    """This checks for inference with USMP using external pool placed in the application"""
+    pytest.importorskip("tflite")
+
+    import tvm.relay.testing.tf as tf_testing
+
+    use_unpacked_api = True
+    interface_api = "c"
+
+    pool_name = "my_memory_pool"
+    target = tvm.target.Target("c")
+    workspace_memory_pools = WorkspaceMemoryPools(
+        [PoolInfo(pool_name, {target: PoolInfo.READ_WRITE_ACCESS})]
+    )
+
+    tflite_model_file = tf_testing.get_workload_official(
+        model_url[0],
+        model_url[1],
+    )
+    mod, inputs, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file)
+    output_list = generate_ref_data(mod, inputs, params)
+
+    input_name, input_data = list(inputs.items())[0]
+    input_size_bytes = input_data.size * input_data.itemsize
+    test_runner = AOTTestRunner(
+        pass_config={
+            "tir.usmp.enable": True,
+            "tir.usmp.algorithm": usmp_algo,
+            "tir.usmp.use_workspace_io": True,
+        },
+        prologue=f"""
+        #include <string.h>
+        __attribute__((section(".data.tvm"), aligned(16)))
+        static uint8_t {pool_name}[{_get_workspace_size_define_macro(pool_name)}];
+        struct {_add_module_prefix("workspace_pools")} {_add_module_prefix("workspace_pools")} = {{
+            .{pool_name} = {pool_name}
+        }};
+        struct {_add_module_prefix("inputs")} {_add_module_prefix("inputs")} = {_add_module_prefix("map_inputs")}(&{_add_module_prefix("workspace_pools")});
+        memcpy({_add_module_prefix("inputs")}.{input_name}, tvmgen_default_input_data_input, {input_size_bytes});
+        struct {_add_module_prefix("outputs")} {_add_module_prefix("outputs")} = {_add_module_prefix("map_outputs")}(&{_add_module_prefix("workspace_pools")});
+        """,
+    )
+
+    compiled_test_mods = compile_models(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params),
+        interface_api=interface_api,
+        use_unpacked_api=use_unpacked_api,
+        pass_config=test_runner.pass_config,
+        workspace_memory_pools=workspace_memory_pools,
+        target=target,
+    )
+
+    for compiled_model in compiled_test_mods:
+        check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+
+    run_and_check(
+        models=compiled_test_mods,
+        runner=test_runner,
+        interface_api=interface_api,
+        use_workspace_io=True,
+    )
+
+
+@pytest.mark.parametrize(
+    "model_url, usmp_algo",
+    [
+        (MOBILENET_V1_URL, "greedy_by_size"),
+    ],
+)
+def test_tflite_model_u4_usecase_two_external_pools(model_url, usmp_algo):
+    """This checks for inference with USMP using external pool placed in the application"""
+    pytest.importorskip("tflite")
+
+    import tvm.relay.testing.tf as tf_testing
+
+    use_unpacked_api = True
+    interface_api = "c"
+
+    target = tvm.target.Target("c")
+    workspace_memory_pools = WorkspaceMemoryPools(
+        [
+            PoolInfo(
+                "my_memory_pool_1", {target: PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=2500000
+            ),
+            PoolInfo("my_memory_pool_2", {target: PoolInfo.READ_WRITE_ACCESS}),
+        ]
+    )
+
+    tflite_model_file = tf_testing.get_workload_official(
+        model_url[0],
+        model_url[1],
+    )
+    mod, inputs, params = create_relay_module_and_inputs_from_tflite_file(tflite_model_file)
+    output_list = generate_ref_data(mod, inputs, params)
+
+    input_name, input_data = list(inputs.items())[0]
+    input_size_bytes = input_data.size * input_data.itemsize
+    test_runner = AOTTestRunner(
+        pass_config={
+            "tir.usmp.enable": True,
+            "tir.usmp.algorithm": usmp_algo,
+            "tir.usmp.use_workspace_io": True,
+        },
+        prologue=f"""
+        #include <string.h>
+        __attribute__((section(".data.tvm"), aligned(16)))
+        static uint8_t my_memory_pool_1[{_get_workspace_size_define_macro("my_memory_pool_1")}];
+        __attribute__((section(".data.tvm"), aligned(16)))
+        static uint8_t my_memory_pool_2[{_get_workspace_size_define_macro("my_memory_pool_2")}];
+        struct {_add_module_prefix("workspace_pools")} {_add_module_prefix("workspace_pools")} = {{
+            .my_memory_pool_1 = my_memory_pool_1,
+            .my_memory_pool_2 = my_memory_pool_2,
+        }};
+        struct {_add_module_prefix("inputs")} {_add_module_prefix("inputs")} = {_add_module_prefix("map_inputs")}(&{_add_module_prefix("workspace_pools")});
+        memcpy({_add_module_prefix("inputs")}.{input_name}, tvmgen_default_input_data_input, {input_size_bytes});
+        struct {_add_module_prefix("outputs")} {_add_module_prefix("outputs")} = {_add_module_prefix("map_outputs")}(&{_add_module_prefix("workspace_pools")});
+        """,
+    )
+
+    compiled_test_mods = compile_models(
+        AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params),
+        interface_api=interface_api,
+        use_unpacked_api=use_unpacked_api,
+        pass_config=test_runner.pass_config,
+        workspace_memory_pools=workspace_memory_pools,
+        target=target,
+    )
+
+    for compiled_model in compiled_test_mods:
+        check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib)
+
+    run_and_check(
+        models=compiled_test_mods,
+        runner=test_runner,
+        interface_api=interface_api,
+        use_workspace_io=True,
+    )
+
+
+def test_u4_usecase_incompatible_interface_api_errors():
+    mod, params = tvm.relay.testing.synthetic.get_workload()
+    target = "c"
+    runtime = Runtime("crt")
+    executor = Executor(
+        "aot",
+        {
+            "interface-api": "packed",
+        },
+    )
+
+    with pytest.raises(
+        tvm.TVMError,
+        match=re.escape(
+            "tir.usmp.use_workspace_io option is only compatible with interface_api c.\n"
+            "Please use interface_api c to be able to enable tir.usmp.use_workspace_io"
+        ),
+    ):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"tir.usmp.enable": True, "tir.usmp.use_workspace_io": True},
+        ):
+            tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params)
diff --git a/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py b/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py
new file mode 100644
index 0000000000..d72cb7f72e
--- /dev/null
+++ b/tests/python/unittest/test_tir_usmp_transform_create_io_allocates.py
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+from typing import NamedTuple, List
+
+import tvm
+from tvm.script import tir as T
+
+
+# fmt: off
+@tvm.script.ir_module
+class SingleInputSingleOutput:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
+
+    @T.prim_func
+    def __tvm_main__(input: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True})
+        input_buffer_var = T.match_buffer(input, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        output_buffer_var = T.match_buffer(output, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p0", dtype="handle"), output_buffer_var.data, dtype="int32"))
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class TwoInputSingleOutput:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
+
+    @T.prim_func
+    def __tvm_main__(input1: T.handle, input2: T.handle, output: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True})
+        input1_buffer_var = T.match_buffer(input1, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        input2_buffer_var = T.match_buffer(input2, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        output_buffer_var = T.match_buffer(output, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input1_buffer_var.data, input2_buffer_var.data, output_buffer_var.data, dtype="int32"))
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class TwoInputTwoOutput:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
+
+    @T.prim_func
+    def __tvm_main__(input1: T.handle, input2: T.handle, output1: T.handle, output2: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True})
+        input1_buffer_var = T.match_buffer(input1, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        input2_buffer_var = T.match_buffer(input2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        output1_buffer_var = T.match_buffer(output1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        output2_buffer_var = T.match_buffer(output2, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input1_buffer_var.data, T.lookup_param("p0", dtype="handle"), output1_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input2_buffer_var.data, T.lookup_param("p1", dtype="handle"), output2_buffer_var.data, dtype="int32"))
+# fmt: on
+
+
+# fmt: off
+@tvm.script.ir_module
+class SingleInputTwoOutput:
+    @T.prim_func
+    def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True})
+        placeholder_4 = T.match_buffer(placeholder_2, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        placeholder_5 = T.match_buffer(placeholder_3, [1], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        T_subtract_1 = T.match_buffer(T_subtract, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        for ax0_ax1_fused_1 in T.serial(0, 224):
+            for ax2_1, ax3_inner_1 in T.grid(224, 3):
+                T_subtract_1[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)] = (T.cast(placeholder_4[(((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)], "int16") - placeholder_5[0])
+
+    @T.prim_func
+    def __tvm_main__(input: T.handle, output1: T.handle, output2: T.handle) -> None:
+        # function attr dict
+        T.func_attr({"global_symbol": "__tvm_main__", "runner_function": True})
+        input_buffer_var = T.match_buffer(input, [150528], dtype="uint8", elem_offset=0, align=128, offset_factor=1)
+        output1_buffer_var = T.match_buffer(output1, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        output2_buffer_var = T.match_buffer(output2, [452], dtype="int16", elem_offset=0, align=128, offset_factor=1)
+        # body
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p0", dtype="handle"), output1_buffer_var.data, dtype="int32"))
+        T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input_buffer_var.data, T.lookup_param("p1", dtype="handle"), output2_buffer_var.data, dtype="int32"))
+# fmt: on
+
+
+class IOInfo(NamedTuple):
+    """A data structure to hold test outputs per I/O tensor"""
+
+    name: str
+    shape: list
+    dtype: str
+
+
+def check_io_allocations(mod: tvm.IRModule, inputs: List[IOInfo], outputs: List[IOInfo]):
+    """This function checks whether outer most allocates correspond to I/O tensors"""
+    found_non_io_allocate_node = False
+
+    input_name_to_info = {}
+    for input in inputs:
+        input_name_to_info[input.name] = input
+    output_name_to_info = {}
+    for output in outputs:
+        output_name_to_info[output.name] = output
+
+    def _visit(stmt):
+        nonlocal found_non_io_allocate_node
+        if isinstance(stmt, tvm.tir.Allocate) and not found_non_io_allocate_node:
+            allocate = stmt
+            if dict(allocate.annotations).get("input_tensor"):
+                input_tensor_name = str(dict(allocate.annotations).get("input_tensor"))
+                assert input_tensor_name in input_name_to_info.keys()
+                assert input_name_to_info[input_tensor_name].shape == list(allocate.extents)
+                assert input_name_to_info[input_tensor_name].dtype == str(allocate.dtype)
+                del input_name_to_info[input_tensor_name]
+            if dict(allocate.annotations).get("output_tensor"):
+                output_tensor_name = str(dict(allocate.annotations).get("output_tensor"))
+                assert output_tensor_name in output_name_to_info.keys()
+                assert output_name_to_info[output_tensor_name].shape == list(allocate.extents)
+                assert output_name_to_info[output_tensor_name].dtype == str(allocate.dtype)
+                del output_name_to_info[output_tensor_name]
+        else:
+            found_non_io_allocate_node = True
+
+    main = mod["__tvm_main__"]
+    tvm.tir.stmt_functor.ir_transform(main.body, _visit, None, ["tir.Allocate", "tir.Call"])
+    assert len(input_name_to_info) == 0
+    assert len(output_name_to_info) == 0
+
+
+@pytest.mark.parametrize(
+    "test_mod, input_names, output_names",
+    [
+        (
+            SingleInputSingleOutput,
+            [IOInfo("input", [150528], "uint8")],
+            [IOInfo("output", [452], "int16")],
+        ),
+        (
+            SingleInputTwoOutput,
+            [IOInfo("input", [150528], "uint8")],
+            [IOInfo("output1", [452], "int16"), IOInfo("output2", [452], "int16")],
+        ),
+        (
+            TwoInputSingleOutput,
+            [IOInfo("input1", [150528], "uint8"), IOInfo("input2", [1], "int16")],
+            [IOInfo("output", [452], "int16")],
+        ),
+        (
+            TwoInputTwoOutput,
+            [IOInfo("input1", [150528], "uint8"), IOInfo("input2", [150528], "uint8")],
+            [IOInfo("output1", [452], "int16"), IOInfo("output2", [452], "int16")],
+        ),
+    ],
+)
+def test_mobilenet_subgraph(test_mod, input_names, output_names):
+    CreateAllocatesForIO = tvm.get_global_func("tir.usmp.transform.CreateAllocatesForIO")
+    test_mod = CreateAllocatesForIO()(test_mod)
+    check_io_allocations(test_mod, input_names, output_names)