You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/05/01 11:32:32 UTC
[tvm] branch main updated: [OpenCL] Refactor cl_program generation
(#7834)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2215d73 [OpenCL] Refactor cl_program generation (#7834)
2215d73 is described below
commit 2215d734339ab918a53b8d4f3a770a6c65495784
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Sat May 1 04:32:12 2021 -0700
[OpenCL] Refactor cl_program generation (#7834)
* Refactor OpenCL runtime module to build separate cl_programs
for each kernel. This can avoid pathological bugs in the
vendor specific OpenCL compiler that may be triggered
with large programs.
* clang-format
* Remove check on program size when deconstructing.
* Refactor into SplitKernels method.
* Limit number of loops for kernel parsing
* Add return doc for SplitKernels per CR.
---
src/runtime/opencl/opencl_common.h | 16 ++++++--
src/runtime/opencl/opencl_module.cc | 80 ++++++++++++++++++++++++++++---------
src/target/source/codegen_opencl.cc | 18 +++++----
3 files changed, 85 insertions(+), 29 deletions(-)
diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h
index 64a9f2c..93420fe 100644
--- a/src/runtime/opencl/opencl_common.h
+++ b/src/runtime/opencl/opencl_common.h
@@ -326,6 +326,14 @@ class OpenCLModuleNode : public ModuleNode {
cl_kernel InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e);
+ /*
+ * \brief Splits the provided serialized source file into separate
+ * source for each kernel primitive.
+ * \param source The serialized program source file (fmt: cl)
+ * \return Mapping from primitive name to kernel source
+ */
+ std::unordered_map<std::string, std::string> SplitKernels(std::string source) const;
+
private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
@@ -340,14 +348,14 @@ class OpenCLModuleNode : public ModuleNode {
std::mutex build_lock_;
// The OpenCL source.
std::string source_;
- // the binary data
- cl_program program_{nullptr};
- // build info
- std::vector<bool> device_built_flag_;
+ // Mapping from primitive name to cl program for each device.
+ std::unordered_map<std::string, std::vector<cl_program>> programs_;
// kernel id cache
std::unordered_map<std::string, KTRefEntry> kid_map_;
// kernels build so far.
std::vector<cl_kernel> kernels_;
+ // parsed kernel data
+ std::unordered_map<std::string, std::string> parsed_kernels_;
};
} // namespace runtime
diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc
index 8c22c3c..6543b1d 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -105,8 +105,13 @@ OpenCLModuleNode::~OpenCLModuleNode() {
for (cl_kernel k : kernels_) {
OPENCL_CALL(clReleaseKernel(k));
}
- if (program_) {
- OPENCL_CALL(clReleaseProgram(program_));
+ // free the programs
+ for (auto& kv : programs_) {
+ for (auto& program : kv.second) {
+ if (program) {
+ OPENCL_CALL(clReleaseProgram(program));
+ }
+ }
}
}
@@ -166,7 +171,6 @@ std::string OpenCLModuleNode::GetSource(const std::string& format) {
void OpenCLModuleNode::Init() {
workspace_ = GetGlobalWorkspace();
workspace_->Init();
- device_built_flag_.resize(workspace_->devices.size(), false);
// initialize the kernel id, need to lock global table.
std::lock_guard<std::mutex> lock(workspace_->mu);
for (const auto& kv : fmap_) {
@@ -181,28 +185,34 @@ void OpenCLModuleNode::Init() {
e.version = workspace_->timestamp++;
kid_map_[key] = e;
}
+
+ // split into source artifacts for each kernel
+ parsed_kernels_ = SplitKernels(GetSource("cl"));
+ // zero initialize cl_program pointers for each device kernel
+ for (auto& kv : parsed_kernels_) {
+ programs_.insert({kv.first, std::vector<cl_program>(workspace_->devices.size(), nullptr)});
+ }
}
cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThreadEntry* t,
const std::string& func_name, const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int device_id = t->device.device_id;
- if (!device_built_flag_[device_id]) {
+ if (programs_[func_name][device_id] == nullptr) {
// create program
if (fmt_ == "cl") {
- if (program_ == nullptr) {
- const char* s = data_.c_str();
- size_t len = data_.length();
- cl_int err;
- program_ = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
- OPENCL_CHECK_ERROR(err);
- }
+ const char* s = parsed_kernels_[func_name].c_str();
+ size_t len = parsed_kernels_[func_name].length();
+ cl_int err;
+ programs_[func_name][device_id] = clCreateProgramWithSource(w->context, 1, &s, &len, &err);
+ OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin" || fmt_ == "aocx") {
const unsigned char* s = (const unsigned char*)data_.c_str();
size_t len = data_.length();
cl_int err;
cl_device_id dev = w->devices[device_id];
- program_ = clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
+ programs_[func_name][device_id] =
+ clCreateProgramWithBinary(w->context, 1, &dev, &len, &s, NULL, &err);
OPENCL_CHECK_ERROR(err);
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
@@ -210,20 +220,21 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
// build program
cl_int err;
cl_device_id dev = w->devices[device_id];
- err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
+ err = clBuildProgram(programs_[func_name][device_id], 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
std::string log;
- clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, 0, nullptr, &len);
+ clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, 0, nullptr,
+ &len);
log.resize(len);
- clGetProgramBuildInfo(program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
- LOG(FATAL) << "OpenCL build error for device=" << dev << log;
+ clGetProgramBuildInfo(programs_[func_name][device_id], dev, CL_PROGRAM_BUILD_LOG, len,
+ &log[0], nullptr);
+ LOG(FATAL) << "OpenCL build error for device=" << dev << "\n" << log;
}
- device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
- cl_kernel kernel = clCreateKernel(program_, func_name.c_str(), &err);
+ cl_kernel kernel = clCreateKernel(programs_[func_name][device_id], func_name.c_str(), &err);
OPENCL_CHECK_ERROR(err);
t->kernel_table[e.kernel_id].kernel = kernel;
t->kernel_table[e.kernel_id].version = e.version;
@@ -231,6 +242,39 @@ cl_kernel OpenCLModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenCLThre
return kernel;
}
+std::unordered_map<std::string, std::string> OpenCLModuleNode::SplitKernels(
+ std::string source) const {
+ std::unordered_map<std::string, std::string> split_kernels;
+ if (source.size()) {
+ std::string del{"// Function: "};
+ size_t end;
+ size_t begin = source.find(del);
+ ICHECK(begin != std::string::npos) << "The OpenCL module expects a kernel delimited "
+ << "source from code generation, but no kernel "
+ << "delimiter was found.";
+ for (size_t num_kernels = 0; num_kernels < workspace_->num_registered_kernels; num_kernels++) {
+ begin += del.size();
+ end = source.find('\n', begin);
+ std::string func_name = source.substr(begin, end - begin);
+ begin = ++end;
+ // std::string::substr returns either start of next kernel
+ // or std::string::npos, in the latter case substr returns
+ // all characters until the end of the source string.
+ end = source.find(del, begin);
+ std::string func_source =
+ source.substr(begin, (end == std::string::npos) ? end : end - begin);
+ split_kernels.insert({func_name, func_source});
+ begin = end;
+ if (end == std::string::npos) {
+ break;
+ }
+ }
+ }
+ ICHECK_EQ(workspace_->num_registered_kernels, split_kernels.size())
+ << "The number of registered kernels does not match number of parsed kernel sources";
+ return split_kernels;
+}
+
Module OpenCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
auto n = make_object<OpenCLModuleNode>(data, fmt, fmap, source);
diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc
index f72f3f2..edb614d 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -283,23 +283,27 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N
runtime::Module BuildOpenCL(IRModule mod, Target target) {
using tvm::runtime::Registry;
bool output_ssa = false;
- CodeGenOpenCL cg;
- cg.Init(output_ssa);
+ std::stringstream code;
+ const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
for (auto kv : mod->functions) {
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc";
+ code << "// Function: " << kv.first->name_hint << std::endl;
+ CodeGenOpenCL cg;
+ cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
cg.AddFunction(f);
+ std::string fsource = cg.Finish();
+ if (fpostproc) {
+ fsource = (*fpostproc)(fsource).operator std::string();
+ }
+ code << fsource;
}
- std::string code = cg.Finish();
- if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
- code = (*f)(code).operator std::string();
- }
- return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(mod), code);
+ return OpenCLModuleCreate(code.str(), "cl", ExtractFuncInfo(mod), code.str());
}
TVM_REGISTER_GLOBAL("target.build.opencl").set_body_typed(BuildOpenCL);