You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/01 11:32:32 UTC

[tvm] branch main updated: [OpenCL] Refactor cl_program generation (#7834)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2215d73  [OpenCL] Refactor cl_program generation (#7834)
2215d73 is described below

commit 2215d734339ab918a53b8d4f3a770a6c65495784
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Sat May 1 04:32:12 2021 -0700

    [OpenCL] Refactor cl_program generation (#7834)
    
    * Refactor OpenCL runtime module to build separate cl_programs
    for each kernel. This can avoid pathological bugs in the
    vendor specific OpenCL compiler that may be triggered
    with large programs.
    
    * clang-format
    
    * Remove check on program size when deconstructing.
    
    * Refactor into SplitKernels method.
    
    * Limit number of loops for kernel parsing
    
    * Add return doc for SplitKernels per CR.
---
 src/runtime/opencl/opencl_common.h  | 16 ++++++--
 src/runtime/opencl/opencl_module.cc | 80 ++++++++++++++++++++++++++++---------
 src/target/source/codegen_opencl.cc | 18 +++++----
 3 files changed, 85 insertions(+), 29 deletions(-)

diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h
index 64a9f2c..93420fe 100644
--- a/src/runtime/opencl/opencl_common.h
+++ b/src/runtime/opencl/opencl_common.h
@@ -326,6 +326,14 @@ class OpenCLModuleNode : public ModuleNode {
   cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
                           const std::string& func_name, const KTRefEntry& e);
 
+  /*
+   * \brief Splits the provided serialized source file into separate
+   * source for each kernel primitive.
+   * \param source The serialized program source file (fmt: cl)
+   * \return Mapping from primitive name to kernel source
+   */
+  std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;
+
  private:
   // The workspace, need to keep reference to use it in destructor.
   // In case of static destruction order problem.
@@ -340,14 +348,14 @@ class OpenCLModuleNode : public ModuleNode {
   std::mutex build_lock_;
   // The OpenCL source.
   std::string source_;
-  // the binary data
-  cl_program program_{nullptr};
-  // build info
-  std::vector<bool> device_built_flag_;
+  // Mapping from primitive name to cl program for each device.
+  std::unordered_map<std::string, std::vector<cl_program>> programs_;
   // kernel id cache
   std::unordered_map<std::string, KTRefEntry> kid_map_;
   // kernels build so far.
   std::vector<cl_kernel> kernels_;
+  // parsed kernel data
+  std::unordered_map<std::string, std::string> parsed_kernels_;
 };
 
 }  // namespace runtime
diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc
index 8c22c3c..6543b1d 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -105,8 +105,13 @@ OpenCLModuleNode::~OpenCLModuleNode() {
   for (cl_kernel k : kernels_) {
     OPENCL_CALL(clReleaseKernel(k));
   }
-  if (program_) {
-    OPENCL_CALL(clReleaseProgram(program_));
+  // free the programs
+  for (auto& kv : programs_) {
+    for (auto& program : kv.second) {
+      if (program) {
+        OPENCL_CALL(clReleaseProgram(program));
+      }
+    }
   }
 }
 
@@ -166,7 +171,6 @@ std::string OpenCLModuleNode::GetSource(const std::string& format) {
 void OpenCLModuleNode::Init() {
   workspace_ = GetGlobalWorkspace();
   workspace_->Init();
-  device_built_flag_.resize(workspace_->devices.size(), false);
   // initialize the kernel id, need to lock global table.
   std::lock_guard<std::mutex> lock(workspace_->mu);
   for (const auto& kv : fmap_) {
@@ -181,28 +185,34 @@ void OpenCLModuleNode::Init() {
     e.version = workspace_->timestamp++;
     kid_map_[key] = e;
   }
+
+  // split into source artifacts for each kernel
+  parsed_kernels_ = SplitKernels(GetSource("cl"));
+  // zero initialize cl_program pointers for each device kernel
+  for (auto& kv : parsed_kernels_) {
+    programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
+  }
 }
 
 cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
                                           const std::string& func_name, const KTRefEntry& e) {
   std::lock_guard<std::mutex> lock(build_lock_);
   int device_id = t->device.device_id;
-  if (!device_built_flag_[device_id]) {
+  if (programs_[func_name][device_id] == nullptr) {
     // create program
     if (fmt_ == "cl") {
-      if (program_ == nullptr) {
-        const char* s = data_.c_str();
-        size_t len = data_.length();
-        cl_int err;
-        program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
-        OPENCL_CHECK_ERROR(err);
-      }
+      const char* s = parsed_kernels_[func_name].c_str();
+      size_t len = parsed_kernels_[func_name].length();
+      cl_int err;
+      programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
+      OPENCL_CHECK_ERROR(err);
     } else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
       const unsigned char* s = (const unsigned char*)data_.c_str();
       size_t len = data_.length();
       cl_int err;
       cl_device_id dev = w->devices[device_id];
-      program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
+      programs_[func_name][device_id] =
+          clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
       OPENCL_CHECK_ERROR(err);
     } else {
       LOG(FATAL) << "Unknown OpenCL format " << fmt_;
@@ -210,20 +220,21 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
     // build program
     cl_int err;
     cl_device_id dev = w->devices[device_id];
-    err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
+    err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr);
     if (err != CL_SUCCESS) {
       size_t len;
       std::string log;
-      clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
+      clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr,
+                            &len);
       log.resize(len);
-      clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
-      LOG(FATAL) << "OpenCL build error for device=" << dev << log;
+      clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len,
+                            &log[0], nullptr);
+      LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log;
     }
-    device_built_flag_[device_id] = true;
   }
   // build kernel
   cl_int err;
-  cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
+  cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err);
   OPENCL_CHECK_ERROR(err);
   t->kernel_table[e.kernel_id].kernel = kernel;
   t->kernel_table[e.kernel_id].version = e.version;
@@ -231,6 +242,39 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
   return kernel;
 }
 
