You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/08 08:54:06 UTC

[GitHub] [tvm] Hzfengsy commented on a change in pull request #8110: Unify Python and C++ TIR lower API

Hzfengsy commented on a change in pull request #8110:
URL: https://github.com/apache/tvm/pull/8110#discussion_r647245911



##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);

Review comment:
       Please explicitly mark the type as possible as we can. `auto` is not so user-friendly that others can understand the code.

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);

Review comment:
       ```suggestion
     return LowerWithPassList(std::(mod), pass_list);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);
+      return mod;
+    });
+
+IRModule LowerModule(IRModule mod, bool simple_mode) {
+  auto pass_list = CreatePassList(simple_mode, false);
+  return LowerWithPassList(mod, pass_list);

Review comment:
       Please apply `std::move` to all following FFI functions

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);
+  bool compact = te::VerifyCompactBuffer(stmt);
+
+  Map<te::Tensor, tir::Buffer> out_binds;
+  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
+
+  // Build the function
+  // At this point binds is only te::Tensors
+  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
+  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
+
+  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
+
+  if (noalias) {
+    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  }
+  return IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
+}
+
+TVM_REGISTER_GLOBAL("driver.schedule_to_module")
+    .set_body_typed([](te::Schedule sch, const Array<ObjectRef>& args, const String& name,
+                       const Map<te::Tensor, tir::Buffer>& binds) {
+      std::unordered_map<te::Tensor, tir::Buffer> c_binds;
+      // Check to make sure binds is not null before doing the conversion;
+      if (binds.get() != nullptr) {
+        for (auto kv : binds) {
+          c_binds.insert({kv.first, kv.second});
+        }
+      }
+      IRModule mod = ScheduleToModule(sch, args, name, c_binds);

Review comment:
       ```suggestion
         IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds);
   ```

##########
File path: src/driver/driver_api.cc
##########
@@ -128,63 +174,208 @@ transform::Pass Filter(FCond fcond) {
   return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
 }
 
-IRModule lower(te::Schedule sch, const Array<te::Tensor>& args, const std::string& name,
-               const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
-  Array<ObjectRef> out_arg_list;
+Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition, bool for_te_schedule) {
   auto pass_ctx = transform::PassContext::Current();
 
-  sch = sch.normalize();
-
-  // Before TIR transformation.
-  auto bounds = te::InferBound(sch);
-  auto stmt = te::ScheduleOps(sch, bounds, false);
-  bool compact = te::VerifyCompactBuffer(stmt);
-
-  Map<te::Tensor, tir::Buffer> out_binds;
-  GetBinds(args, compact, binds, &out_binds, &out_arg_list);
-
-  // build the function
-  tir::PrimFunc f = te::SchedulePostProcToPrimFunc(out_arg_list, std::move(stmt), out_binds);
-  f = WithAttr(std::move(f), "global_symbol", runtime::String(name));
-
-  bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
   bool disable_vectorize = pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
   bool instrument_bound_checkers =
       pass_ctx->GetConfig<Bool>("tir.instrument_bound_checkers", Bool(false)).value();
 
-  if (noalias) {
-    f = WithAttr(std::move(f), "tir.noalias", Bool(true));
+  // Get any user-added passes
+  auto add_lower_pass =
+      pass_ctx->GetConfig<Array<Array<ObjectRef>>>("tir.add_lower_pass", Array<Array<ObjectRef>>())
+          .value();
+
+  auto user_lower_phase0 = Array<tvm::transform::Pass>();
+  auto user_lower_phase1 = Array<tvm::transform::Pass>();
+  auto user_lower_phase2 = Array<tvm::transform::Pass>();
+  auto user_lower_phase3 = Array<tvm::transform::Pass>();
+
+  // phase pasees is of the form
+  // [[phase_number, pass], [phase_number, pass]... ]
+  for (auto phase_pass : add_lower_pass) {
+    auto phase_num = phase_pass[0].as<IntImmNode>();
+    ICHECK(phase_num)
+        << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer";
+    int phase_num_val = phase_num->value;
+
+    CHECK_GE(phase_num_val, 0);
+
+    auto pass_node = phase_pass[1].as<tvm::transform::PassNode>();
+    auto pass = GetRef<tvm::transform::Pass>(pass_node);
+    // Copy the pass into the correct phase
+    if (phase_num_val == 0) {
+      user_lower_phase0.push_back(pass);
+    } else if (phase_num_val == 1) {
+      user_lower_phase1.push_back(pass);
+    } else if (phase_num_val == 2) {
+      user_lower_phase2.push_back(pass);
+    } else if (phase_num_val >= 3) {
+      user_lower_phase3.push_back(pass);
+    }
   }
 
-  auto mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(name), f}}));
-  auto pass_list = Array<tvm::transform::Pass>();
+  // Construct the pass list, inserting the user provided passes at the end of the phase
+
+  // PHASE 0
+  auto pass_list = user_lower_phase0;
 
-  // Phase 0
-  pass_list.push_back(tir::transform::InjectPrefetch());
-  pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
-  // Phase 1
+  // PHASE 1
+  if (for_te_schedule) {
+    pass_list.push_back(tir::transform::InjectPrefetch());
+    pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers));
+  } else {
+    pass_list.push_back(tir::transform::LowerInitBlock());
+    pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
+    pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
+    pass_list.push_back(tir::transform::CompactBufferAllocation());
+    pass_list.push_back(tir::transform::FlattenBuffer());
+  }
   pass_list.push_back(tir::transform::BF16Legalize());
   pass_list.push_back(tir::transform::NarrowDataType(32));
   pass_list.push_back(tir::transform::Simplify());
-  pass_list.push_back(tir::transform::LoopPartition());
+
+  // Add user-defined phase-1 passes
+  pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end());
+
+  // PHASE 2
+  if (!disable_loop_partition) {
+    pass_list.push_back(tir::transform::LoopPartition());
+  }
+
   pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize));
   pass_list.push_back(tir::transform::InjectVirtualThread());
   pass_list.push_back(tir::transform::InjectDoubleBuffer());
   pass_list.push_back(tir::transform::StorageRewrite());
   pass_list.push_back(tir::transform::UnrollLoop());
-  // Phase 2
+
+  // Add user-defined phase-2 passes
+  pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());
+
+  // PHASE 3
   pass_list.push_back(tir::transform::Simplify());
   pass_list.push_back(tir::transform::RemoveNoOp());
   pass_list.push_back(tir::transform::RewriteUnsafeSelect());
+  pass_list.push_back(tir::transform::HoistIfThenElse());
+
+  // Add user-defined phase-3 passes
+  pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end());
+
   if (instrument_bound_checkers) {
     pass_list.push_back(tir::transform::InstrumentBoundCheckers());
   }
-  // run
+  return pass_list;
+}
+
+IRModule LowerWithPassList(IRModule mod, Array<tvm::transform::Pass> pass_list) {
   auto optimize = transform::Sequential(pass_list);
   mod = optimize(std::move(mod));
   return mod;
 }
 
+IRModule ScheduleToModule(te::Schedule sch, const Array<ObjectRef>& args, const std::string& name,
+                          const std::unordered_map<te::Tensor, tir::Buffer>& binds) {
+  // Convert te schedule to IRModule
+  Array<ObjectRef> out_arg_list;
+  auto pass_ctx = transform::PassContext::Current();
+
+  sch = sch.normalize();
+
+  // Before TIR transformation.
+  auto bounds = te::InferBound(sch);
+  auto stmt = te::ScheduleOps(sch, bounds, false);

Review comment:
       Please use const ref or `std::move` to reduce unnecessary memory copy
   ```suggestion
     auto stmt = te::ScheduleOps(sch, std::move(bounds), false);
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org