You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/23 07:18:25 UTC

[tvm] branch main updated: Fix GraphModule.load_params to allow passing parameters that are not an expected input (#7665)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c66fb2  Fix GraphModule.load_params to allow passing parameters that are not an expected input (#7665)
4c66fb2 is described below

commit 4c66fb2e4b99e376fbaec15d975e4e4d1d8321ab
Author: Jorn Tuyls <jt...@users.noreply.github.com>
AuthorDate: Tue Mar 23 07:18:04 2021 +0000

    Fix GraphModule.load_params to allow passing parameters that are not an expected input (#7665)
---
 src/runtime/graph/graph_runtime.cc          |  4 +-
 tests/python/relay/test_external_codegen.py | 59 +++++++++++++++++++++--------
 tests/python/unittest/test_runtime_graph.py | 24 +++++++++++-
 3 files changed, 69 insertions(+), 18 deletions(-)

diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index 5c7b756..b11a573 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -201,7 +201,9 @@ void GraphRuntime::LoadParams(const std::string& param_blob) {
 void GraphRuntime::LoadParams(dmlc::Stream* strm) {
   Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
   for (auto& p : params) {
-    uint32_t eid = this->entry_id(input_nodes_[GetInputIndex(p.first)], 0);
+    int in_idx = GetInputIndex(p.first);
+    if (in_idx < 0) continue;
+    uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
     data_entry_[eid].CopyFrom(p.second);
   }
 }
diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py
index 0d729b7..ab6695e 100644
--- a/tests/python/relay/test_external_codegen.py
+++ b/tests/python/relay/test_external_codegen.py
@@ -23,9 +23,29 @@ import tvm
 from tvm import te
 import tvm.relay.testing
 import tvm.relay.transform
+
 from tvm import relay
 from tvm import runtime
+from tvm.relay import transform
 from tvm.contrib import utils
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+
+
+def update_lib(lib):
+    test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+    source_dir = os.path.join(test_dir, "..", "..", "..")
+    contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+    kwargs = {}
+    kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
+    tmp_path = utils.tempdir()
+    lib_name = "lib.so"
+    lib_path = tmp_path.relpath(lib_name)
+    lib.export_library(lib_path, fcompile=False, **kwargs)
+    lib = tvm.runtime.load_module(lib_path)
+
+    return lib
 
 
 def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu()):
@@ -33,21 +53,6 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ct
         print("Skip test on Windows for now")
         return
 
-    def update_lib(lib):
-        test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
-        source_dir = os.path.join(test_dir, "..", "..", "..")
-        contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
-
-        kwargs = {}
-        kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
-        tmp_path = utils.tempdir()
-        lib_name = "lib.so"
-        lib_path = tmp_path.relpath(lib_name)
-        lib.export_library(lib_path, fcompile=False, **kwargs)
-        lib = tvm.runtime.load_module(lib_path)
-
-        return lib
-
     def check_vm_result():
         with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
             exe = relay.vm.compile(mod, target=target)
@@ -329,6 +334,29 @@ def test_extern_dnnl_const():
     check_result(mod, {"data0": i_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
 
 
+def test_load_params_with_constants_in_ext_codegen():
+    # After binding params and partitioning graph_module.get_params()
+    # might contain parameters that are not an graph runtime input but
+    # for example constants in external function.
+    y_in = np.ones((1,)).astype("float32")
+    params = {"y": y_in}
+    mod = tvm.IRModule()
+    x = relay.var("x", shape=(1, 10))
+    y = relay.var("y", shape=(1,))
+    xcb = compiler_begin(x, "ccompiler")
+    ycb = compiler_begin(y, "ccompiler")
+    z = relay.add(xcb, ycb)
+    zce = compiler_end(z, "ccompiler")
+    mod["main"] = relay.Function([x, y], zce)
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = transform.PartitionGraph()(mod)
+
+    graph_module = relay.build(mod, target="llvm", params=params)
+    lib = update_lib(graph_module.get_lib())
+    rt_mod = tvm.contrib.graph_runtime.create(graph_module.get_json(), lib, tvm.cpu(0))
+    rt_mod.load_params(runtime.save_param_dict(graph_module.get_params()))
+
+
 if __name__ == "__main__":
     test_multi_node_subgraph()
     test_extern_gcc_single_op()
@@ -337,3 +365,4 @@ if __name__ == "__main__":
     test_extern_gcc_consts()
     test_extern_dnnl()
     test_extern_dnnl_const()
+    test_load_params_with_constants_in_ext_codegen()
diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py
index 16e9db4..fe33c0f 100644
--- a/tests/python/unittest/test_runtime_graph.py
+++ b/tests/python/unittest/test_runtime_graph.py
@@ -20,6 +20,7 @@ from tvm import te, runtime
 import numpy as np
 import json
 from tvm import rpc
+from tvm import relay
 from tvm.contrib import utils, graph_runtime
 
 
@@ -82,8 +83,6 @@ def test_graph_simple():
         np.testing.assert_equal(out.asnumpy(), a + 1)
 
     def check_sharing():
-        from tvm import relay
-
         x = relay.var("x", shape=(1, 10))
         y = relay.var("y", shape=(1, 10))
         z = relay.add(x, y)
@@ -120,5 +119,26 @@ def test_graph_simple():
     check_sharing()
 
 
+def test_load_unexpected_params():
+    # Test whether graph_runtime.load_params works if parameters
+    # are provided that are not an expected input.
+    mod = tvm.IRModule()
+    params = {}
+    x = relay.var("x", shape=(1, 10))
+    y = relay.var("y", shape=(1, 10))
+    z = relay.add(x, y)
+    mod["main"] = relay.Function([x, y], z)
+
+    graph_module = relay.build(mod, target="llvm", params=params)
+    rt_mod = tvm.contrib.graph_runtime.create(
+        graph_module.get_json(), graph_module.get_lib(), tvm.cpu(0)
+    )
+
+    new_params = graph_module.get_params()
+    new_params.update({"y_unknown": np.ones((1,)).astype("float32")})
+    rt_mod.load_params(runtime.save_param_dict(new_params))
+
+
 if __name__ == "__main__":
     test_graph_simple()
+    test_load_unexpected_params()