You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/21 22:52:37 UTC

[GitHub] [tvm] CircleSpin opened a new pull request #8110: [WIP] Unify Python and C++ TIR lower API

CircleSpin opened a new pull request #8110:
URL: https://github.com/apache/tvm/pull/8110


   In this PR we removed the python version of the lower function and replaced with a call to the C++ version. We also added some functionality that was in the python version that was not previously in the C++ version. 
   
   
   
   This is a work in progress as we see what unit tests fail. 
   @electriclilies @jroesch 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen commented on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-860057276


   Thanks @CircleSpin and @electriclilies ! Thanks @Hzfengsy @tkonolige @csullivan @YuchenJin @manupa-arm for reviewing


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-856162984


   CC @Hzfengsy: Hey @electriclilies and @CircleSpin has done great work unifying python and C++ side lowering/build APIs!! Would you love to take a look and make sure it works for TensorIR? Thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r639553994



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {

Review comment:
       Clarification : what is legacy_te_pass mean here?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -48,10 +49,39 @@ namespace tvm {
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Skips the LoopPartition pass if true. Defaults to false.
  * \return The result module.
  */
 TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                       bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Skips the LoopPartition pass if true. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule lower(IRModule mod, const Array<te::Tensor>& args, const std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                       bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       typo : should be a PrimFunc?

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();

Review comment:
       Clarification : What is the importance of having phases and what do they represent ?

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    // TODO(electriclilies): is there a cleaner way to do this?
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (legacy_te_pass) {

Review comment:
       @jroesch @CircleSpin 
   A general design question : What do you think of relocating the pass pipeline closer to the target registry ? So that way we could have a function (that can also be a bit parameterized based on target args) that describes pass pipeline rather than mandating the passes here. 
   
   Here, we could just query the pass pipeline based on the target and run here. I think this can remove the need to support user defined custom passes at this level.
   
   Thoughts ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] YuchenJin commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
YuchenJin commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647079954



##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       Hi @electriclilies, I guess even if `input` is a builtin function in python, it can still be a parameter name? For example this code should work:
   ```
   def f(input=2):
       print(input)
   ```
   Please correct me if I am wrong. :)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647699158



##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars)
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Create an IRModule out of a Schedule
+ * \param sch The schedule
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \return The result module.
+ */
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,

Review comment:
       @tkonolige does this look good to you?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-856162984


   CC @Hzfengsy: Hey Siyuan, @electriclilies, @CircleSpin and Jared have done great work unifying python and C++ side lowering/build APIs!! Would you love to take a look and make sure it works for TensorIR? Thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647820965



##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,67 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given an input IRModule

Review comment:
       ```suggestion
    * \brief Lower an IRModule
   ```
   
   Could you maybe also add a little more detail on what it means to lower an irmodule?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647699649



##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       oh, thanks! Well I just changed it to `inp`, so I think I'll leave it that way. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-847475877


   cc @manupa-arm @giuseros this is worth paying attention to, we have been working on cleaning up the internal APIs and bringing everything to C++. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r643358453



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       Seems like it is causing issues in VM profiler tests. Removing for now. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] YuchenJin commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
YuchenJin commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647661298



##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       Thanks for the detail! :) It's a pylint warning so you can disable it using `# pylint: disable=redefined-builtin`. One example is here: https://github.com/apache/tvm/blob/main/python/tvm/topi/image/dilation2d.py#L27.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646927080



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();

Review comment:
       I'm not sure, it would be good to get some clarification from someone else on this, maybe @junrushao1994 knows?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.

Review comment:
       I'm not completely sure. This logic was already in the code base, so I just duplicated it and made it more explicit. (For context, when I started this refactor, Jared said he wanted to get it in quickly so I should just try to naively duplicate the existing logic, but he hasn't reviewed it yet so I'm not sure what the timeline is now). 
   
   I can try to remove it, but I'm not sure what the best way to go about this is since there are a few tests that call lower directly..

##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       The word input is a python built in function so you can't name variables `input.` Initially the name of the parameter to the function was `inputs` and the name of the parameter in the documentation was `input`. I just changed the documentation to match the function signature. I don't like that they are different. I agree that `inputs` is not an ideal name though and I'm open to suggestions for other names

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars)
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Create an IRModule out of a Schedule
+ * \param sch The schedule
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \return The result module.
+ */
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,

