You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ar...@apache.org on 2021/06/29 23:16:47 UTC

[tvm] branch main updated: Decoupling AOT from graph memory planner (#8096)

This is an automated email from the ASF dual-hosted git repository.

areusch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b803bab  Decoupling AOT from graph memory planner (#8096)
b803bab is described below

commit b803bab4a3e072d00ca08d0dcf9a5585194b98d4
Author: Giuseppe Rossini <gi...@arm.com>
AuthorDate: Wed Jun 30 00:16:29 2021 +0100

    Decoupling AOT from graph memory planner (#8096)
    
    * Fix an issue with storage-rewrite pass and packed functions
    
    Change-Id: I13888471d4b8927a4012d6a8e749fb7a8935dd77
    
    * Rebasing
    
    Change-Id: I7aa12e0217b8a2e1ff2a97a7c5fdda6b7597ae64
    
    * Addressing comments
    
    Change-Id: If9f1ee190690f9a810fe41eb1933d736f1eb4ec3
    
    * Add a pass to legalize packed calls
    
    Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436
    
    * Add a unit test for the legalization pass
    
    Change-Id: I5b0d75380ff660dd5a0acf5b14fa84bb992fbec4
    
    * rebasing
    
    Change-Id: I52ceab5cf6e9b54390cb36c18dbb8e22505d8e18
    
    * Use common StorageInfo
    
    Change-Id: Ia8b7de1373f167ca7d0d69a99846d417405bbe48
---
 include/tvm/tir/transform.h                        |   5 +
 python/tvm/tir/transform/transform.py              |  11 +
 src/relay/backend/aot_executor_codegen.cc          | 354 ++++++++++++++-------
 src/tir/transforms/ir_utils.h                      |  23 ++
 src/tir/transforms/legalize_packed_calls.cc        | 121 +++++++
 src/tir/transforms/lower_tvm_builtin.cc            |  10 -
 tests/python/relay/aot/aot_test_utils.py           |  47 ++-
 tests/python/relay/aot/test_crt_aot.py             |  48 +++
 .../unittest/test_aot_legalize_packed_call.py      |  80 +++++
 9 files changed, 573 insertions(+), 126 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 2113d58..5ee847e 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -419,6 +419,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
 TVM_DLL Pass CompactBufferAllocation();
 
 /*!
+ * This pass legalizes packed calls by wrapping their arguments into TVMValues
+ */
+TVM_DLL Pass LegalizePackedCalls();
+
+/*!
  * \brief Flatten the multi-dimensional BufferLoad and BufferStore
  *        to single dimensional Load/Store. Also remove Block to
  *        ensure that the flattened TIR can not be scheduled again.
diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py
index 8a32a7e..51330f8 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -451,6 +451,17 @@ def LowerTVMBuiltin():
     return _ffi_api.LowerTVMBuiltin()
 
 
+def LegalizePackedCalls():
+    """Legalize packed calls to have its arguments wrapped in TVMValues
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.LegalizePackedCalls()
+
+
 def LowerIntrin():
     """Lower target specific intrinsic calls.
 
diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc
index 93935af..9b495ad 100644
--- a/src/relay/backend/aot_executor_codegen.cc
+++ b/src/relay/backend/aot_executor_codegen.cc
@@ -31,6 +31,7 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/function.h>
 #include <tvm/tir/stmt.h>
+#include <tvm/tir/transform.h>
 
 #include <algorithm>
 #include <list>
@@ -46,50 +47,175 @@ namespace backend {
 
 using IntegerArray = Array<Integer>;
 using TargetsMap = std::unordered_map<int, Target>;
+using StorageMap =
+    std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
 
-class AotReturnSidVisitor : public ExprVisitor {
+/**
+ * This is an on demand allocator for AOT. A new temporary
+ * (storage allocator identifier) is allocated for each operation.
+ */
+class AOTOnDemandAllocator : public ExprVisitor {
  public:
-  explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map)
-      : storage_device_map_{storage_device_map}, return_sid_{-1} {}
+  // run the visitor on a function.
+  void Run(const Function& func) {
+    node_device_map_ = CollectDeviceInfo(func);
 
-  IntegerArray FindReturnSid(Function func) {
-    VisitExpr(func->body);
-    return return_sid_;
+    for (Expr param : func->params) {
+      CreateStorage(param.operator->());
+    }
+
+    GetStorage(func->body);
   }
 
- protected:
-  void AssignReturnSid(Expr e) {
-    auto iter = storage_device_map_.find(e);
-    if (iter != storage_device_map_.end()) {
-      return_sid_ = (*iter).second[0];
+  std::vector<int> GetReturnIds() const { return return_ids_; }
+
+  StorageMap GetStorageMap() const { return storage_device_map_; }
+
+  void VisitExpr_(const ConstantNode* op) final {
+    CreateStorage(op);
+    AssignReturnSid(GetRef<Expr>(op));
+  }
+
+  void VisitExpr_(const CallNode* op) final {
+    // create token for the call node.
+    CreateStorage(op);
+    for (Expr arg : op->args) {
+      GetStorage(arg);
     }
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const ConstantNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const VarNode* op) final {
+    ExprVisitor::VisitExpr_(op);
+    AssignReturnSid(GetRef<Expr>(op));
   }
 
-  void VisitExpr_(const VarNode* vn) override {
-    ExprVisitor::VisitExpr_(vn);
-    AssignReturnSid(GetRef<Expr>(vn));
+  void VisitExpr_(const FunctionNode* op) final {
+    // do not recurse into sub function.
   }
 
-  void VisitExpr_(const CallNode* cn) override {
-    ExprVisitor::VisitExpr_(cn);
-    AssignReturnSid(GetRef<Expr>(cn));
+  void VisitExpr_(const GlobalVarNode* op) final {
+    // Do nothing.
   }
 
-  void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
+  void VisitExpr_(const OpNode* op) final {
+    // Do nothing.
+  }
 
-  void VisitExpr_(const TupleNode* tn) override {
-    ExprVisitor::VisitExpr_(tn);
-    AssignReturnSid(GetRef<Expr>(tn));
+  void VisitExpr_(const TupleNode* op) final {
+    std::vector<int64_t> storage_ids;
+    std::vector<DLDeviceType> device_types;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    Expr expr = GetRef<Expr>(op);
+    for (Expr field : op->fields) {
+      auto sid = GetStorage(field);
+      storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
+      device_types.insert(device_types.end(), sid->device_types.begin(), sid->device_types.end());
+      storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
+                                    sid->storage_sizes_in_bytes.begin(),
+                                    sid->storage_sizes_in_bytes.end());
+    }
+    storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes);
+    AssignReturnSid(expr);
   }
 
+  void VisitExpr_(const TupleGetItemNode* op) final {
+    Expr expr = GetRef<Expr>(op);
+    auto sids = GetStorage(op->tuple);
+    ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
+    storage_device_map_[expr] =
+        StorageInfo({sids->storage_ids[op->index]}, {sids->device_types[op->index]},
+                    {sids->storage_sizes_in_bytes[op->index]});
+    AssignReturnSid(expr);
+  }
+
+  void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
+
+  void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "let is not supported."; }
+
  private:
-  Map<Expr, Array<IntegerArray>> storage_device_map_;
-  IntegerArray return_sid_;
+  void AssignReturnSid(Expr e) {
+    if (storage_device_map_.find(e) != storage_device_map_.end()) {
+      StorageInfo& sinfo = storage_device_map_[e];
+      return_ids_.clear();
+      for (auto sid : sinfo->storage_ids) {
+        return_ids_.push_back(sid);
+      }
+    }
+  }
+  /*!
+   * \brief ceil(size/word_size) to get number of words.
+   * \param size The original size.
+   * \param word_size The element size.
+   */
+  static size_t DivRoundUp(size_t size, size_t word_size) {
+    return (size + word_size - 1) / word_size;
+  }
+  /*!
+   * \brief Get the memory requirement.
+   * \param prototype The prototype token.
+   * \return The required memory size.
+   */
+  size_t GetMemorySizeBytes(const TensorTypeNode* ttype) {
+    ICHECK(ttype != nullptr);
+    size_t size = 1;
+    for (IndexExpr dim : ttype->shape) {
+      const int64_t* pval = tir::as_const_int(dim);
+      ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
+      ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
+      size *= static_cast<size_t>(pval[0]);
+    }
+    size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
+    return size;
+  }
+  /*!
+   * \brief Get the necessary storage for the expression.
+   * \param expr The expression.
+   * \return The corresponding token.
+   */
+  StorageInfo GetStorage(const Expr& expr) {
+    this->VisitExpr(expr);
+    auto it = storage_device_map_.find(expr);
+    ICHECK(it != storage_device_map_.end());
+    return it->second;
+  }
+
+  /*!
+   * \brief Create storage for the expression.
+   * \param expr The expression.
+   */
+  void CreateStorage(const ExprNode* op) {
+    std::vector<int64_t> storage_ids;
+    std::vector<DLDeviceType> device_types;
+    std::vector<int64_t> storage_sizes_in_bytes;
+    Expr expr = GetRef<Expr>(op);
+    int device_type_int =
+        node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0;
+    if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
+      for (Type t : tuple_type->fields) {
+        const auto* ttype = t.as<TensorTypeNode>();
+        ICHECK(ttype);
+        storage_ids.push_back(next_available_sid_++);
+        storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
+        device_types.push_back(DLDeviceType(device_type_int));
+      }
+    } else {
+      const auto* ttype = op->checked_type().as<TensorTypeNode>();
+      ICHECK(ttype);
+      storage_ids.push_back(next_available_sid_++);
+      storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
+      device_types.push_back(DLDeviceType(device_type_int));
+    }
+    storage_device_map_[expr] = StorageInfo(storage_ids, device_types, storage_sizes_in_bytes);
+  }
+  /*! \brief mapping of expression -> storageInfo*/
+  StorageMap storage_device_map_;
+  /*! \brief mapping of expression -> device type*/
+  Map<Expr, Integer> node_device_map_;
+  /*! \brief current id of the temporary allocated*/
+  int next_available_sid_{0};
+  /*! \brief the set of intermediate tensors that are return variables */
+  std::vector<int> return_ids_;
 };
 
 /*! \brief Code generator for AOT executor */
@@ -120,65 +246,24 @@ class AOTExecutorCodegen : public ExprVisitor {
    * \brief Return a vector of variables that represents the sids for the given Relay Expr
    */
   std::vector<tir::Var> PackSid(Expr expr) {
-    Array<IntegerArray> sids = storage_device_map_[expr];
-    std::vector<tir::Var> sid_vars;
+    std::vector<tir::Var> buffer_vars;
+    StorageInfo& sinfo = storage_device_map_[expr];
 
     // Note that an expression can have multiple sids associated with it
     // e.g., returning multiple values from a function
-    for (const auto& sid : sids[0]) {
+    for (auto sid : sinfo->storage_ids) {
       // Determine if an sid is an output buffer
-      int sid_int = static_cast<int>((sid.as<IntImmNode>())->value);
-      auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
+      auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
       if (output_iter != return_sid_.end()) {
         int output_index = std::distance(return_sid_.begin(), output_iter);
-        sid_vars.push_back(main_signature_[input_vars_.size() + output_index]);
+        buffer_vars.push_back(main_signature_[input_vars_.size() + output_index]);
         continue;
       }
-      // Pack the sid inside the TVMValue
-      auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
-      auto sid_value = sids_table_[sid];
 
-      if (!use_unpacked_api_) {
-        tvm::PrimExpr set_tensor =
-            tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
-                           {sid_array, 0, tir::builtin::kArrData, sid_value});
-        stmts_.push_back(
-            tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
-      } else {
-        stmts_.push_back(tir::LetStmt(sid_array, sid_value, tir::Evaluate(0)));
-      }
-
-      sid_vars.push_back(sid_array);
+      auto sid_value = sids_table_[sid];
+      buffer_vars.push_back(sid_value);
     }
-    return sid_vars;
-  }
-
-  /*!
-   * \brief Utility function to return a parameter associated with an expression
-   * \param expr Relay Expression associated with the parameter
-   * \return Variable that represents the DLTensor associated with the parameters
-   */
-  tir::Var PackParam(Expr expr) {
-    int param_sid = param_storage_ids_[params_by_expr_[expr]];
-    auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());
-
-    // Compose the lookup_call using a local stack
-    Array<tir::Stmt> lookup_call;
-    // Set the param to the value returned by lookup_call
-    auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
-                                       {tir::StringImm(params_by_expr_[expr])});
-
-    if (!use_unpacked_api_) {
-      tvm::PrimExpr set_param_array =
-          tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
-                         {param_array, 0, tir::builtin::kArrData, param_handle});
-      stmts_.push_back(
-          tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
-    } else {
-      stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
-    }
-
-    return param_array;
+    return buffer_vars;
   }
 
   /*!
@@ -190,9 +275,6 @@ class AOTExecutorCodegen : public ExprVisitor {
       // Input variable
       int main_index = std::distance(input_vars_.begin(), input_iter);
       return {main_signature_[main_index]};
-    } else if (params_by_expr_.find(arg) != params_by_expr_.end()) {
-      // Parameter of the network
-      return {PackParam(arg)};
     } else {
       // Storage identifier (i.e., intermediate memory)
       return PackSid(arg);
@@ -208,8 +290,14 @@ class AOTExecutorCodegen : public ExprVisitor {
 
     // Pack the inputs
     for (Expr arg : call->args) {
-      auto var_arg = FindExpr(arg);
-      args.push_back(var_arg[0]);
+      if (params_by_expr_.find(arg) != params_by_expr_.end()) {
+        auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+                                           {tir::StringImm(params_by_expr_[arg])});
+        args.push_back(param_handle);
+      } else {
+        auto var_arg = FindExpr(arg);
+        args.push_back(var_arg[0]);
+      }
     }
 
     auto ret_expr = Downcast<Expr>(call);
@@ -237,7 +325,7 @@ class AOTExecutorCodegen : public ExprVisitor {
    * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a
    * copy-on-write fashion.
    */
-  void CopyToOutput(te::Var out, te::Var in, size_t size) {
+  void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
     // Define intermediate DLTensor to load/store the data
     auto tmp0 = te::Var("tmp0", DataType::Handle());
     auto tmp1 = te::Var("tmp1", DataType::Handle());
@@ -249,10 +337,15 @@ class AOTExecutorCodegen : public ExprVisitor {
     PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
                                       {out, 0, tir::builtin::kArrData});
     if (use_unpacked_api_) {
-      retval_get = in;
       tostore = out;
     }
 
+    // Do not pack the input if the flag is set or the caller
+    // explicitly asked to do so (e.g., copying a param to the output)
+    if (use_unpacked_api_ || !pack_input) {
+      retval_get = in;
+    }
+
     // Copy the variable from the input to the output
     tir::Stmt copy = tir::For(
         loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial,
@@ -390,8 +483,8 @@ class AOTExecutorCodegen : public ExprVisitor {
     }
 
     ICHECK_GE(storage_device_map_.count(expr), 0);
-    auto& device_type = storage_device_map_[expr][1];
-    auto call_dev_type = device_type[0]->value;
+    StorageInfo& sinfo = storage_device_map_[expr];
+    auto call_dev_type = sinfo->device_types[0];
     // Normal Relay Function
     if (targets_.size() == 1) {
       // homogeneous execution.
@@ -425,17 +518,23 @@ class AOTExecutorCodegen : public ExprVisitor {
 
   void VisitExpr_(const VarNode* op) override {
     Expr expr = GetRef<Expr>(op);
+    StorageInfo& sinfo = storage_device_map_[expr];
 
     // If the Var node is an output node we need to copy the content of the variable to the output
     // It's safe to check the SID here because Var StorageToken are never reallocated
-    Array<IntegerArray> sids = storage_device_map_[expr];
-
-    auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
-                                 static_cast<int>((sids[0][0].as<IntImmNode>())->value));
+    auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
     if (output_iter != return_sid_.end()) {
       int output_index = std::distance(return_sid_.begin(), output_iter);
-      auto var_expr = FindExpr(expr);
-      CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]);
+      if (params_by_expr_.find(expr) != params_by_expr_.end()) {
+        auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+                                           {tir::StringImm(params_by_expr_[expr])});
+        CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle,
+                     /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]);
+      } else {
+        auto var_expr = FindExpr(expr);
+        CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
+                     /*pack_input*/ true, sinfo->storage_sizes_in_bytes[0]);
+      }
     }
   }
 
@@ -443,19 +542,20 @@ class AOTExecutorCodegen : public ExprVisitor {
     Expr expr = GetRef<Expr>(op);
     size_t index = params_.size();
     std::string name = "p" + std::to_string(index);
-
-    param_storage_ids_[name] = storage_device_map_[expr][0][0]->value;
+    StorageInfo& sinfo = storage_device_map_[expr];
+    param_storage_ids_[name] = sinfo->storage_ids[0];
     params_[name] = op->data;
     params_by_expr_.Set(expr, name);
 
     // If the Constant node is an output node we need to copy the content of the parameter to the
     // output A Var node can only produce a single output
-    Array<IntegerArray> sids = storage_device_map_[expr];
-    auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
-                                 static_cast<int>((sids[0][0].as<IntImmNode>())->value));
+    auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
     if (output_iter != return_sid_.end()) {
       int output_index = std::distance(return_sid_.begin(), output_iter);
-      CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]);
+      auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
+                                         {tir::StringImm(params_by_expr_[expr])});
+      CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false,
+                   sinfo->storage_sizes_in_bytes[0]);
     }
   }
 
