You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/11/06 19:23:07 UTC

[incubator-tvm] branch v0.7 updated: [Backport][Bugfix][Module] Fix recursive GetFunction in runtime::Module (#6866)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch v0.7
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.7 by this push:
     new 5eb56a9  [Backport][Bugfix][Module] Fix recursive GetFunction in runtime::Module (#6866)
5eb56a9 is described below

commit 5eb56a99c0fabae02671298f10c6222e75966bb1
Author: Junru Shao <ju...@gmail.com>
AuthorDate: Fri Nov 6 11:22:52 2020 -0800

    [Backport][Bugfix][Module] Fix recursive GetFunction in runtime::Module (#6866)
---
 src/runtime/module.cc                              |  3 +++
 .../test_runtime_module_based_interface.py         | 30 ++++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 98b0b3a..e50ea1c 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -68,6 +68,9 @@ PackedFunc ModuleNode::GetFunction(const std::string& name, bool query_imports)
   if (query_imports) {
     for (Module& m : self->imports_) {
       pf = m.operator->()->GetFunction(name, query_imports);
+      if (pf != nullptr) {
+        return pf;
+      }
     }
   }
   return pf;
diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py
index 1d682d2..a2e2e59 100644
--- a/tests/python/unittest/test_runtime_module_based_interface.py
+++ b/tests/python/unittest/test_runtime_module_based_interface.py
@@ -538,6 +538,35 @@ def test_debug_graph_runtime():
     tvm.testing.assert_allclose(out, verify(data), atol=1e-5)
 
 
+def test_multiple_imported_modules():
+    def make_func(symbol):
+        n = tvm.te.size_var("n")
+        Ab = tvm.tir.decl_buffer((n,), dtype="float32")
+        i = tvm.te.var("i")
+        stmt = tvm.tir.For(
+            i,
+            0,
+            n - 1,
+            0,
+            0,
+            tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1),
+        )
+        return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol)
+
+    def make_module(mod):
+        mod = tvm.IRModule(mod)
+        mod = tvm.driver.build(mod, target="llvm")
+        return mod
+
+    module_main = make_module({"main": make_func("main")})
+    module_a = make_module({"func_a": make_func("func_a")})
+    module_b = make_module({"func_b": make_func("func_b")})
+    module_main.import_module(module_a)
+    module_main.import_module(module_b)
+    module_main.get_function("func_a", query_imports=True)
+    module_main.get_function("func_b", query_imports=True)
+
+
 if __name__ == "__main__":
     test_legacy_compatibility()
     test_cpu()
@@ -545,3 +574,4 @@ if __name__ == "__main__":
     test_mod_export()
     test_remove_package_params()
     test_debug_graph_runtime()
+    test_multiple_imported_modules()