Review comment:
       Initially this was a function called `form_irmodule` in python. I translated it to C++ and renamed it as `ScheduleToModule`. Unfortunately, form_irmodule is called by some tests. To preserve that behavior I had to split the functions apart again and register them separately in the FFI. 
   The difference between these is that `ScheduleToModule` just converts the schedule to a module that hasn't yet been lowered, whereas `LowerSchedule` converts the schedule into a module and then applies the passes. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r639988925



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    // TODO(electriclilies): is there a cleaner way to do this?
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (legacy_te_pass) {

Review comment:
       To clarify, are you saying that we would construct the pass list in the registration of `driver.lower`, then pass it into a function like `LowerWithPassList`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647713947



##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,

Review comment:
       Array<te::Tensor> is not a subtype of Array<ObjectRef>, so C++ can't upcast the args and just call the other one. So I won't remove the signature, but I'll have the one that has Array<te::Tensor> as args upcast each element ObjectRef, then call the other GetBinds




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-856162984


   CC @Hzfengsy: Hey @electriclilies, @CircleSpin and Jared has done great work unifying python and C++ side lowering/build APIs!! Would you love to take a look and make sure it works for TensorIR? Thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646943807



##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -770,7 +770,8 @@ class CompileEngineImpl : public CompileEngineNode {
       With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
+      cache_node->funcs =
+          tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds);

Review comment:
       We shouldn't need to route through the `relay.backend.lower` packed function any longer right? I'd suggest either removing it and the above conditional (line 766) or also reproducing its functionality in c++ if we do need it as part of this PR. This way we avoid an unnecessary round trip through python. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 edited a comment on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
junrushao1994 edited a comment on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-856162984


   CC @Hzfengsy: Hey Siyuan, @electriclilies, @CircleSpin and Jared have done great work unifying python and C++ side lowering/build APIs!! So finally we got those two APIs unified well. Would you love to take a look and make sure it works for TensorIR? Thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] junrushao1994 commented on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
junrushao1994 commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-856162984


   CC @Hzfengsy: Hey Lily has done great work unifying python and C++ side lowering/build APIs!! Would you love to take a look and make sure it works for TensorIR? Thanks a lot!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-860838359


   Thanks @tqchen!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647647392



##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       The problem is actually that the linter doesn't like it -- I guess it just flags everything called input for some reason.
   https://ci.tlcpack.ai/blue/organizations/jenkins/tvm/detail/PR-8110/14/pipeline




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r643488713



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       TQ confirmed that this code is no longer used




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-857293701


   Thanks @tkonolige! @Hzfengsy @csullivan can you take another look and approve?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646912412



##########
File path: python/tvm/autotvm/feature.py
##########
@@ -39,7 +39,7 @@ def ana_lower(sch, args, binds=None, simple_mode=True):
     """Do lower while keeping all axes in IR
     i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
     """
-    binds, _ = build_module.get_binds(args, binds)
+    binds, _ = build_module.get_binds(args, False, binds)

Review comment:
       ```suggestion
       binds, _ = build_module.get_binds(args, compact=False, binds=binds)
   ```
   Preference for keyword use at call site for readability

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars)

