You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/05/30 03:08:36 UTC

[tvm] branch main updated: [VM] Avoid round-trip Target->str->Target conversions (#8161)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e535ec8  [VM] Avoid round-trip Target->str->Target conversions (#8161)
e535ec8 is described below

commit e535ec8a3bf6dee8f8370c169e707295eaae9222
Author: Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Sat May 29 20:08:24 2021 -0700

    [VM] Avoid round-trip Target->str->Target conversions (#8161)
    
    Currently, in some cases this round-trip cannot be completed.  For
    example, if an Integer value has a value outside a 32-bit signed
    integer range, or if a String value contains spaces.
    
    Co-authored-by: Eric Lunderberg <el...@octoml.ai>
---
 python/tvm/relay/backend/vm.py   | 24 +++++++++++++++---------
 src/relay/backend/vm/compiler.cc | 19 +++++++------------
 2 files changed, 22 insertions(+), 21 deletions(-)

diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py
index 0b6d137..363ff89 100644
--- a/python/tvm/relay/backend/vm.py
+++ b/python/tvm/relay/backend/vm.py
@@ -198,20 +198,26 @@ class VMCompiler(object):
         target = target if target else tvm.target.Target.current()
         if target is None:
             raise ValueError("Target is not set in env or passed as argument.")
-        tgts = {}
-        if isinstance(target, (str, tvm.target.Target)):
-            dev_type = tvm.tir.IntImm("int32", tvm.nd.device(str(target)).device_type)
-            tgts[dev_type] = tvm.target.Target(target)
-        elif isinstance(target, dict):
-            for dev, tgt in target.items():
-                dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
-                tgts[dev_type] = tvm.target.Target(tgt)
-        else:
+
+        if isinstance(target, str):
+            target = {target: target}
+        elif isinstance(target, tvm.target.Target):
+            target = {target.kind.name: target}
+        elif not isinstance(target, dict):
             raise TypeError(
                 "target is expected to be str, tvm.target.Target, "
                 + "or dict of str to str/tvm.target.Target, but received "
                 + "{}".format(type(target))
             )
+
+        tgts = {}
+        for dev, tgt in target.items():
+            dev_type = tvm.tir.IntImm("int32", tvm.nd.device(dev).device_type)
+            if isinstance(tgt, str):
+                tgt = tvm.target.Target(tgt)
+
+            tgts[dev_type] = tgt
+
         return tgts
 
     def _update_target_host(self, target, target_host):
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index ad23e13..b43972d 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1156,25 +1156,24 @@ void VMCompiler::Codegen() {
   if (cached_funcs.size() == 0) {
     return;
   }
-  std::unordered_map<std::string, IRModule> funcs;
+  Map<Target, IRModule> funcs;
 
   for (auto& cfunc : cached_funcs) {
-    std::string target_str = cfunc->target->str();
+    Target target = cfunc->target;
     // NOTE: because module, is mutable, we need to make an
     // explicit copy of the IRModule.
     IRModule mod = cfunc->funcs;
     mod.CopyOnWrite();
 
-    if (target_str == "ext_dev") {
+    if (target->kind->device_type == kDLExtDev) {
       // Collect metadata in functions that are handled by external codegen.
       ICHECK(mod->ContainGlobalVar(cfunc->func_name));
       Function func = Downcast<Function>(mod->Lookup(cfunc->func_name));
       backend::UpdateConstants(func, &params_);
-      continue;
-    } else if (funcs.count(target_str) == 0) {
-      funcs.emplace(target_str, mod);
+    } else if (funcs.count(target) == 0) {
+      funcs.Set(target, mod);
     } else {
-      funcs[target_str]->Update(mod);
+      funcs[target]->Update(mod);
     }
   }
 
@@ -1182,11 +1181,7 @@ void VMCompiler::Codegen() {
   auto ext_mods = compile_engine->LowerExternalFunctions();
   runtime::Module lib;
   if (funcs.size() > 0) {
-    Map<String, IRModule> build_funcs;
-    for (const auto& i : funcs) {
-      build_funcs.Set(i.first, i.second);
-    }
-    lib = tvm::build(build_funcs, target_host_);
+    lib = tvm::build(funcs, target_host_);
   } else {
     // There is no function handled by TVM. We create a virtual main module
     // to make sure a DSO module will be also available.