You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/02/16 04:02:29 UTC
[tvm] branch main updated: [BYOC][Verilator] Refactor Verilator
runtime (#7406)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fc48514 [BYOC][Verilator] Refactor Verilator runtime (#7406)
fc48514 is described below
commit fc48514f1d8ccffcebd12007cb6c602506975703
Author: Luis Vega <ve...@users.noreply.github.com>
AuthorDate: Mon Feb 15 20:02:10 2021 -0800
[BYOC][Verilator] Refactor Verilator runtime (#7406)
* new experiment
* save
* refactor
* refactor library
* add profiler
* refactor
* refactor
* add docs
* update comment
* add deallocator
---
src/relay/backend/contrib/verilator/codegen.cc | 56 +++---
src/runtime/contrib/verilator/verilator_device.h | 39 +++-
src/runtime/contrib/verilator/verilator_runtime.cc | 197 +++++++++++----------
src/runtime/contrib/verilator/verilator_runtime.h | 138 +++++++++++++++
.../contrib/test_verilator/infrastructure.py | 6 +-
5 files changed, 307 insertions(+), 129 deletions(-)
diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc
index 2f61ae5..b206288 100644
--- a/src/relay/backend/contrib/verilator/codegen.cc
+++ b/src/relay/backend/contrib/verilator/codegen.cc
@@ -34,6 +34,7 @@
#include <sstream>
#include "../../../../runtime/contrib/json/json_node.h"
+#include "../../../../runtime/contrib/verilator/verilator_runtime.h"
#include "../../utils.h"
#include "../codegen_json/codegen_json.h"
@@ -75,29 +76,34 @@ class VerilatorJSONSerializer : public backend::contrib::JSONSerializer {
}
};
-/*! \brief Attributes to store the compiler options for Verilator */
-struct VerilatorCompilerConfigNode : public tvm::AttrsNode<VerilatorCompilerConfigNode> {
- String lib;
-
- TVM_DECLARE_ATTRS(VerilatorCompilerConfigNode, "ext.attrs.VerilatorCompilerConfigNode") {
- TVM_ATTR_FIELD(lib).set_default("libverilator.so");
+/*! \brief Attributes to store options for Verilator */
+struct VerilatorOptionsNode : public tvm::AttrsNode<VerilatorOptionsNode> {
+ String lib_path;
+ int reset_cycles;
+ bool profiler_enable;
+ int profiler_cycle_counter_id;
+
+ TVM_DECLARE_ATTRS(VerilatorOptionsNode, "ext.attrs.VerilatorOptionsNode") {
+ TVM_ATTR_FIELD(lib_path).describe("the design library path").set_default("libverilator.so");
+ TVM_ATTR_FIELD(reset_cycles).describe("the number of reset cycles").set_default(1);
+ TVM_ATTR_FIELD(profiler_enable).describe("enable profiler").set_default(false);
+ TVM_ATTR_FIELD(profiler_cycle_counter_id).describe("profiler cycle counter id").set_default(0);
}
};
-class VerilatorCompilerConfig : public Attrs {
+class VerilatorOptions : public Attrs {
public:
- TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorCompilerConfig, Attrs,
- VerilatorCompilerConfigNode);
+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VerilatorOptions, Attrs, VerilatorOptionsNode);
};
-TVM_REGISTER_NODE_TYPE(VerilatorCompilerConfigNode);
-TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorCompilerConfig);
+TVM_REGISTER_NODE_TYPE(VerilatorOptionsNode);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.verilator.options", VerilatorOptions);
/*!
- * \brief The external compiler/codegen tool. It takes a Relay expression/module and
- * compile it into a runtime module.
+ * \brief The Verilator codegen tool. It takes a Relay expression/module and
+ * compile it into a Verilator runtime module.
*/
-runtime::Module VerilatorCompiler(const ObjectRef& ref) {
+runtime::Module VerilatorBackend(const ObjectRef& ref) {
CHECK(ref->IsInstance<FunctionNode>());
auto func = Downcast<Function>(ref);
auto func_name = GetExtSymbol(func);
@@ -106,22 +112,28 @@ runtime::Module VerilatorCompiler(const ObjectRef& ref) {
std::string graph_json = serializer.GetJSON();
auto params = serializer.GetParams();
+ // Create runtime object
+ auto n = make_object<runtime::contrib::VerilatorRuntime>(func_name, graph_json, params);
+
// Get Verilator compiler options
auto ctx = transform::PassContext::Current();
- auto cfg = ctx->GetConfig<VerilatorCompilerConfig>("relay.ext.verilator.options");
+ auto cfg = ctx->GetConfig<VerilatorOptions>("relay.ext.verilator.options");
if (!cfg.defined()) {
- cfg = AttrsWithDefaultValues<VerilatorCompilerConfig>();
+ cfg = AttrsWithDefaultValues<VerilatorOptions>();
}
- auto lib_name = cfg.value()->lib;
+ n->SetLibrary(cfg.value()->lib_path);
+ n->SetResetCycles(cfg.value()->reset_cycles);
+
+ if (cfg.value()->profiler_enable) {
+ n->EnableProfiler();
+ n->SetProfilerCycleCounterId(cfg.value()->profiler_cycle_counter_id);
+ }
- const auto* pf = runtime::Registry::Get("runtime.verilator_runtime_create");
- CHECK(pf != nullptr) << "Cannot find JSON runtime module to create";
- auto mod = (*pf)(lib_name, func_name, graph_json, params);
- return mod;
+ return runtime::Module(n);
}
-TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorCompiler);
+TVM_REGISTER_GLOBAL("relay.ext.verilator").set_body_typed(VerilatorBackend);
} // namespace contrib
} // namespace relay
diff --git a/src/runtime/contrib/verilator/verilator_device.h b/src/runtime/contrib/verilator/verilator_device.h
index acd91a5..298e41c 100644
--- a/src/runtime/contrib/verilator/verilator_device.h
+++ b/src/runtime/contrib/verilator/verilator_device.h
@@ -31,24 +31,51 @@ namespace tvm {
namespace runtime {
namespace contrib {
+/*! \brief Verilator device resource context */
typedef void* VerilatorHandle;
-/* allocate Verilator object */
+/*!
+ * \brief Allocate a verilator device resource handle
+ * \return The verilator device handle.
+ */
extern "C" TVM_DLL VerilatorHandle VerilatorAlloc();
-/* deallocate Verilator object */
+/*!
+ * \brief Free a verilator device handle
+ * \param handle The verilator device handle to be freed.
+ */
extern "C" TVM_DLL void VerilatorDealloc(VerilatorHandle handle);
-/* read Verilator register or memory */
+/*!
+ * \brief Read verilator register or memory
+ * \param handle The verilator device handle.
+ * \param id The register or memory identifier.
+ * \param addr The register or memory address (word-level).
+ * \return The value of register or memory.
+ */
extern "C" TVM_DLL int VerilatorRead(VerilatorHandle handle, int id, int addr);
-/* write Verilator register or memory */
+/*!
+ * \brief Write verilator register or memory
+ * \param handle The verilator device handle.
+ * \param id The register or memory identifier.
+ * \param addr The register or memory address (word-level).
+ * \param value The value of register or memory.
+ */
extern "C" TVM_DLL void VerilatorWrite(VerilatorHandle handle, int id, int addr, int value);
-/* reset Verilator for n clock cycles */
+/*!
+ * \brief Reset Verilator for n clock cycles
+ * \param handle The verilator device handle.
+ * \param n The number of reset cycles.
+ */
extern "C" TVM_DLL void VerilatorReset(VerilatorHandle handle, int n);
-/* run Verilator for n clock cycles */
+/*!
+ * \brief Run Verilator for n clock cycles
+ * \param handle The verilator device handle.
+ * \param n The number of run cycles.
+ */
extern "C" TVM_DLL void VerilatorRun(VerilatorHandle handle, int n);
} // namespace contrib
diff --git a/src/runtime/contrib/verilator/verilator_runtime.cc b/src/runtime/contrib/verilator/verilator_runtime.cc
index 60f36e4..bc96b69 100644
--- a/src/runtime/contrib/verilator/verilator_runtime.cc
+++ b/src/runtime/contrib/verilator/verilator_runtime.cc
@@ -19,9 +19,11 @@
/*!
* \file src/runtime/contrib/verilator/verilator_runtime.cc
- * \brief A simple JSON runtime for Verilator.
+ * \brief A runtime for Verilator.
*/
+#include "verilator_runtime.h"
+
#include <dlfcn.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>
@@ -40,124 +42,123 @@ namespace tvm {
namespace runtime {
namespace contrib {
-typedef VerilatorHandle (*VerilatorAllocFunc)();
-typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
-typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
-
using namespace tvm::runtime;
+using namespace tvm::runtime::contrib;
using namespace tvm::runtime::json;
-class VerilatorLibrary : public Library {
- public:
- ~VerilatorLibrary() {
- if (lib_handle_) Unload();
- }
- void Init(const std::string& name) { Load(name); }
-
- void* GetSymbol(const char* name) final { return GetSymbol_(name); }
-
- private:
- // Library handle
- void* lib_handle_{nullptr};
- // load the library
- void Load(const std::string& name) {
- lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
- ICHECK(lib_handle_ != nullptr)
- << "Failed to load dynamic shared library " << name << " " << dlerror();
- }
-
- void* GetSymbol_(const char* name) { return dlsym(lib_handle_, name); }
-
- void Unload() {
+VerilatorLibrary::~VerilatorLibrary() {
+ if (lib_handle_) {
dlclose(lib_handle_);
lib_handle_ = nullptr;
}
-};
+}
-class VerilatorJSONRuntime : public JSONRuntimeBase {
- public:
- VerilatorJSONRuntime(const std::string& symbol_name, const std::string& graph_json,
- const Array<String> const_names)
- : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+void VerilatorLibrary::Load(const std::string& name) {
+ lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
+ ICHECK(lib_handle_ != nullptr) << "Failed to load dynamic shared library " << name << " "
+ << dlerror();
+}
- const char* type_key() const { return "verilator_json"; }
+void* VerilatorLibrary::GetSymbol(const char* name) { return dlsym(lib_handle_, name); }
- void LoadLibrary(const std::string& lib_name) {
- lib_ = new VerilatorLibrary();
- lib_->Init(lib_name);
- }
+void VerilatorProfiler::Clear() { cycle_counter = 0; }
- void Init(const Array<NDArray>& consts) override {
- // get symbols
- auto alloc_func = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
- ICHECK(alloc_func != nullptr);
- auto reset_func = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
- ICHECK(reset_func != nullptr);
- vadd_func_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
- ICHECK(vadd_func_ != nullptr);
+std::string VerilatorProfiler::AsJSON() {
+ std::ostringstream os;
+ os << "{\n"
+ << " \"cycle_counter\":" << cycle_counter << "\n"
+ << "}\n";
+ return os.str();
+}
- // alloc device
- device_ = (*alloc_func)();
+VerilatorProfiler* VerilatorProfiler::ThreadLocal() {
+ static thread_local VerilatorProfiler inst;
+ return &inst;
+}
- // reset for 10 cycles
- (*reset_func)(device_, 10);
+VerilatorRuntime::~VerilatorRuntime() {
+ auto dealloc = reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
+ ICHECK(dealloc != nullptr);
+ dealloc(device_);
+ lib_->~VerilatorLibrary();
+}
- CHECK_EQ(consts.size(), const_idx_.size())
- << "The number of input constants must match the number of required.";
+void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ = lib_path; }
- // Setup constants entries for weights.
- SetupConstants(consts);
- }
+void VerilatorRuntime::SetResetCycles(const int cycles) { reset_cycles_ = cycles; }
- void Run() override {
- std::vector<int*> in_ptr;
- std::vector<int*> out_ptr;
- for (size_t i = 0; i < input_nodes_.size(); ++i) {
- uint32_t eid = EntryID(input_nodes_[i], 0);
- int* data = static_cast<int*>(data_entry_[eid]->data);
- in_ptr.push_back(data);
- }
- for (size_t i = 0; i < outputs_.size(); ++i) {
- uint32_t eid = EntryID(outputs_[i]);
- int* data = static_cast<int*>(data_entry_[eid]->data);
- out_ptr.push_back(data);
- }
- for (size_t nid = 0; nid < nodes_.size(); ++nid) {
- const auto& node = nodes_[nid];
- if (node.GetOpType() == "kernel") {
- CHECK_EQ(node.GetOpType(), "kernel");
- auto op_name = node.GetOpName();
- if ("add" == op_name) {
- auto entry = node.GetInputs()[0];
- auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
- (*vadd_func_)(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
- } else {
- LOG(FATAL) << "Unsupported op: " << op_name;
- }
+void VerilatorRuntime::EnableProfiler() { prof_enable_ = true; }
+
+void VerilatorRuntime::SetProfilerCycleCounterId(const int id) { prof_cycle_counter_id_ = id; }
+
+void VerilatorRuntime::Init(const Array<NDArray>& consts) {
+ lib_ = new VerilatorLibrary();
+ lib_->Load(lib_path_);
+ auto alloc = reinterpret_cast<VerilatorAllocFunc>(lib_->GetSymbol("VerilatorAlloc"));
+ ICHECK(alloc != nullptr);
+ auto reset = reinterpret_cast<VerilatorResetFunc>(lib_->GetSymbol("VerilatorReset"));
+ ICHECK(reset != nullptr);
+ read_ = reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
+ ICHECK(read_ != nullptr);
+ add_op_ = reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
+
+ // alloc verilator device
+ device_ = alloc();
+
+ // enable profiler
+ if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();
+
+ // reset verilator device
+ reset(device_, reset_cycles_);
+
+ CHECK_EQ(consts.size(), const_idx_.size())
+ << "The number of input constants must match the number of required.";
+
+ // Setup constants entries for weights.
+ SetupConstants(consts);
+}
+
+void VerilatorRuntime::Run() {
+ std::vector<int*> in_ptr;
+ std::vector<int*> out_ptr;
+ for (size_t i = 0; i < input_nodes_.size(); ++i) {
+ uint32_t eid = EntryID(input_nodes_[i], 0);
+ int* data = static_cast<int*>(data_entry_[eid]->data);
+ in_ptr.push_back(data);
+ }
+ for (size_t i = 0; i < outputs_.size(); ++i) {
+ uint32_t eid = EntryID(outputs_[i]);
+ int* data = static_cast<int*>(data_entry_[eid]->data);
+ out_ptr.push_back(data);
+ }
+ for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+ const auto& node = nodes_[nid];
+ if (node.GetOpType() == "kernel") {
+ CHECK_EQ(node.GetOpType(), "kernel");
+ auto op_name = node.GetOpName();
+ if ("add" == op_name) {
+ auto entry = node.GetInputs()[0];
+ auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
+ ICHECK(add_op_ != nullptr);
+ add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
+ } else {
+ LOG(FATAL) << "Unsupported op: " << op_name;
}
}
}
-
- private:
- /* The verilator device handle. */
- VerilatorHandle device_{nullptr};
- /* The verilator library handle. */
- VerilatorLibrary* lib_{nullptr};
- /* The verilator vadd function handle. */
- VerilatorAddFunc vadd_func_{nullptr};
-};
-
-runtime::Module VerilatorJSONRuntimeCreate(String lib_name, String symbol_name, String graph_json,
- const Array<String>& const_names) {
- auto n = make_object<VerilatorJSONRuntime>(symbol_name, graph_json, const_names);
- n->LoadLibrary(lib_name);
- return runtime::Module(n);
+ if (prof_enable_) {
+ int cycles = read_(device_, prof_cycle_counter_id_, 0);
+ prof_->cycle_counter += cycles;
+ }
}
-TVM_REGISTER_GLOBAL("runtime.verilator_runtime_create").set_body_typed(VerilatorJSONRuntimeCreate);
+TVM_REGISTER_GLOBAL("verilator.profiler_clear").set_body([](TVMArgs args, TVMRetValue* rv) {
+ VerilatorProfiler::ThreadLocal()->Clear();
+});
-TVM_REGISTER_GLOBAL("runtime.module.loadbinary_verilator_json")
- .set_body_typed(JSONRuntimeBase::LoadFromBinary<VerilatorJSONRuntime>);
+TVM_REGISTER_GLOBAL("verilator.profiler_status").set_body([](TVMArgs args, TVMRetValue* rv) {
+ *rv = VerilatorProfiler::ThreadLocal()->AsJSON();
+});
} // namespace contrib
} // namespace runtime
diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h
new file mode 100644
index 0000000..acdaa3b
--- /dev/null
+++ b/src/runtime/contrib/verilator/verilator_runtime.h
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/verilator/verilator_runtime.h
+ * \brief A runtime for Verilator.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_
+
+#include <dlfcn.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <string>
+#include <vector>
+
+#include "../../library_module.h"
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "verilator_device.h"
+#include "verilator_kernel.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::contrib;
+using namespace tvm::runtime::json;
+
+typedef VerilatorHandle (*VerilatorAllocFunc)();
+typedef void (*VerilatorDeallocFunc)(VerilatorHandle);
+typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
+typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
+typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int);
+
+class VerilatorLibrary : public Library {
+ public:
+ ~VerilatorLibrary();
+
+ /*! \brief load library */
+ void Load(const std::string& name);
+
+ /*! \brief get symbol from libray */
+ void* GetSymbol(const char* name) final;
+
+ private:
+ /*! \brief the library handle */
+ void* lib_handle_{nullptr};
+};
+
+class VerilatorProfiler {
+ public:
+ /*! \brief the number of cycle counter */
+ uint32_t cycle_counter{0};
+
+ /*! \brief clear the profiler */
+ void Clear();
+
+ /*! \brief get profiler data */
+ std::string AsJSON();
+
+ /*! \brief profiler constructor */
+ static VerilatorProfiler* ThreadLocal();
+};
+
+class VerilatorRuntime : public JSONRuntimeBase {
+ public:
+ VerilatorRuntime(const std::string& symbol_name, const std::string& graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ ~VerilatorRuntime();
+
+ const char* type_key() const { return "verilator"; }
+
+ /*! \brief set verilator library */
+ void SetLibrary(const std::string& lib_name);
+
+ /*! \brief set the number of reset cycles */
+ void SetResetCycles(const int cycles);
+
+ /*! \brief enable profiler */
+ void EnableProfiler();
+
+ /*! \brief set cycle counter register id */
+ void SetProfilerCycleCounterId(const int id);
+
+ /*! \brief init verilator runtime */
+ void Init(const Array<NDArray>& consts) override;
+
+ /*! \brief run verilator runtime */
+ void Run() override;
+
+ private:
+ /*! \brief the verilator library path */
+ String lib_path_;
+ /*! \brief the verilator device */
+ VerilatorHandle device_{nullptr};
+ /*! \brief the verilator library */
+ VerilatorLibrary* lib_{nullptr};
+ /*! \brief the verilator profiler */
+ VerilatorProfiler* prof_{nullptr};
+ /*! \brief the verilator read function */
+ VerilatorReadFunc read_{nullptr};
+ /*! \brief the verilator add op function */
+ VerilatorAddFunc add_op_{nullptr};
+ /*! \brief the verilator reset cycles */
+ int reset_cycles_{1};
+ /*! \brief the verilator profiler status */
+ bool prof_enable_{false};
+ /*! \brief the verilator profiler cycle counter id */
+ int prof_cycle_counter_id_{0};
+};
+
+} // namespace contrib
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_RUNTIME_CONTRIB_VERILATOR_VERILATOR_RUNTIME_H_
diff --git a/tests/python/contrib/test_verilator/infrastructure.py b/tests/python/contrib/test_verilator/infrastructure.py
index e8fd943..7e4c297 100644
--- a/tests/python/contrib/test_verilator/infrastructure.py
+++ b/tests/python/contrib/test_verilator/infrastructure.py
@@ -102,9 +102,9 @@ def compile_module(mod):
if not os.path.isfile(lib):
compile_hardware()
- with tvm.transform.PassContext(
- opt_level=3, config={"relay.ext.verilator.options": {"lib": lib}}
- ):
+ opts = {"lib_path": lib}
+
+ with tvm.transform.PassContext(opt_level=3, config={"relay.ext.verilator.options": opts}):
exe = relay.vm.compile(mod, target="llvm", params=None)
code, lib = exe.save()
return runtime.vm.Executable.load_exec(code, lib)