Review comment:
       ```suggestion
    * \param args The arguments to the function (Array of Tensor, Buffer and Vars)
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
+  return LowerModule(mod, simple_mode);
+});
+
+IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) {
+  auto pass_ctx = transform::PassContext::Current();
+  auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+
+  // Get the pass list
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_primfunc")
+    .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) {
+      return LowerPrimFunc(func, name, simple_mode);
+    });
+
+IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+  Array<ObjectRef> ref_args;
+  for (auto x : args) {
+    ref_args.push_back(x);
+  }
+  return LowerSchedule(sch, ref_args, name, binds);
+}
+
+IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+  IRModule mod = ScheduleToModule(sch, args, name, binds);
+  // Get the legacy TE pass list
+  auto pass_list = CreatePassList(simple_mode, true);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_schedule")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();

Review comment:
       It would be nice to take the refactoring opportunity to add some documentation on this. Though completely understand you have mirrored the previous lower impl.

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       ```suggestion
    * \brief Build an IRModule given a module
   ```
   Updated the comment to reflect the API. Please check the docs for the other APIs that don't take args/binds as well.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse

Review comment:
       ```suggestion
   ```
   Either add more details or this comment isn't needed.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: python/tvm/driver/build_module.py
##########
@@ -37,92 +37,54 @@
 from tvm.tir.buffer import Buffer
 from tvm.tir.expr import Var
 
+from . import _ffi_api as ffi
+
 
 def get_binds(args, compact=False, binds=None):
     """Internal function to get binds and arg_list given arguments.
-
     Parameters
     ----------
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
-
     compact : bool
         If the statement has already bound to a compact buffer.
-
     binds : dict of :any:`Tensor` to :any:`Buffer`, optional
         Dictionary that maps the Tensor to Buffer which specified the data layout
         requirement of the function. By default, a new compact buffer is created
         for each tensor in the argument.
-
     Returns
     -------
     binds: dict
         The bind specification
-
     arg_list: list
         The list of symbolic buffers of arguments.
     """
-    binds = {} if binds is None else binds.copy()
-    arg_list = []
-    for x in args:
-        if isinstance(x, tensor.Tensor):
-            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
-            buffer_type = "auto_broadcast" if any_dim and not compact else ""
-            if x not in binds:
-                buf = tvm.tir.decl_buffer(
-                    x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
-                )
-                binds[x] = buf
-                arg_list.append(buf)
-            else:
-                arg_list.append(binds[x])
-        elif isinstance(x, schedule.Buffer):
-            arg_list.append(x)
-        elif isinstance(x, tvm.tir.Var):
-            arg_list.append(x)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    return binds, arg_list
-
-
-def form_irmodule(sch, args, name, binds):
-    """According to the given schedule, form a function.
+    out_arr = ffi.get_binds(args, compact, binds)
+    return out_arr[0], out_arr[1]
+

Review comment:
       Probably we should update the doc string to match the return variable names. Actually the reverse would be preferable, `binds, arg_list = ffi.get_binds(...)` but I understand you are limited here by the object system.

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,

Review comment:
       Is there a reason to include outputs as input references other than needing multiple return values?  I couldn't see one from the use of GetBinds in your PR but maybe I missed it.
   
   If not we could return std::pair/tuple with std::tie and later upgrade to structured bindings if we move to c++17.
   ```
   // now
   Array<ObjectRef> out_arg_list;
   Map<te::Tensor, tir::Buffer> out_binds;
   std::tie(out_binds, out_arg_list) = GetBinds(args, compact, binds);
   
   // c++17
   auto [out_binds, out_arg_list] = GetBinds(args, compact, binds);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```
   

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));
+        }
+      }
+      Map<te::Tensor, tir::Buffer> out_binds;
+      Array<ObjectRef> out_arg_list;
+      GetBinds(args, compact, c_binds, &out_binds, &out_arg_list);
+
+      // TVM object system doesn't have a pair object, so we'll put both ret values in an array
+      // and return that.
+      Array<ObjectRef> out_arr = {out_binds, out_arg_list};

Review comment:
       I'm fine with the array impl., but do you know if there is interest to add a pair type? I would have guessed there was one already given that we have Map. `Map<String, ObjectRef>` could work, albeit at the cost of a string lookup.
   ```
   out = driver.get_binds(...)
   return out["binds"], out["arg_list"]
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -770,7 +770,8 @@ class CompileEngineImpl : public CompileEngineNode {
       With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
+      cache_node->funcs =
+          tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds);

Review comment:
       We shouldn't need to route through the `relay.backend.lower` packed function any longer right? I'd suggest either removing it and the above conditional (line 766) or also reproducing it's functionality in c++ if we do need it as part of this PR. This way we avoid an unnecessary round trip through python. 

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) {
+  return LowerModule(mod, simple_mode);
+});
+
+IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, bool simple_mode) {
+  auto pass_ctx = transform::PassContext::Current();
+  auto f = WithAttr(std::move(func), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+
+  // Get the pass list
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_primfunc")
+    .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) {
+      return LowerPrimFunc(func, name, simple_mode);
+    });
+
+IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+  Array<ObjectRef> ref_args;
+  for (auto x : args) {
+    ref_args.push_back(x);
+  }
+  return LowerSchedule(sch, ref_args, name, binds);
+}
+
+IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                       const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+  IRModule mod = ScheduleToModule(sch, args, name, binds);
+  // Get the legacy TE pass list
+  auto pass_list = CreatePassList(simple_mode, true);
+  return LowerWithPassList(mod, pass_list);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower_schedule")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));

