You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/13 10:40:20 UTC

[tvm] branch main updated: [VM] Move param bind to OptimizeModule (#7451)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0aa90b0  [VM] Move param bind to OptimizeModule (#7451)
0aa90b0 is described below

commit 0aa90b093fd7e842eb88fa8e9994f70f24ba2bbf
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Feb 13 19:40:00 2021 +0900

    [VM] Move param bind to OptimizeModule (#7451)
    
    * [VM] Move param bind to OptimizeModule
    
    * add test to verify the number of free vars after opt
    
    * remove const from OptimizeModule
---
 src/relay/backend/vm/compiler.cc | 20 ++++++++++----------
 src/relay/backend/vm/compiler.h  |  3 +--
 tests/python/relay/test_vm.py    |  4 ++++
 3 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 7861502..7697b59 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -892,15 +892,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
 }
 
 void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
-  if (params_.size()) {
-    BaseFunc base_func = mod->Lookup("main");
-    ICHECK(base_func->IsInstance<FunctionNode>())
-        << "VM compiler expects to compile relay::Function";
-    auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
-    auto gvar = mod->GetGlobalVar("main");
-    mod->Add(gvar, f);
-  }
-
   exec_ = make_object<Executable>();
   targets_ = targets;
   target_host_ = target_host;
@@ -1005,8 +996,17 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
   return transform::Sequential(pass_seqs);
 }
 
-IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
+IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets,
                                     const Target& target_host) {
+  if (params_.size()) {
+    BaseFunc base_func = mod->Lookup("main");
+    ICHECK(base_func->IsInstance<FunctionNode>())
+        << "VM compiler expects to compile relay::Function";
+    auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
+    auto gvar = mod->GetGlobalVar("main");
+    mod->Add(gvar, f);
+  }
+
   Array<Pass> pass_seqs;
   Array<runtime::String> entry_functions{"main"};
   pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index 56965c5..615a818 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -125,8 +125,7 @@ class VMCompiler : public runtime::ModuleNode {
    *
    * \return The optimized IRModule.
    */
-  IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets,
-                          const Target& target_host);
+  IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host);
 
   /*!
    * \brief Populate the global function names in a map where the value is used
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 6958010..975070a 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -678,6 +678,10 @@ def test_vm_optimize():
     comp = relay.vm.VMCompiler()
     opt_mod, _ = comp.optimize(mod, target="llvm", params=params)
 
+    free_vars = relay.analysis.free_vars(opt_mod["main"].body)
+    # Paremeters should all be bound, so the only free var is data
+    assert len(free_vars) == 1
+
 
 @tvm.testing.uses_gpu
 def test_loop_free_var():