You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/02/13 10:40:20 UTC
[tvm] branch main updated: [VM] Move param bind to OptimizeModule
(#7451)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0aa90b0 [VM] Move param bind to OptimizeModule (#7451)
0aa90b0 is described below
commit 0aa90b093fd7e842eb88fa8e9994f70f24ba2bbf
Author: masahi <ma...@gmail.com>
AuthorDate: Sat Feb 13 19:40:00 2021 +0900
[VM] Move param bind to OptimizeModule (#7451)
* [VM] Move param bind to OptimizeModule
* add test to verify the number of free vars after opt
* remove const from OptimizeModule
---
src/relay/backend/vm/compiler.cc | 20 ++++++++++----------
src/relay/backend/vm/compiler.h | 3 +--
tests/python/relay/test_vm.py | 4 ++++
3 files changed, 15 insertions(+), 12 deletions(-)
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 7861502..7697b59 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -892,15 +892,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
}
void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) {
- if (params_.size()) {
- BaseFunc base_func = mod->Lookup("main");
- ICHECK(base_func->IsInstance<FunctionNode>())
- << "VM compiler expects to compile relay::Function";
- auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
- auto gvar = mod->GetGlobalVar("main");
- mod->Add(gvar, f);
- }
-
exec_ = make_object<Executable>();
targets_ = targets;
target_host_ = target_host;
@@ -1005,8 +996,17 @@ transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
return transform::Sequential(pass_seqs);
}
-IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets,
+IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets,
const Target& target_host) {
+ if (params_.size()) {
+ BaseFunc base_func = mod->Lookup("main");
+ ICHECK(base_func->IsInstance<FunctionNode>())
+ << "VM compiler expects to compile relay::Function";
+ auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
+ auto gvar = mod->GetGlobalVar("main");
+ mod->Add(gvar, f);
+ }
+
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h
index 56965c5..615a818 100644
--- a/src/relay/backend/vm/compiler.h
+++ b/src/relay/backend/vm/compiler.h
@@ -125,8 +125,7 @@ class VMCompiler : public runtime::ModuleNode {
*
* \return The optimized IRModule.
*/
- IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets,
- const Target& target_host);
+ IRModule OptimizeModule(IRModule mod, const TargetsMap& targets, const Target& target_host);
/*!
* \brief Populate the global function names in a map where the value is used
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 6958010..975070a 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -678,6 +678,10 @@ def test_vm_optimize():
comp = relay.vm.VMCompiler()
opt_mod, _ = comp.optimize(mod, target="llvm", params=params)
+ free_vars = relay.analysis.free_vars(opt_mod["main"].body)
+ # Paremeters should all be bound, so the only free var is data
+ assert len(free_vars) == 1
+
@tvm.testing.uses_gpu
def test_loop_free_var():