Review comment:
       ```suggestion
             c_binds.insert({kv.first, kv.second});
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +173,209 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {

Review comment:
       ```suggestion
         if (binds.get() != nullptr) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646870692



##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.

Review comment:
       Why is there a special flag to disable the loop partition pass. Shouldn't the existing `PassContext` infrastructure handle this?

##########
File path: python/tvm/driver/build_module.py
##########
@@ -37,92 +37,54 @@
 from tvm.tir.buffer import Buffer
 from tvm.tir.expr import Var
 
+from . import _ffi_api as ffi
+
 
 def get_binds(args, compact=False, binds=None):
     """Internal function to get binds and arg_list given arguments.
-
     Parameters
     ----------
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
-
     compact : bool
         If the statement has already bound to a compact buffer.
-
     binds : dict of :any:`Tensor` to :any:`Buffer`, optional
         Dictionary that maps the Tensor to Buffer which specified the data layout
         requirement of the function. By default, a new compact buffer is created
         for each tensor in the argument.
-
     Returns
     -------
     binds: dict
         The bind specification
-
     arg_list: list
         The list of symbolic buffers of arguments.
     """
-    binds = {} if binds is None else binds.copy()
-    arg_list = []
-    for x in args:
-        if isinstance(x, tensor.Tensor):
-            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
-            buffer_type = "auto_broadcast" if any_dim and not compact else ""
-            if x not in binds:
-                buf = tvm.tir.decl_buffer(
-                    x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
-                )
-                binds[x] = buf
-                arg_list.append(buf)
-            else:
-                arg_list.append(binds[x])
-        elif isinstance(x, schedule.Buffer):
-            arg_list.append(x)
-        elif isinstance(x, tvm.tir.Var):
-            arg_list.append(x)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    return binds, arg_list
-
-
-def form_irmodule(sch, args, name, binds):
-    """According to the given schedule, form a function.
+    out_arr = ffi.get_binds(args, compact, binds)
+    return out_arr[0], out_arr[1]
+
 
+def schedule_to_module(
+    sch: schedule.Schedule,
+    args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
+    name: str = "main",
+    binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
+) -> IRModule:
+    """According to the given schedule, form a function.
     Parameters
     ----------
     sch : tvm.te.schedule.Schedule
         The given scheduler to form the raw body
-
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
-
     name : str
         The name of result function.

Review comment:
       Can you add the default name to this documentation?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds

Review comment:
       Maybe specify that this is a TE schedule?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       Should this be primfunc instead of module?

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)

Review comment:
       Use `LOG(FATAL)` instead of `ICHECK(false)`

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var";

Review comment:
       Can you print their actual types here?

##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       Why is the name `inputs` when it is a singular schedule, primfunc, or irmodule?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars)
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Create an IRModule out of a Schedule
+ * \param sch The schedule
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \return The result module.
+ */
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,

Review comment:
       This looks almost exactly the same as `LowerSchedule`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r639975806



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    // TODO(electriclilies): is there a cleaner way to do this?
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (legacy_te_pass) {

Review comment:
       I think doing something like this is a good idea. I also want to simplify the signature for Lower..
   
    Right now we have a lot of function signatures for Lower because the python version allows multiple input types for each argument. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646953614



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();

Review comment:
       I'm happy to add documentation, I just don't actually know what the answer is to this question

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      ICHECK(false)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var";
+    }
+  }
+}
+
+TVM_REGISTER_GLOBAL("driver.get_binds")
+    .set_body_typed([](const Array<ObjectRef>& args, bool compact,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));
+        }
+      }
+      Map<te::Tensor, tir::Buffer> out_binds;
+      Array<ObjectRef> out_arg_list;
+      GetBinds(args, compact, c_binds, &out_binds, &out_arg_list);
+
+      // TVM object system doesn't have a pair object, so we'll put both ret values in an array
+      // and return that.
+      Array<ObjectRef> out_arr = {out_binds, out_arg_list};

