You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/04/09 19:17:24 UTC

[tvm] 03/09: [AMD:ONNXRT:TVM] Include input shapes during compilation.

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch checkpoint
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 388e1df3aeab41fa18bec37a80d76c24f5210983
Author: Chris Sullivan <cs...@octoml.ai>
AuthorDate: Thu Aug 20 22:18:58 2020 -0700

    [AMD:ONNXRT:TVM] Include input shapes during compilation.
---
 include/tvm/driver/jit_interface.h    |  6 ++++--
 python/tvm/relay/frontend/jit/onnx.py |  7 +++----
 src/driver/driver_api.cc              | 19 ++++++++++++++-----
 3 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/include/tvm/driver/jit_interface.h b/include/tvm/driver/jit_interface.h
index e0906f1..e9203ee 100644
--- a/include/tvm/driver/jit_interface.h
+++ b/include/tvm/driver/jit_interface.h
@@ -2,7 +2,9 @@
 
 #ifdef __cplusplus
 extern "C" {
-    EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level);
-    EXPORT_DLL void TVMRun(tvm::runtime::Module& mod,  std::vector<DLTensor> inputs, std::vector<DLTensor> outputs, tvm::runtime::TVMRetValue* ret);
+    EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level, const std::vector<std::vector<int64_t>>& input_shapes);
+    EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor>& inputs, std::vector<DLTensor>& outputs, tvm::runtime::TVMRetValue* ret);
+    
+    
 }  // TVM_EXTERN_C
 #endif
diff --git a/python/tvm/relay/frontend/jit/onnx.py b/python/tvm/relay/frontend/jit/onnx.py
index 9545395..3672bbe 100644
--- a/python/tvm/relay/frontend/jit/onnx.py
+++ b/python/tvm/relay/frontend/jit/onnx.py
@@ -19,13 +19,12 @@ import tvm
 import tvm.relay
 
 @tvm.register_func("tvm_onnx_import_and_compile")
-def onnx_compile(model_string, target, target_host, opt_level):
+def onnx_compile(model_string, target, target_host, opt_level, input_shapes):
     model = onnx.load_model_from_string(bytes(model_string))
 
-    # input shape from data
-    input_shape = {model.graph.input[0].name: (6,)}
+    input_shapes = {name : shape for (name, shape) in zip([i.name for i in model.graph.input], input_shapes)}
 
-    irmod, params = tvm.relay.frontend.from_onnx(model, input_shape, opset=11)
+    irmod, params = tvm.relay.frontend.from_onnx(model, input_shapes, opset=11)
     with tvm.relay.build_config(opt_level=opt_level):
         graph, lib, params = tvm.relay.build(irmod, target_host=target_host, target=target, params=params)
 
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index b876c38..d55c0ae 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -338,16 +338,25 @@ runtime::Module build(const IRModule& funcs, const Target& target, const Target&
 
 }  // namespace tvm
 
-
-tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level)
+tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level, const std::vector<std::vector<int64_t>>& input_shapes)
 {
+  tvm::Array<tvm::Array<tvm::Integer>> shapes;
+  for (size_t i = 0; i < input_shapes.size(); i++)
+  {
+    tvm::Array<tvm::Integer> shape;
+    for (auto& dim : input_shapes[i])
+    {
+      shape.push_back(tvm::Integer(dim));
+    }
+    shapes.push_back(shape);
+  }
+
   const tvm::PackedFunc* compile = tvm::runtime::Registry::Get("tvm_onnx_import_and_compile");
-  tvm::runtime::Module mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, target, target_host, opt_level);
+  tvm::runtime::Module mod = (*compile)(TVMByteArray{onnx_txt.data(), onnx_txt.size()}, target, target_host, opt_level, shapes);
   return mod;
-
 }
 
-void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor> inputs, std::vector<DLTensor> outputs, tvm::runtime::TVMRetValue* ret)
+void TVMRun(tvm::runtime::Module& mod, std::vector<DLTensor>& inputs, std::vector<DLTensor>& outputs, tvm::runtime::TVMRetValue* ret)
 {
   tvm::PackedFunc set_input = mod.GetFunction("set_input_zero_copy", false);
   for (size_t i = 0; i < inputs.size(); i++)