You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/05/20 06:32:46 UTC

[tvm] branch main updated: [VM] add removeUnusedFunctions pass in vm memoryopt (#8040)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c732af  [VM] add removeUnusedFunctions pass in vm memoryopt (#8040)
7c732af is described below

commit 7c732af2b57495261773e22601fe3e2a33cceb78
Author: Xingyu Zhou <zh...@amazon.com>
AuthorDate: Wed May 19 23:32:21 2021 -0700

    [VM] add removeUnusedFunctions pass in vm memoryopt (#8040)
    
    * add removeUnusedFunctions pass in vm memoryopt
    
    * fix lint
---
 src/relay/backend/vm/compiler.cc |  3 +++
 tests/python/relay/test_vm.py    | 26 ++++++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index 832cc0e..ad23e13 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -978,6 +978,9 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe
 
 transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) {
   Array<Pass> pass_seqs;
+  // Remove unused functions
+  Array<runtime::String> entry_functions{"main"};
+  pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
   // Manifest the allocations.
   pass_seqs.push_back(transform::ManifestAlloc(host_target, targets));
 
diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py
index 8f51869..9f861a2 100644
--- a/tests/python/relay/test_vm.py
+++ b/tests/python/relay/test_vm.py
@@ -29,6 +29,7 @@ from tvm.relay import testing
 from tvm.contrib import utils
 from tvm import rpc
 import tvm.testing
+from tvm.relay.transform import InferType
 
 
 def check_result(args, expected_result, mod=None):
@@ -187,6 +188,31 @@ def test_multiple_ifs():
 
 
 @tvm.testing.uses_gpu
+def test_unused_function():
+    cond = relay.const(True)
+    mod = tvm.IRModule()
+    then_name = relay.GlobalVar("times_2")
+    # define unused function
+    else_name = relay.GlobalVar("times_3")
+    t1 = relay.TensorType((2, 2), dtype="float32")
+    x1 = relay.var("x1", t1, dtype="float32")
+    x2 = relay.var("x2", t1, dtype="float32")
+    f2 = relay.multiply(x1, relay.const(2.0))
+    f3 = relay.multiply(x2, relay.const(3.0))
+    mod[then_name] = relay.Function([x1], f2)
+    mod[else_name] = relay.Function([x2], f3)
+    mod = InferType()(mod)
+    x3 = relay.var("x3", t1, dtype="float32")
+    # put unused function in else branch
+    f = relay.If(cond, then_name(x3), else_name(x3))
+    mod["main"] = relay.Function([x3], f)
+    x_data = np.random.rand(2, 2).astype("float32")
+    y_data = x_data * 2
+
+    check_result([x_data], y_data, mod=mod)
+
+
+@tvm.testing.uses_gpu
 def test_simple_call():
     mod = tvm.IRModule({})
     sum_up = relay.GlobalVar("sum_up")