Review comment:
       I've used the Map as a workaround for this before as well. I do think having some way to unpack multiple values from an FFI function is important since people do this in Python so often, so adding either a Pair object or Tuple object would be helpful. 

##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,

Review comment:
       Originally I did this because I was making the signature as similar as possible to the other GetBinds (which was originally in C++), which takes in an `Array<te::Tensor>` for the args instead of an` Array<ObjectRef>. 
   I'm actually not sure if I should try to get rid of one of the versions of GetBinds, keep them as similar as possible, or just make two versions with different signatures, what do you think? 

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -770,7 +770,8 @@ class CompileEngineImpl : public CompileEngineNode {
       With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
 
       std::unordered_map<te::Tensor, tir::Buffer> binds;
-      cache_node->funcs = tvm::lower(cfunc->schedule, all_args, cache_node->func_name, binds);
+      cache_node->funcs =
+          tvm::LowerSchedule(cfunc->schedule, all_args, cache_node->func_name, binds);

Review comment:
       Yeah I don't think we do. I'll remove it (I didn't look too closely at what this call was doing, oops!)

##########
File path: python/tvm/driver/build_module.py
##########
@@ -37,92 +37,54 @@
 from tvm.tir.buffer import Buffer
 from tvm.tir.expr import Var
 
+from . import _ffi_api as ffi
+
 
 def get_binds(args, compact=False, binds=None):
     """Internal function to get binds and arg_list given arguments.
-
     Parameters
     ----------
     args : list of Buffer or Tensor or Var
         The argument lists to the function.
-
     compact : bool
         If the statement has already bound to a compact buffer.
-
     binds : dict of :any:`Tensor` to :any:`Buffer`, optional
         Dictionary that maps the Tensor to Buffer which specified the data layout
         requirement of the function. By default, a new compact buffer is created
         for each tensor in the argument.
-
     Returns
     -------
     binds: dict
         The bind specification
-
     arg_list: list
         The list of symbolic buffers of arguments.
     """
-    binds = {} if binds is None else binds.copy()
-    arg_list = []
-    for x in args:
-        if isinstance(x, tensor.Tensor):
-            any_dim = any(isinstance(i, tvm.tir.Var) for i in x.shape)
-            buffer_type = "auto_broadcast" if any_dim and not compact else ""
-            if x not in binds:
-                buf = tvm.tir.decl_buffer(
-                    x.shape, dtype=x.dtype, name=x.name, buffer_type=buffer_type
-                )
-                binds[x] = buf
-                arg_list.append(buf)
-            else:
-                arg_list.append(binds[x])
-        elif isinstance(x, schedule.Buffer):
-            arg_list.append(x)
-        elif isinstance(x, tvm.tir.Var):
-            arg_list.append(x)
-        else:
-            raise ValueError("args must be Tensor, Buffer or Var")
-    return binds, arg_list
-
-
-def form_irmodule(sch, args, name, binds):
-    """According to the given schedule, form a function.
+    out_arr = ffi.get_binds(args, compact, binds)
+    return out_arr[0], out_arr[1]
+

Review comment:
       Yup can do this




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646984639



##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       Updated!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r641221024



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       What is this doing and should it be here? It was in the C++ version but not python version. 

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       Seems like it is causing issues in VM profiler tests. Removing for now. 

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       TQ confirmed that this code is no longer used




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tkonolige commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
tkonolige commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647013617



##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,52 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,
+              const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+              Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list) {
+  *out_binds = binds;
+
+  for (const ObjectRef& x : args) {
+    if (const auto* tensor_node = x.as<te::TensorNode>()) {
+      auto x_ref = GetRef<te::Tensor>(tensor_node);
+      if (out_binds->find(x_ref) == out_binds->end()) {
+        auto buf =
+            BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, x_ref->op->name, -1, 0, compact);
+        out_binds->Set(x_ref, buf);
+        out_arg_list->push_back(buf);
+      } else {
+        out_arg_list->push_back((*out_binds)[x_ref]);
+      }
+    } else if (x.as<te::BufferNode>() || x.as<tir::VarNode>()) {
+      out_arg_list->push_back(x);
+    } else {
+      LOG(FATAL)
+          << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, "
+          << "but got a " << typeid(x).name();

Review comment:
       ```suggestion
             << "but got a " << x->GetTypeKey();
   ```
   `GetTypeKey` has a confusing name but it is what you want.

##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       Maybe use `inp` instead? I can't really think of a good name either.

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
  * \return The result module.
  */
-TVM_DLL IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-                       const std::unordered_map<te::Tensor, tir::Buffer>& binds);
 
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<te::Tensor>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a schedule, args and binds
+ * \param sch The schedule to lower.
+ * \param args The arguments to the function (Array of Union of Tensor, Buffer and Vars)
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerSchedule(te::Schedule sch, const Array<ObjectRef>& args,
+                               const std::string& name,
+                               const std::unordered_map<te::Tensor, tir::Buffer>& binds,
+                               bool simple_mode = false);
+
+/*!
+ * \brief Create an IRModule out of a Schedule
+ * \param sch The schedule
+ * \param args The arguments to the function.
+ * \param name The name of the lowered function.
+ * \param binds Buffer assignments.
+ * \return The result module.
+ */
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,

Review comment:
       Can you be a little more explicit about this in the documentation?

##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param func The PrimFunc to lower
+ * \param name The name of the lowered function.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name,
+                               bool simple_mode = false);
+
 /*!
  * \brief Build an IRModule given a schedule, args and binds
  * \param sch The schedule to lower.
  * \param args The arguments to the function.
  * \param name The name of the lowered function.
  * \param binds Buffer assignments.
+ * \param simple_mode Disables the loop partition pass. Defaults to false.

Review comment:
       It is probably fine for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
jroesch commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-847475482


   @CircleSpin @electriclilies ping me again when tests are green, left one comment that jumped out at me, otherwise looks good if we can get it to pass the test suites


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] Hzfengsy commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
Hzfengsy commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647245911



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);

Review comment:
       Please explicitly mark the type as possible as we can. `auto` is not so user-friendly that others can understand the code.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);

Review comment:
       ```suggestion
     return LowerWithPassList(std::(mod), pass_list);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);