@@ -495,7 +595,9 @@ class AOTExecutorCodegen : public ExprVisitor {
     throw std::invalid_argument("match case not yet implemented");
   }
 
-  // Create the main PrimFunc to execute the graph
+  // Create the main PrimFunc to execute the graph. Please note that
+  // the packed function calls don't pack their arguments. The AOT
+  // runner function needs to be legalized by the LegalizePackedCalls pass.
   tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
     tir::Stmt body = tir::SeqStmt(stmts_);
 
@@ -511,9 +613,9 @@ class AOTExecutorCodegen : public ExprVisitor {
         continue;
       }
 
-      for (unsigned int i = 0; i < kv.second[0].size(); i++) {
-        int size = kv.second[2][i];
-        int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value);
+      for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) {
+        int size = kv.second->storage_sizes_in_bytes[i];
+        int sid = kv.second->storage_ids[i];
 
         if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
           continue;
@@ -523,6 +625,8 @@ class AOTExecutorCodegen : public ExprVisitor {
         // so we don't pay the price of allocation for every inference
         if (!allocated[sid]) {
           body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body);
+          body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"),
+                               body);
         }
         allocated[sid] = true;
       }
@@ -578,7 +682,8 @@ class AOTExecutorCodegen : public ExprVisitor {
   std::unordered_map<std::string, int64_t> param_storage_ids_;
 
   /*! \brief plan memory of device result */
-  Map<Expr, Array<IntegerArray>> storage_device_map_;
+  StorageMap storage_device_map_;
+  /*! \brief mapping sid -> tir::Var */
   std::unordered_map<int, te::Var> sids_table_;
   /*! \brief lowered funcs */
   std::unordered_map<std::string, IRModule> lowered_funcs_;
@@ -589,7 +694,7 @@ class AOTExecutorCodegen : public ExprVisitor {
   /*! \brief the set of statements that make the program */
   std::vector<tir::Stmt> stmts_;
   /*! \brief the list of return sids (note that the function might return more then one output */
-  IntegerArray return_sid_;
+  std::vector<int> return_sid_;
   /*! \brief the module name we use to mangle the function names */
   String mod_name_;
 
@@ -602,9 +707,11 @@ class AOTExecutorCodegen : public ExprVisitor {
         compile_engine_(CompileEngine::Global()) {}
 
   LoweredOutput Codegen(relay::Function func, String mod_name) {
-    // Get the module, storage map and token sizes
-    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
-    storage_device_map_ = (*pf)(func);
+    auto aot_allocator = AOTOnDemandAllocator();
+    aot_allocator.Run(func);
+
+    // Retrieve the storage map
+    storage_device_map_ = aot_allocator.GetStorageMap();
     mod_name_ = mod_name;
 
     for (auto input : func->params) {
@@ -614,20 +721,23 @@ class AOTExecutorCodegen : public ExprVisitor {
 
     // Define the storage allocator ids
     for (auto kv : storage_device_map_) {
-      for (const auto& sid : kv.second[0]) {
-        te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
-        sids_table_[sid] = sid_var;
+      for (auto sid : kv.second->storage_ids) {
+        te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
+        sids_table_[sid] = buffer_var;
       }
     }
 
-    // Find the return sid
-    return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
+    // Retrieve the return sids
+    return_sid_ = aot_allocator.GetReturnIds();
     for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
       main_signature_.push_back(tir::Var("output", DataType::Handle()));
     }
 
     VisitExpr(func->body);
 
+    // Create the runner function. Please note that the function is not legal yet
+    // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
+    // to run the LegalizePackedCalls pass.
     auto prim_func = CreateMainFunc(func->params.size());
     UpdateMainWorkspaceSize(prim_func, func);
     LoweredOutput ret;
@@ -649,14 +759,28 @@ class AOTExecutorCodegen : public ExprVisitor {
     }
     ret.external_mods = compile_engine_->LowerExternalFunctions();
 
+    // Build the TIR IRModule
+    Map<GlobalVar, BaseFunc> symbol_map;
+    symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
+    IRModule mod_run(symbol_map);
+
+    // Apply storage rewrite pass to the runner function to do memory planning
+    auto storage_rewrite = tir::transform::StorageRewrite();
+    mod_run = storage_rewrite(mod_run);
+
+    // Legalize AOT if needed. This means that all the packed calls
+    // need to be wrapped in TVMValues (unless use_unpacked_api is set)
+    if (!use_unpacked_api_) {
+      auto pack_calls = tir::transform::LegalizePackedCalls();
+      mod_run = pack_calls(mod_run);
+    }
+
+    // Update the lowered functions
     auto target_host_str = target_host_->str();
     if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
-      ret.lowered_funcs[target_host_str]->Add(
-          GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
+      ret.lowered_funcs[target_host_str]->Update(mod_run);
     } else {
-      Map<GlobalVar, BaseFunc> symbol_map;
-      symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
-      ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
+      ret.lowered_funcs.Set(target_host_str, mod_run);
     }
     ret.function_metadata = std::move(function_metadata_);
     ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h
index 3b4e693..906ff8a 100644
--- a/src/tir/transforms/ir_utils.h
+++ b/src/tir/transforms/ir_utils.h
@@ -29,6 +29,8 @@
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 
+#include <limits>
+#include <string>
 #include <vector>
 
 namespace tvm {
@@ -162,6 +164,27 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
 }
 
 /*!
+ * \brief Create an int32 constant
+ * \param index the value of the constant
+ * \return the PrimExpr that represents the constant
+ */
+inline PrimExpr ConstInt32(size_t index) {
+  ICHECK_LE(index, std::numeric_limits<int>::max());
+  return make_const(DataType::Int(32), static_cast<int>(index));
+}
+
+/*!
+ * \brief Allocate TVMValues on the stack
+ * \param type type of allocation
+ * \param num number of TVMValues to allocate
+ * \return PrimExpr representing the TVMValue
+ */
+inline PrimExpr StackAlloca(std::string type, size_t num) {
+  Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
+  return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
+}
+
+/*!
  * \brief Convert a IR node to be SSA form.
  * \param stmt The source statement to be converted.
  * \return The converted form.
diff --git a/src/tir/transforms/legalize_packed_calls.cc b/src/tir/transforms/legalize_packed_calls.cc
new file mode 100644
index 0000000..424da1e
--- /dev/null
+++ b/src/tir/transforms/legalize_packed_calls.cc
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file make_packed_call.cc
+ * \brief Rewrite packed calls in AOT so that the arguments are packed
+ */
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <unordered_map>
+
+#include "ir_utils.h"
+
+namespace tvm {
+namespace tir {
+
+using InputMap =
+    std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
+/**
+ * This is a legalization pass only used in AOT. Traverse the TIR graph to legalize
+ * packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in)
+ */
+class PackedCallLegalizer : public StmtExprMutator {
+ public:
+  Stmt Legalize(const InputMap& params, tir::Stmt body) {
+    inputs_ = params;
+    return StmtExprMutator::VisitStmt(body);
+  }
+
+  Stmt VisitStmt_(const EvaluateNode* op) final {
+    if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op);
+    const CallNode* call = op->value.as<CallNode>();
+    // Given a packed call f(A,B,C), we need a set of new statements
+    // let A_packed = set_struct(tvm_value1, A)
+    // let B_packed = set_struct(tvm_value2, B)
+    // let C_packed = set_struct(tvm_value3, C)
+    // call_packed(f, A_packed, B_packed, C_packed)
+    std::vector<Stmt> new_stmts;
+    if (call) {
+      if (call->op.same_as(builtin::tvm_call_cpacked())) {
+        Array<PrimExpr> packed_args{call->args[0]};
+        std::vector<tir::Var> tvm_values;
+        for (unsigned i = 1; i < call->args.size(); i++) {
+          // No need to pack inputs of the prim_func
+          if (inputs_[call->args[i]] == true) {
+            packed_args.push_back(call->args[i]);
+          } else {
+            // Pack the argument inside a TVMValue
+            std::stringstream ss;
+            ss << "tvm_value_" << tvm_value_index_++;
+            auto sid_array = tir::Var(ss.str(), DataType::Handle());
+            tvm_values.push_back(sid_array);
+
+            new_stmts.push_back(tir::Evaluate(
+                tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
+                               {sid_array, 0, tir::builtin::kArrData, call->args[i]})));
+            packed_args.push_back(sid_array);
+          }
+        }
+        // Evaluate the packed call
+        new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)));
+        tir::Stmt call_stmt = tir::SeqStmt(new_stmts);
+
+        // Allocate the TVMValues on the stack and define the variables
+        for (auto v : tvm_values) {
+          call_stmt = LetStmt(v, StackAlloca("array", 1), call_stmt);
+        }
+        return call_stmt;
+      }
+    }
+    return StmtExprMutator::VisitStmt_(op);
+  }
+
+ private:
+  InputMap inputs_;      // Store the inputs to the primfunc that don't need to be packed.
+  int tvm_value_index_;  // Index of the actual tvm_value variable
+};
+
+namespace transform {
+
+Pass LegalizePackedCalls() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    auto* n = f.CopyOnWrite();
+
+    // Create the
+    InputMap inputs;
+    for (auto i : f->params) {
+      inputs[i] = true;
+    }
+    n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body));
+    return std::move(f);
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.LegalizePackedCalls").set_body_typed(LegalizePackedCalls);
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc
index 0e2e612..8b70817 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -34,16 +34,6 @@
 namespace tvm {
 namespace tir {
 
-inline PrimExpr ConstInt32(size_t index) {
-  ICHECK_LE(index, std::numeric_limits<int>::max());
-  return make_const(DataType::Int(32), static_cast<int>(index));
-}
-
-inline PrimExpr StackAlloca(std::string type, size_t num) {
-  Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
-  return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
-}
-
 // Calculate the statistics of packed function.
 // These information are needed during codegen.
 class BuiltinLower : public StmtExprMutator {
diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py
index a18a0fa..836ff4b 100644
--- a/tests/python/relay/aot/aot_test_utils.py
+++ b/tests/python/relay/aot/aot_test_utils.py
@@ -42,6 +42,46 @@ def mangle_name(mod_name, name):
     return mod_name + "_" + name
 
 
+def convert_to_relay(
+    tflite_model_buf,
+    input_data,
+    input_node,
+):
+    """Convert a tflite model buffer in a Relay module"""
+
+    def convert_to_list(x):
+        if not isinstance(x, list):
+            x = [x]
+        return x
+
+    # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
+    try:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+    except AttributeError:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+    except ImportError:
+        raise ImportError("The tflite package must be installed")
+
+    input_data = convert_to_list(input_data)
+    input_node = convert_to_list(input_node)
+
+    shape_dict = {}
+    dtype_dict = {}
+    for i, e in enumerate(input_node):
+        shape_dict[e] = input_data[i].shape
+        dtype_dict[e] = input_data[i].dtype.name
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict
+    )
+    mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params)
+    return mod, params
+
+
 def subprocess_with_stdout_and_log(cmd, cwd, logfile, stdout):
     """
     This method runs a process and logs the output to both a log file and stdout
@@ -221,6 +261,7 @@ def compile_and_run(
     params=None,
     workspace_byte_alignment=8,
     mod_name=None,
+    enable_op_fusion=True,
 ):
     """
     This method verifies the generated source
@@ -232,7 +273,11 @@ def compile_and_run(
     if not use_calculated_workspaces:
         cflags += "-DTVM_CRT_STACK_ALLOCATOR_ENABLE_LIFO_CHECK "
 
-    with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
+    config = {"tir.disable_vectorize": True}
+    if not enable_op_fusion:
+        config["relay.FuseOps.max_depth"] = 1
+
+    with tvm.transform.PassContext(opt_level=3, config=config):
         lib = tvm.relay.build(mod, target, target_host=target, params=params, mod_name=mod_name)
 
     tmp_path = utils.tempdir()
diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py
index 36596a4..13cbfa7 100644
--- a/tests/python/relay/aot/test_crt_aot.py
+++ b/tests/python/relay/aot/test_crt_aot.py
@@ -465,5 +465,53 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(8, 3, 5, 5),
     )
 
 
+def test_quant_mobilenet_tfl():
+    """Since in AOT we pass directly the output buffer from the user, in quantized networks sharing the output buffers is not possible.
+    This is because the output data type is int8 and the intermediate buffer are int32 or int16. We use mobilenet quantized to stress this
+    situation and verify that the output buffer sharing is disabled in AOT."""
+    pytest.importorskip("tflite")
+
+    import tvm.relay.testing.tf as tf_testing
+
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://storage.googleapis.com/download.tensorflow.org/"
+        "models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+        "mobilenet_v1_1.0_224_quant.tflite",
+    )
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+    data_shape = (1, 224, 224, 3)
+    in_min, in_max = (0, 255)
+    data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8")
+    mod, params = convert_to_relay(tflite_model_buf, data, "input")
+    inputs = {"input": data}
+    output_list = generate_ref_data(mod, inputs, params)
+    input_list = [inputs["input"]]
+    compile_and_run(mod, input_list, output_list, "--unpacked-api=0", True, params)
+
+
+@pytest.mark.parametrize("target_options", ["--unpacked-api=0", "--unpacked-api=1"])
+def test_transpose(target_options):
+    """Test that non-inpleaceable operations (e.g., transpose) do not happen in-place."""
+
+    dtype = "float32"
+    x = relay.var("x", shape=(10, 5), dtype=dtype)
+    y = relay.var("y", shape=(10, 5), dtype=dtype)
+    t = relay.var("z", shape=(), dtype=dtype)
+    a = relay.add(x, y)
+    b = relay.transpose(a)
+    z = relay.add(b, t)
+    # Check result.
+    func = relay.Function([x, y, t], z)
+    x_data = np.random.rand(10, 5).astype(dtype)
+    y_data = np.random.rand(10, 5).astype(dtype)
+    t_data = np.random.uniform(size=()).astype(dtype)
+    inputs = {"x": x_data, "y": y_data, "z": t_data}
+
+    output_list = generate_ref_data(func, inputs)
+    input_list = [inputs["x"], inputs["y"], inputs["z"]]
+    compile_and_run(func, input_list, output_list, target_options, True, enable_op_fusion=False)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])
diff --git a/tests/python/unittest/test_aot_legalize_packed_call.py b/tests/python/unittest/test_aot_legalize_packed_call.py
new file mode 100644
index 0000000..626af0c
--- /dev/null
+++ b/tests/python/unittest/test_aot_legalize_packed_call.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import tvm
+from tvm.script import ty
+from tvm import te, tir
+import numpy as np
+import tvm.testing
+import pytest
+
+
+@tvm.script.tir
+class Module:
+    def tir_packed_call() -> None:
+        A = tir.var("handle")
+        B = tir.var("handle")
+        C = tir.var("handle")
+        # body
+        tir.evaluate(
+            tir.tvm_call_cpacked(
+                "tvm_test_cpacked",
+                A,
+                B,
+                C,
+                dtype="int32",
+            )
+        )
+
+
+@tvm.script.tir
+class Expected:
+    def tir_packed_call() -> None:
+        A = tir.var("handle")
+        B = tir.var("handle")
+        C = tir.var("handle")
+
+        # body
+        tvm_value_2 = tir.var("handle")
+        tvm_value_1 = tir.var("handle")
+        tvm_value_0 = tir.var("handle")
+        with tir.let(tvm_value_2, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+            with tir.let(tvm_value_1, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+                with tir.let(tvm_value_0, tir.tvm_stack_alloca("array", 1, dtype="handle")):
+                    tir.evaluate(tir.tvm_struct_set(tvm_value_0, 0, 1, A, dtype="handle"))
+                    tir.evaluate(tir.tvm_struct_set(tvm_value_1, 0, 1, B, dtype="handle"))
+                    tir.evaluate(tir.tvm_struct_set(tvm_value_2, 0, 1, C, dtype="handle"))
+                    tir.evaluate(
+                        tir.tvm_call_cpacked(
+                            "tvm_test_cpacked",
+                            tvm_value_0,
+                            tvm_value_1,
+                            tvm_value_2,
+                            dtype="int32",
+                        )
+                    )
+
+
+def test_aot_packed_call():
+    mod = Module()
+    expected = Expected()
+    out = tir.transform.LegalizePackedCalls()(mod)
+    tvm.ir.assert_structural_equal(expected, out, map_free_vars=True)
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])