You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/14 08:53:34 UTC

[GitHub] [tvm] mikepapadim commented on a change in pull request #9103: [Core][Build] Move build module transformations and utilities to C++

mikepapadim commented on a change in pull request #9103:
URL: https://github.com/apache/tvm/pull/9103#discussion_r728774603



##########
File path: src/driver/driver_api.cc
##########
@@ -373,97 +394,96 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
       return LowerSchedule(std::move(sch), args, name, c_binds, simple_mode);
     });
 
-std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target& target_arg,
-                                                const Target& target_host_arg,
-                                                const transform::PassContext& pass_ctx) {
+/**
+ * This function takes the input module that contains both the device and host opts.
+ * Then, it applies transformation on the original module before splitting into separate modules for
+ * device and host. Then it also applies transformations on the new splitted modules.
+ */
+std::pair<IRModule, IRModule> SplitMixedModule(IRModule mod_mixed, const Target& target_arg,
+                                               const Target& target_host_arg) {
   Target target = target_arg, target_host = target_host_arg;
   CheckAndUpdateHostConsistency(&target, &target_host);
-  Array<tvm::transform::Pass> mixed_pass_list = {BindTarget(target),
-                                                 tir::transform::VerifyMemory()};
 
-  mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations());
-  if (pass_ctx->GetConfig<Bool>("tir.detect_global_barrier", Bool(false)).value()) {
-    mixed_pass_list.push_back(tir::transform::ThreadSync("global"));
-  }
-  mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
-  mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
-  mixed_pass_list.push_back(tir::transform::InferFragment());
-  mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
+  ICHECK(mod_mixed.defined()) << "This module must be defined";
 
-  if (target->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {
-    mixed_pass_list.push_back(tir::transform::MakeUnpackedAPI());
-  } else {
-    mixed_pass_list.push_back(tir::transform::MakePackedAPI(-1));
-  }
+  mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
 
-  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+  IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host));
 
-  auto opt_mixed = transform::Sequential(mixed_pass_list);
-  mod_mixed = opt_mixed(std::move(mod_mixed));
-
-  // We make an assumption here that the overriden host target
-  // can be used alongside the default host codegen based on device type
-  // this is so the correct code generator is used later instead of overriding the target.
-  // We need better support for inserting multiple kDLCPU targets as our current options
-  // are kDeviceKernelLaunch or not
-  Target overriden_host_target = target_host;
-  if (target->kind->device_type == target_host->kind->device_type) {
-    overriden_host_target = target;
-  }
-  auto host_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(overriden_host_target),
-      tir::transform::LowerTVMBuiltin(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-      tir::transform::CombineContextCall(),
-  };
-  auto opt_host = transform::Sequential(host_pass_list);
-  ICHECK(mod_mixed.defined()) << "This module must be defined";
-  auto mhost = opt_host(mod_mixed);
-
-  // device pipeline
-  auto device_pass_list = {
-      Filter([](const tir::PrimFunc& f) {
-        return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
-               CallingConv::kDeviceKernelLaunch;
-      }),
-      BindTarget(target),
-      tir::transform::LowerWarpMemory(),
-      tir::transform::Simplify(),
-      tir::transform::LowerCustomDatatypes(),
-      tir::transform::LowerIntrin(),
-      tir::transform::LowerDeviceStorageAccessInfo(),
-  };
-  auto opt_device = transform::Sequential(device_pass_list);
-  auto mdevice = opt_device(mod_mixed);
+  IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target));
 
-  // some final misc checks.
   auto keys = target->GetKeys();
+
+  CheckAndUpdateHostConsistency(&target, &target_host);
+
   bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end();
-  if (target_is_gpu && mdevice->functions.size() == 0) {
-    LOG(WARNING) << "Specified target " << target->str()
-                 << " but cannot find device code. Did you forget to bind?";
+  if (target_is_gpu && device_mod->functions.size() == 0) {
+    DLOG(WARNING) << "Specified target " << target->str()
+                  << " but cannot find device code. Did you forget to bind?";
+  }
+
+  return {host_mod, device_mod};
+}
+
+runtime::Module FinalizeModule(const Map<Target, IRModule>& inputs_arg, const Target& host_target) {

Review comment:
       Sure, IMO there is some more cleanup to be done in `build_module`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org