Review comment:
       Please apply `std::move` to all following FFI functions

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);

Review comment:
       ```suggestion
         IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);

Review comment:
       Please use const ref or `std::move` to reduce unnecessary memory copy
   ```suggestion
     auto stmt = te::ScheduleOps(sch, std::move(bounds), false);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] manupa-arm commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
manupa-arm commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r640833354



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    // TODO(electriclilies): is there a cleaner way to do this?
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (legacy_te_pass) {

Review comment:
       Yes, OK having looked at whats being removed -- maybe this is out of scope for the PR.
   
   My original question was about the motivation of custom passes -- specifically at the four possible locations they get inserted (known as phases). If it has some meaning, might worth putting a comment.
   
   Then, the next one was given the feature of adding custom passes -- I was thinking we could rather have custom pass pipeline registered/provided with the proximity for the target. Thus, some could create a new target and may be re-use , re-organize the passes that needs running.
   
   Again now I feel thats out of scope for this PR -- as this essentially mimics what the python lower is doing. :) 
   
   
   

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {

Review comment:
       Maybe we could be more specific in the name ? e.g. using such as "bool from_te_schedule" ?
   
   Also I see simple_mode is just enabling LoopPartition, we could use "bool enable_loop_partition" ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r641219409



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {

Review comment:
       Is there a better way to handle the creation of the pass list? Should I split this function into two, one for the te schedules and one for the IRModule and primfuncs?

##########
File path: src/driver/driver_api.cc
##########
@@ -93,6 +93,7 @@ tir::Buffer BufferWithOffsetAlignment(Array<PrimExpr> shape, DataType dtype, std
                      offset_factor, buffer_type);
 }
 
+// comment to try to remove this
 void GetBinds(const Array<te::Tensor>& args, bool compact,

Review comment:
       Is having two copies of GetBinds OK? 

##########
File path: python/tvm/driver/build_module.py
##########
@@ -160,90 +99,15 @@ def lower(
     m : IRModule
        The result IRModule
     """
-    # config setup
-    pass_ctx = PassContext.current()
-    instrument_bound_checkers = bool(pass_ctx.config.get("tir.instrument_bound_checkers", False))
-    disable_vectorize = bool(pass_ctx.config.get("tir.disable_vectorize", False))
-    add_lower_pass = pass_ctx.config.get("tir.add_lower_pass", [])
-
-    lower_phase0 = [x[1] for x in add_lower_pass if x[0] == 0]
-    lower_phase1 = [x[1] for x in add_lower_pass if x[0] == 1]
-    lower_phase2 = [x[1] for x in add_lower_pass if x[0] == 2]
-    lower_phase3 = [x[1] for x in add_lower_pass if x[0] > 2]
-
-    # Phase 0
-    pass_list = lower_phase0
-    is_legacy_te_schedule: bool = False
-
-    if isinstance(inputs, schedule.Schedule):
-        if args is None:
-            raise ValueError("args must be given for lowering from TE schedule")
-        mod = form_irmodule(inputs, args, name, binds)
-        is_legacy_te_schedule = True
-    elif isinstance(inputs, PrimFunc):
-        func = inputs.with_attr("global_symbol", name)
-        if pass_ctx.config.get("tir.noalias", True):
-            func = func.with_attr("tir.noalias", True)
-        mod = tvm.IRModule({name: func})
-    elif isinstance(inputs, IRModule):
-        mod = inputs
-    else:
-        raise TypeError(
-            f"tvm.lower expected te.Schedule, PrimFunc or IRModule, but got {type(inputs)}"
-        )
-
-    # Phase 1
-    if is_legacy_te_schedule:
-        pass_list += [
-            tvm.tir.transform.InjectPrefetch(),
-            tvm.tir.transform.StorageFlatten(64, instrument_bound_checkers),
-        ]
-    else:
-        pass_list += [
-            tvm.tir.transform.LowerInitBlock(),
-            tvm.tir.transform.PlanAndUpdateBufferAllocationLocation(),
-            tvm.tir.transform.ConvertBlocksToOpaque(),
-            tvm.tir.transform.CompactBufferAllocation(),
-            tvm.tir.transform.FlattenBuffer(),
-        ]
-    pass_list += [
-        tvm.tir.transform.BF16Legalize(),
-        tvm.tir.transform.NarrowDataType(32),
-        tvm.tir.transform.Simplify(),
-    ]
-
-    pass_list += lower_phase1
-
-    # Phase 2
-    if not simple_mode:
-        pass_list += [(tvm.tir.transform.LoopPartition())]
-
-    pass_list += [
-        tvm.tir.transform.VectorizeLoop(not disable_vectorize),
-        tvm.tir.transform.InjectVirtualThread(),
-        tvm.tir.transform.InjectDoubleBuffer(),
-        tvm.tir.transform.StorageRewrite(),
-        tvm.tir.transform.UnrollLoop(),
-    ]
-    pass_list += lower_phase2
-
-    # Phase 3
-    pass_list += [
-        tvm.tir.transform.Simplify(),
-        tvm.tir.transform.RemoveNoOp(),
-    ]
-
-    pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
-    pass_list += [tvm.tir.transform.HoistIfThenElse()]
-    pass_list += lower_phase3
-
-    # Instrument BoundCheckers
-    if instrument_bound_checkers:
-        pass_list += [tvm.tir.transform.InstrumentBoundCheckers()]
-
-    optimize = tvm.transform.Sequential(pass_list)
-    mod = optimize(mod)
-    return mod
+    if isinstance(input, IRModule):
+        return ffi.lower_module(input)
+    if isinstance(input, PrimFunc):
+        return ffi.lower_primfunc(input)

Review comment:
       Should I split the python lower function into 3 functions as well or leave as a single API for now?

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.
+  // The code is inconsistent with what passes are in which phase as well. For now I have coped the
+  // python behavior exactly.
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
+
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (enable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  // HoistIfThenElse
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  stmt = te::SchedulePostProcRewriteForTensorCore(
+      stmt, sch,
+      binds);  // TODO(electriclilies): Should this be in here? Was in python but not C++ version.

Review comment:
       What is this doing and should it be here? It was in the C++ version but not python version. 

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +154,219 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool enable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GT(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+  // TODO(electriclilies): I'm not sure if they should go at the beginning or the end of the phase.

Review comment:
       What is the significance of phases? where are the boundaries between phases? This was inconsistent across the C++ and python version.. For now I have copied the python behavior exactly but I'm not sure if that is correct. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] YuchenJin commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
YuchenJin commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647700600



##########
File path: python/tvm/driver/build_module.py
##########
@@ -136,7 +98,7 @@ def lower(
 
     Parameters
     ----------
-    input : Union[schedule.Schedule, PrimFunc, IRModule]
+    inputs : Union[schedule.Schedule, PrimFunc, IRModule]

Review comment:
       Sounds good!




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r639952006



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +128,192 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool simple_mode, bool legacy_te_pass) {

Review comment:
       Legacy te pass means that the object we are lowering is a te::Schedule. I've copied the logic and phrasing directly from the python version -- I don't know the details of why this is still here. But without it, some tests fail. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r646907873



##########
File path: include/tvm/driver/driver_api.h
##########
@@ -42,17 +43,64 @@
 #include <vector>
 
 namespace tvm {
+
+/*!
+ * \brief Build an IRModule given a module, args and binds
+ * \param mod The IRmodule to lower
+ * \param simple_mode Disables the loop partition pass. Defaults to false.
+ * \return The result module.
+ */
+TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false);
+
+/*!
+ * \brief Build an IRModule given a module, args and binds

Review comment:
       Yes




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-850056461


   @manupa-arm Yeah I'm not exactly sure what the meaning of each "phase" is-- again, I'm just trying to duplicate the python code in C++. @junrushao1994 Could you weigh in about the significance of each phase?
   
   Also, thanks for the naming suggestions, I'm going to go through and clean stuff up later and I will take them into account then! Ideally I'd like to remove flags if possible


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] tqchen merged pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
tqchen merged pull request #8110:
URL: https://github.com/apache/tvm/pull/8110


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] csullivan commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
csullivan commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647017947



##########
File path: src/driver/driver_api.cc
##########
@@ -109,6 +109,51 @@ void GetBinds(const Array<te::Tensor>& args, bool compact,
   }
 }
 
+void GetBinds(const Array<ObjectRef>& args, bool compact,

Review comment:
       Oh, I didn't notice that other GetBinds. The one you've added here handles a superset of input types, right? In that case my vote would be to remove the old one for the sake of common code. Else, you know someone will end up updating only one of them... 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] jroesch commented on a change in pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
jroesch commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r638402304



##########
File path: src/driver/driver_api.cc
##########
@@ -185,6 +170,59 @@ IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::strin
   return mod;
 }
 
+IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
+               const std::unordered_map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // build the function
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+  return lower(mod, args, name, binds, simple_mode);
+}
+
+TVM_REGISTER_GLOBAL("driver.lower")
+    .set_body_typed([](ObjectRef obj, const Array<te::Tensor>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds, bool simple_mode) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != NULL) {
+        for (auto kv : binds) {
+          c_binds.insert(std::pair<te::Tensor, tir::Buffer>(kv.first, kv.second));
+        }
+      }
+
+      if (const auto* p_mod = obj.as<IRModuleNode>()) {
+        IRModule mod = GetRef<IRModule>(p_mod);
+        return lower(mod, args, name, c_binds, simple_mode);
+      } else if (const auto* p_sch = obj.as<te::ScheduleNode>()) {
+        te::Schedule sch = GetRef<te::Schedule>(p_sch);
+        return lower(sch, args, name, c_binds, simple_mode);
+      } else {
+        ICHECK(false) << "driver.lower expects the first argument to be a te::Schedule or "
+                      << "IRModule";
+        throw;

Review comment:
       Don't add a throw here, instead return an empty IRModule(...)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] electriclilies commented on pull request #8110: [WIP] Unify Python and C++ TIR lower API

Posted by GitBox <gi...@apache.org>.
electriclilies commented on pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#issuecomment-853362985


   @jroesch @junrushao1994 Tests are green, let me know if there's anything I should change


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org