You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/30 03:08:36 UTC
[tvm] branch main updated: [VM] Avoid round-trip
Target->str->Target conversions (#8161)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e535ec8 [VM] Avoid round-trip Target->str->Target conversions (#8161)
e535ec8 is described below
commit e535ec8a3bf6dee8f8370c169e707295eaae9222
Author: Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sat May 29 20:08:24 2021 -0700
[VM] Avoid round-trip Target->str->Target conversions (#8161)
Currently, in some cases this round-trip cannot be completed. For
example, if an Integer value has a value outside a 32-bit signed
integer range, or if a String value contains spaces.
Co-authored-by: Eric Lunderberg <el...@octoml.ai>
---
python/tvm/relay/backend/vm.py | 24 +++++++++++++++---------
src/relay/backend/vm/compiler.cc | 19 +++++++------------
2 files changed, 22 insertions(+), 21 deletions(-)
diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py
index 0b6d137..363ff89 100644
--- a/python/tvm/relay/backend/vm.py
+++ b/python/tvm/relay/backend/vm.py
@@ -198,20 +198,26 @@ class VMCompiler(object):
target = target if target else tvm.target.Target.current()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")
- tgts = {}
- if isinstance(target, (str, tvm.target.Target)):
- dev_type = tvm.tir.IntImm("int32", tvm.nd.device(str(target)).device_type)
- tgts[dev_type] = tvm.target.Target(target)
- elif isinstance(target, dict):
- for dev, tgt in target.items():
- dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
- tgts[dev_type] = tvm.target.Target(tgt)
- else:
+
+ if isinstance(target, str):
+ target = {target: target}
+ elif isinstance(target, tvm.target.Target):
+ target = {target.kind.name: target}
+ elif not isinstance(target, dict):
raise TypeError(
"target is expected to be str, tvm.target.Target, "
+ "or dict of str to str/tvm.target.Target, but received "
+ "{}".format(type(target))
)
+
+ tgts = {}
+ for dev, tgt in target.items():
+ dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
+ if isinstance(tgt, str):
+ tgt = tvm.target.Target(tgt)
+
+ tgts[dev_type] = tgt
+
return tgts
def _update_target_host(self, target, target_host):
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index ad23e13..b43972d 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1156,25 +1156,24 @@ void VMCompiler::Codegen() {
if (cached_funcs.size() == 0) {
return;
}
- std::unordered_map<std::string, IRModule> funcs;
+ Map<Target, IRModule> funcs;
for (auto& cfunc : cached_funcs) {
- std::string target_str = cfunc->target->str();
+ Target target = cfunc->target;
// NOTE: because module, is mutable, we need to make an
// explicit copy of the IRModule.
IRModule mod = cfunc->funcs;
mod.CopyOnWrite();
- if (target_str == "ext_dev") {
+ if (target->kind->device_type == kDLExtDev) {
// Collect metadata in functions that are handled by external codegen.
ICHECK(mod->ContainGlobalVar(cfunc->func_name));
Function func = Downcast<Function>(mod->Lookup(cfunc->func_name));
backend::UpdateConstants(func, ¶ms_);
- continue;
- } else if (funcs.count(target_str) == 0) {
- funcs.emplace(target_str, mod);
+ } else if (funcs.count(target) == 0) {
+ funcs.Set(target, mod);
} else {
- funcs[target_str]->Update(mod);
+ funcs[target]->Update(mod);
}
}
@@ -1182,11 +1181,7 @@ void VMCompiler::Codegen() {
auto ext_mods = compile_engine->LowerExternalFunctions();
runtime::Module lib;
if (funcs.size() > 0) {
- Map<String, IRModule> build_funcs;
- for (const auto& i : funcs) {
- build_funcs.Set(i.first, i.second);
- }
- lib = tvm::build(build_funcs, target_host_);
+ lib = tvm::build(funcs, target_host_);
} else {
// There is no function handled by TVM. We create a virtual main module
// to make sure a DSO module will be also available.