+std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
+    std::string source) const {
+  std::unordered_map<std::string, std::string> split_kernels;
+  if (source.size()) {
+    std::string del{"// Function: "};
+    size_t end;
+    size_t begin = source.find(del);
+    ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
+                                       << "source from code generation, but no kernel "
+                                       << "delimiter was found.";
+    for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
+      begin += del.size();
+      end = source.find('\n', begin);
+      std::string func_name = source.substr(begin, end - begin);
+      begin = ++end;
+      // std::string::substr returns either start of next kernel
+      // or std::string::npos, in the latter case substr returns
+      // all characters until the end of the source string.
+      end = source.find(del, begin);
+      std::string func_source =
+          source.substr(begin, (end == std::string::npos) ? end : end - begin);
+      split_kernels.insert({func_name, func_source});
+      begin = end;
+      if (end == std::string::npos) {
+        break;
+      }
+    }
+  }
+  ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
+      << "The number of registered kernels does not match number of parsed kernel sources";
+  return split_kernels;
+}
+
 Module OpenCLModuleCreate(std::string data, std::string fmt,
                           std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
   auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index f72f3f2..edb614d 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -283,23 +283,27 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) {  // N
 runtime::Module BuildOpenCL(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
-  CodeGenOpenCL cg;
-  cg.Init(output_ssa);
 
+  std::stringstream code;
+  const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
   for (auto kv : mod->functions) {
     ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+    code << "// Function: " << kv.first->name_hint << std::endl;
+    CodeGenOpenCL cg;
+    cg.Init(output_ssa);
     auto f = Downcast<PrimFunc>(kv.second);
     auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
     cg.AddFunction(f);
+    std::string fsource = cg.Finish();
+    if (fpostproc) {
+      fsource = (*fpostproc)(fsource).operator std::string();
+    }
+    code << fsource;
   }
 
-  std::string code = cg.Finish();
-  if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
-    code = (*f)(code).operator std::string();
-  }
-  return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
+  return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str());
 }
 
 TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);