You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/23 07:18:25 UTC
[tvm] branch main updated: Fix GraphModule.load_params to allow
passing parameters that are not an expected input (#7665)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4c66fb2 Fix GraphModule.load_params to allow passing parameters that are not an expected input (#7665)
4c66fb2 is described below
commit 4c66fb2e4b99e376fbaec15d975e4e4d1d8321ab
Author: Jorn Tuyls <jt...@users.noreply.github.com>
AuthorDate: Tue Mar 23 07:18:04 2021 +0000
Fix GraphModule.load_params to allow passing parameters that are not an expected input (#7665)
---
src/runtime/graph/graph_runtime.cc | 4 +-
tests/python/relay/test_external_codegen.py | 59 +++++++++++++++++++++--------
tests/python/unittest/test_runtime_graph.py | 24 +++++++++++-
3 files changed, 69 insertions(+), 18 deletions(-)
diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc
index 5c7b756..b11a573 100644
--- a/src/runtime/graph/graph_runtime.cc
+++ b/src/runtime/graph/graph_runtime.cc
@@ -201,7 +201,9 @@ void GraphRuntime::LoadParams(const std::string& param_blob) {
void GraphRuntime::LoadParams(dmlc::Stream* strm) {
Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
for (auto& p : params) {
- uint32_t eid = this->entry_id(input_nodes_[GetInputIndex(p.first)], 0);
+ int in_idx = GetInputIndex(p.first);
+ if (in_idx < 0) continue;
+ uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
data_entry_[eid].CopyFrom(p.second);
}
}
diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py
index 0d729b7..ab6695e 100644
--- a/tests/python/relay/test_external_codegen.py
+++ b/tests/python/relay/test_external_codegen.py
@@ -23,9 +23,29 @@ import tvm
from tvm import te
import tvm.relay.testing
import tvm.relay.transform
+
from tvm import relay
from tvm import runtime
+from tvm.relay import transform
from tvm.contrib import utils
+from tvm.relay.build_module import bind_params_by_name
+from tvm.relay.op.annotation import compiler_begin, compiler_end
+
+
+def update_lib(lib):
+ test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
+ source_dir = os.path.join(test_dir, "..", "..", "..")
+ contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
+
+ kwargs = {}
+ kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
+ tmp_path = utils.tempdir()
+ lib_name = "lib.so"
+ lib_path = tmp_path.relpath(lib_name)
+ lib.export_library(lib_path, fcompile=False, **kwargs)
+ lib = tvm.runtime.load_module(lib_path)
+
+ return lib
def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ctx=tvm.cpu()):
@@ -33,21 +53,6 @@ def check_result(mod, map_inputs, out_shape, result, tol=1e-5, target="llvm", ct
print("Skip test on Windows for now")
return
- def update_lib(lib):
- test_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))
- source_dir = os.path.join(test_dir, "..", "..", "..")
- contrib_path = os.path.join(source_dir, "src", "runtime", "contrib")
-
- kwargs = {}
- kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path]
- tmp_path = utils.tempdir()
- lib_name = "lib.so"
- lib_path = tmp_path.relpath(lib_name)
- lib.export_library(lib_path, fcompile=False, **kwargs)
- lib = tvm.runtime.load_module(lib_path)
-
- return lib
-
def check_vm_result():
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
exe = relay.vm.compile(mod, target=target)
@@ -329,6 +334,29 @@ def test_extern_dnnl_const():
check_result(mod, {"data0": i_data}, (1, 32, 14, 14), ref_res.asnumpy(), tol=1e-5)
+def test_load_params_with_constants_in_ext_codegen():
+ # After binding params and partitioning graph_module.get_params()
+ # might contain parameters that are not an graph runtime input but
+ # for example constants in external function.
+ y_in = np.ones((1,)).astype("float32")
+ params = {"y": y_in}
+ mod = tvm.IRModule()
+ x = relay.var("x", shape=(1, 10))
+ y = relay.var("y", shape=(1,))
+ xcb = compiler_begin(x, "ccompiler")
+ ycb = compiler_begin(y, "ccompiler")
+ z = relay.add(xcb, ycb)
+ zce = compiler_end(z, "ccompiler")
+ mod["main"] = relay.Function([x, y], zce)
+ mod["main"] = bind_params_by_name(mod["main"], params)
+ mod = transform.PartitionGraph()(mod)
+
+ graph_module = relay.build(mod, target="llvm", params=params)
+ lib = update_lib(graph_module.get_lib())
+ rt_mod = tvm.contrib.graph_runtime.create(graph_module.get_json(), lib, tvm.cpu(0))
+ rt_mod.load_params(runtime.save_param_dict(graph_module.get_params()))
+
+
if __name__ == "__main__":
test_multi_node_subgraph()
test_extern_gcc_single_op()
@@ -337,3 +365,4 @@ if __name__ == "__main__":
test_extern_gcc_consts()
test_extern_dnnl()
test_extern_dnnl_const()
+ test_load_params_with_constants_in_ext_codegen()
diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py
index 16e9db4..fe33c0f 100644
--- a/tests/python/unittest/test_runtime_graph.py
+++ b/tests/python/unittest/test_runtime_graph.py
@@ -20,6 +20,7 @@ from tvm import te, runtime
import numpy as np
import json
from tvm import rpc
+from tvm import relay
from tvm.contrib import utils, graph_runtime
@@ -82,8 +83,6 @@ def test_graph_simple():
np.testing.assert_equal(out.asnumpy(), a + 1)
def check_sharing():
- from tvm import relay
-
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
@@ -120,5 +119,26 @@ def test_graph_simple():
check_sharing()
+def test_load_unexpected_params():
+ # Test whether graph_runtime.load_params works if parameters
+ # are provided that are not an expected input.
+ mod = tvm.IRModule()
+ params = {}
+ x = relay.var("x", shape=(1, 10))
+ y = relay.var("y", shape=(1, 10))
+ z = relay.add(x, y)
+ mod["main"] = relay.Function([x, y], z)
+
+ graph_module = relay.build(mod, target="llvm", params=params)
+ rt_mod = tvm.contrib.graph_runtime.create(
+ graph_module.get_json(), graph_module.get_lib(), tvm.cpu(0)
+ )
+
+ new_params = graph_module.get_params()
+ new_params.update({"y_unknown": np.ones((1,)).astype("float32")})
+ rt_mod.load_params(runtime.save_param_dict(new_params))
+
+
if __name__ == "__main__":
test_graph_simple()
+ test_load_unexpected_params()