You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/03/04 22:10:29 UTC
[tvm] branch main updated: [BYOC][TensorRT] Make TRT runtime robust
to empty or weird subgraphs (#7581)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3fbb0a3 [BYOC][TensorRT] Make TRT runtime robust to empty or weird subgraphs (#7581)
3fbb0a3 is described below
commit 3fbb0a3d749c45d121ed213d4741c4e8e8041320
Author: Trevor Morris <tr...@amazon.com>
AuthorDate: Thu Mar 4 14:10:15 2021 -0800
[BYOC][TensorRT] Make TRT runtime robust to empty or weird subgraphs (#7581)
* Prevent TRT runtime crash for duplicate inputs and outputs
* Add empty subgraph unit test
---
src/runtime/contrib/tensorrt/tensorrt_builder.cc | 8 +++++
tests/python/contrib/test_tensorrt.py | 42 ++++++++++++++++++++----
2 files changed, 43 insertions(+), 7 deletions(-)
diff --git a/src/runtime/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
index ee47e67..09b36d7 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_builder.cc
@@ -99,6 +99,14 @@ void TensorRTBuilder::AddOutput(const JSONGraphNodeEntry& node, uint32_t entry_i
ICHECK(it != node_output_map_.end()) << "Output was not found.";
auto out_tensor = it->second[node.index_].tensor;
std::string name = "tensorrt_output_" + std::to_string(network_output_names_.size());
+ // If the network is already marked as an input or output, make a copy to avoid TRT crash.
+ if (out_tensor->isNetworkOutput()) {
+ LOG(WARNING) << name << " is a duplicate output.";
+ out_tensor = network_->addIdentity(*out_tensor)->getOutput(0);
+ } else if (out_tensor->isNetworkInput()) {
+ LOG(WARNING) << name << " is both an input and an output.";
+ out_tensor = network_->addIdentity(*out_tensor)->getOutput(0);
+ }
out_tensor->setName(name.c_str());
network_->markOutput(*out_tensor);
network_output_names_.push_back(name);
diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py
index 7ddc4e7..60d6b2a 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -71,6 +71,14 @@ def assert_result_dict_holds(result_dict):
tvm.testing.assert_allclose(r1, r2, rtol=1e-3, atol=1e-3)
+def set_func_attr(func, compile_name, symbol_name):
+ func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+ func = func.with_attr("Compiler", compile_name)
+ func = func.with_attr("global_symbol", symbol_name)
+ return func
+
+
def run_and_verify_func(config, target="cuda"):
"""Test a Relay func by compiling, running, and comparing TVM and TRT outputs.
@@ -1109,13 +1117,6 @@ def test_dynamic_offload():
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
def get_expected():
- def set_func_attr(func, compile_name, symbol_name):
- func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
- func = func.with_attr("Compiler", compile_name)
- func = func.with_attr("global_symbol", symbol_name)
- return func
-
# Create a nested TRT function that matches the expected output
mod = tvm.IRModule()
var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
@@ -1331,5 +1332,32 @@ def test_maskrcnn_resnet50() -> None:
)
+def test_empty_subgraph():
+ if skip_codegen_test():
+ return
+ x_shape = (1, 3, 5)
+ mod = tvm.IRModule()
+ # Empty tensorrt subgraph.
+ var1 = relay.var("tensorrt_0_i0", shape=(x_shape), dtype="float32")
+ f1 = GlobalVar("tensorrt_0")
+ func = relay.Function([var1], var1)
+ func = set_func_attr(func, "tensorrt", "tensorrt_0")
+ mod[f1] = func
+ mod = relay.transform.InferType()(mod)
+
+ # Create the main function
+ x = relay.var("x", shape=x_shape, dtype="float32")
+ out = f1(relay.nn.relu(x))
+ f = relay.Function([x], out)
+ mod["main"] = f
+
+ x_data = np.random.uniform(-1, 1, x_shape).astype("float32")
+ for mode in ["graph", "vm"]:
+ with tvm.transform.PassContext(opt_level=3):
+ exec = relay.create_executor(mode, mod=mod, ctx=tvm.gpu(0), target="cuda")
+ if not skip_runtime_test():
+ results = exec.evaluate()(x_data)
+
+
if __name__ == "__main__":
pytest.main([__file__])