You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/12 15:57:12 UTC

[GitHub] [incubator-tvm] rohanmukh opened a new pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

rohanmukh opened a new pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905


   refactoring test tensort code
   
   added comments to dynamic check wrapper
   
   log.warn changed to logger.info
   
   TRT codegen taking slice_mode into account
   
   TRT codegen to handle both types of stride_mode
   
   refactoring TRT codegen
   
   adding a test for dynamic offload
   
   [TRT] bug in codegen for slice_mode=end
   
   ctx determined from target in test + io test was missing
   
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] trevor-m commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
trevor-m commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522510919



##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,13 +960,82 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
+def test_tensorrt_integration():
+    # Integration tests
+    test_alexnet()
+    test_resnet18_v1()
+    test_resnet18_v2()
+    test_squeezenet()
+    test_mobilenet()
+    test_mobilenet_v2()
+    test_vgg11()
+    test_densenet121()
+
+
+def test_dynamic_offload(data_shape=(1, 32, 8, 8), k_shape=(1, 32, 3, 3)):
+    """
+    This test checks for proper dynamic offloading of relay graphs. An addition between
+    the outputs of two conv2d's is performed, one of them having all static args whereas
+    the other has a arg with dynamic shape. It is expected for the TRT partitioner to
+    offload the conv2d with dynamic arg to TVM while running the other in TRT.
+    """
+    x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), dtype="float32")
+    y = relay.var("y", shape=(data_shape), dtype="float32")
+    kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
+
+    def get_expected():
+        def set_func_attr(func, compile_name, symbol_name):
+            func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Compiler", compile_name)
+            func = func.with_attr("global_symbol", symbol_name)
+            return func
+
+        # Create a nested TRT function that matches the expected output
+        mod = tvm.IRModule()
+        var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
+        kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), dtype="float32")
+        out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], kernel_size=k_shape[2:4])
+        f1 = GlobalVar("tensorrt_0")
+        func = relay.Function([var1, kernel_trt], out1)
+        func = set_func_attr(func, "tensorrt", "tensorrt_0")
+        mod[f1] = func
+        mod = relay.transform.InferType()(mod)
+
+        # Create the main function
+        out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+        out = relay.add(out1, f1(y, kernel))
+        f = relay.Function([x, y, kernel], out)
+        mod["main"] = f
+        mod = relay.transform.InferType()(mod)
+        return mod
+
+    # Create relay function that will be offloaded to TRT
+    out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out = relay.add(out1, out2)
+    f = relay.Function([x, y, kernel], out)
+
+    # Pass the function to TRT compilation
+    mod = tvm.IRModule()
+    mod["main"] = f
+    mod = relay.transform.InferType()(mod)
+    mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={})
+
+    # Get the expected relay graph and compare
+    mod_exp = get_expected()
+    tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
+    return

Review comment:
       Don't need this return

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,13 +960,82 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
+def test_tensorrt_integration():
+    # Integration tests
+    test_alexnet()
+    test_resnet18_v1()
+    test_resnet18_v2()
+    test_squeezenet()
+    test_mobilenet()
+    test_mobilenet_v2()
+    test_vgg11()
+    test_densenet121()
+
+
+def test_dynamic_offload(data_shape=(1, 32, 8, 8), k_shape=(1, 32, 3, 3)):

Review comment:
       Lets move the args to variables inside the test. Same for the serialization tests




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] trevor-m commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
trevor-m commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522477235



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)

Review comment:
       I see, can we make this function a decorator then?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522597861



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)

Review comment:
       Thanks, I did that. Also tested its efficacy with test_dynamic_offload(). Can you please cross-verify the decorator usage?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] trevor-m commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
trevor-m commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r523133866



##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,44 +972,7 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
-if __name__ == "__main__":
-    test_tensorrt_not_compatible()
-    test_tensorrt_simple()
-    test_tensorrt_simple_cpu_io()
-    test_tensorrt_serialize()
-
-    # Op tests
-    test_conv2d()
-    test_conv2d_nhwc()
-    test_conv2d_weights_const()
-    test_conv2d_weights_transposed()
-    test_dense()
-    test_bias_add()
-    test_pool2d()
-    test_global_pool2d()
-    test_batch_flatten()
-    test_expand_dims()
-    test_squeeze()
-    test_concatenate()
-    test_conv2d_transpose()
-    test_reshape()
-    test_transpose()
-    test_float_const()
-    test_pad()
-    test_softmax()
-    test_batch_norm()
-    test_unary()
-    test_clip()
-    test_leaky_relu()
-    test_binary()
-    test_reduce()
-    test_strided_slice()
-    test_adaptive_pool2d()
-    test_multiple_outputs()
-    test_conv3d()
-    test_pool3d()
-    test_conv3d_transpose()
-
+def test_tensorrt_integration():

Review comment:
       We can remove this function - these tests will already be ran




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] trevor-m commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
trevor-m commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522257464



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -598,14 +646,26 @@ def strided_slice_annotate_fn(expr):  # pylint: disable=unused-variable
         if batch_dim_begin_modified or batch_dim_end_modified:
             logger.info("strided_slice: can't modify batch dimension.")
             return False
+
     if any([x is not None and x <= 0 for x in attrs.strides]):
         logger.info("strided_slice: stride must be positive")
         return False
+
+    for i in range(0, len(args[0].checked_type.shape)):
+        begin = int(attrs.begin[i])
+        end = (

Review comment:
       We need to take slice mode into account here also

##########
File path: src/relay/backend/contrib/tensorrt/codegen.cc
##########
@@ -133,26 +133,34 @@ class TensorRTJSONSerializer : public backend::contrib::JSONSerializer {
     auto process_slice_index = [](Integer x, int default_value, int dim_value) {
       if (!x.defined()) return default_value;
       int value = x.as<IntImmNode>()->value;
-      if (value < 0) value += dim_value;
+      value = (value < 0 ) ? dim_value + value : value;

Review comment:
       This line is the same as the previous code, can you change it back?

##########
File path: src/runtime/contrib/tensorrt/tensorrt_ops.cc
##########
@@ -944,7 +944,7 @@ class ReduceOpConverter : public TensorRTOpConverter {
 #if TRT_VERSION_GE(5, 1, 5)
 class StridedSliceOpConverter : public TensorRTOpConverter {
  public:
-  StridedSliceOpConverter() : TensorRTOpConverter({kTensor, kWeight, kWeight, kWeight}) {}
+  StridedSliceOpConverter() : TensorRTOpConverter({kTensor}) {} // , kWeight, kWeight, kWeight}) {}

Review comment:
       Remove comment

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)

Review comment:
       I'm not sure if using the wrapper is better than adding a call to check_dynamism to each annotator. With the previous method, the annotator is all in one place with the decorator. Now, we have to remember to call register down here after writing the functions.

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -841,6 +906,39 @@ def get_graph(
     run_and_verify_func(get_graph(strides=(2, 2, 2)))
     run_and_verify_func(get_graph(strides=(2, 2, 2), output_padding=(1, 1, 1)))
 
+def test_tensorrt_ops():

Review comment:
       You can leave these in main

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -219,40 +222,95 @@ def test_tensorrt_not_compatible():
     mod = tvm.IRModule()
     mod["main"] = f
     mod, config = tensorrt.partition_for_tensorrt(mod)
-    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-        graph, lib, params = relay.build(mod, "cuda")
-    if skip_runtime_test():
-        return
-    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
-    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
-    mod.run(x=x_data)
-    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+    for mode in ["graph", "vm"]:
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            exec = relay.create_executor(mode, mod=mod, ctx=tvm.gpu(0), target="cuda")
+            if not skip_runtime_test():
+                results = exec.evaluate()(x_data)
+
 
 
-def test_tensorrt_serialize():
+def test_tensorrt_serialize(data_shape=(1, 3, 224, 224), data_type="float32"):

Review comment:
       I think it would be good to split this into `test_tensorrt_serialize_graph_runtime` and `test_tensorrt_serialize_vm`




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522295438



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)

Review comment:
       Will it also not lead to the same issue, that we have to remember to check for check_dynamism in each of the annotator?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522306819



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -134,15 +135,18 @@ def partition_for_tensorrt(
 
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
+
     seq = tvm.transform.Sequential(
         [
             transform.InferType(),
             RemoveDropoutPass(),
             transform.RemoveUnusedFunctions(),
             transform.ConvertLayout(
-                {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]}
+                {"nn.conv2d": ["NCHW", "default"],
+                 "nn.conv3d": ["NCDHW", "default"]}
             ),
             transform.FoldConstant(),
+            transform.InferType(),

Review comment:
       Thanks for pointing, we don't, was used for debugging purposes only.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#issuecomment-726223278


   Thanks for the comments @trevor-m. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522473247



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +153,50 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
 
+def check_dynamism(args, op_name):
+    """
+    Check for dynamism inside any of the args in the op.
+
+    Parameters
+    ----------
+    args : tvm.ir.container.Array
+        Arguments of the op. Each of the argument shape is checked for presence of dynamic
+        components.
+    op_name: str
+        Name of the op for debugging purposes only.
+    Returns
+    ----------
+    ret : bool
+        True if dynamism is present, False otherwise
+    """
+    for arg in args:
+        if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
+            for dim_shape in arg.checked_type.shape:

Review comment:
       nit: the type might be a nested tuple. We may want to check the flattened one if it is a tuple: https://github.com/apache/incubator-tvm/blob/main/python/tvm/relay/op/memory/memory.py#L70

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -219,40 +226,106 @@ def test_tensorrt_not_compatible():
     mod = tvm.IRModule()
     mod["main"] = f
     mod, config = tensorrt.partition_for_tensorrt(mod)
-    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-        graph, lib, params = relay.build(mod, "cuda")
-    if skip_runtime_test():
-        return
-    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
-    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
-    mod.run(x=x_data)
-    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+    for mode in ["graph", "vm"]:
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            exec = relay.create_executor(mode, mod=mod, ctx=tvm.gpu(0), target="cuda")
+            if not skip_runtime_test():
+                results = exec.evaluate()(x_data)
 
 
-def test_tensorrt_serialize():
+def test_tensorrt_serialize_graph_runtime(data_shape=(1, 3, 224, 224), data_type="float32"):
     if skip_codegen_test():
         return
-    import mxnet
-    from mxnet.gluon.model_zoo.vision import get_model
 
+    i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
     block = get_model("resnet18_v1", pretrained=True)
-    mod, params = relay.frontend.from_mxnet(
-        block, shape={"data": (1, 3, 224, 224)}, dtype="float32"
-    )
-    # Compile
-    mod, config = tensorrt.partition_for_tensorrt(mod, params)
-    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-        lib = relay.build(mod, "cuda", params=params)
-    # Serialize
-    lib.export_library("compiled.so")
-    # Deserialize
-    loaded_lib = tvm.runtime.load_module("compiled.so")
-    # Run
-    if skip_runtime_test():
+    mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
+    mod, config = tensorrt.partition_for_tensorrt(mod)
+
+    def compile_graph(mod, params):
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            graph, lib, params = relay.build(mod, params=params, target="cuda")
+            params = relay.save_param_dict(params)
+        return graph, lib, params
+
+    def run_graph(graph, lib, params):
+        mod_ = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod_.load_params(params)
+        mod_.run(data=i_data)
+        res = mod_.get_output(0)
+        return res
+
+    def save_graph(graph, lib, params):
+        # Serialize
+        with open("compiled.json", "w") as f_graph_json:
+            f_graph_json.write(graph)
+        with open("compiled.params", "wb") as f_params:
+            f_params.write(params)
+        lib.export_library("compiled.so")
+
+    def load_graph():
+        # Deserialize
+        with open("compiled.json", "r") as f_graph_json:
+            graph = f_graph_json.read()
+        with open("compiled.params", "rb") as f_params:
+            params = bytearray(f_params.read())
+        lib = tvm.runtime.load_module("compiled.so")
+        return graph, lib, params
+
+    # Test serialization with graph runtime
+    graph, lib, graph_params = compile_graph(mod, params)
+    save_graph(graph, lib, graph_params)
+    loaded_graph, loaded_lib, loaded_params = load_graph()
+
+    if not skip_runtime_test():
+        result_dict = dict()
+        result_dict["graph"] = run_graph(graph, lib, graph_params)
+        result_dict["graph_ref"] = run_graph(loaded_graph, loaded_lib, loaded_params)
+        assert_result_dict_holds(result_dict)
+
+
+def test_tensorrt_serialize_vm(data_shape=(1, 3, 224, 224), data_type="float32"):
+    if skip_codegen_test():
         return
-    gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib["default"](tvm.gpu(0)))
-    i_data = np.random.uniform(0, 1, (1, 3, 224, 224)).astype("float32")
-    gen_module.run(data=i_data)
+
+    i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
+    block = get_model("resnet18_v1", pretrained=True)
+    mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
+    mod, config = tensorrt.partition_for_tensorrt(mod)
+
+    def compile_vm(mod, params):
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            vm_exec = relay.vm.compile(mod, target="cuda", params=params)
+            code, lib = vm_exec.save()
+        return code, lib
+
+    def run_vm(code, lib):
+        vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
+        vm = VirtualMachine(vm_exec, tvm.gpu(0))
+        result = vm.invoke("main", data=i_data)
+        return result
+
+    def save_vm(code, lib):
+        # save and load the code and lib file.
+        lib.export_library("path_lib.so")
+        with open("path_code.ro", "wb") as fo:

Review comment:
       its probably better to use tmp dir because you sometimes may not have write permission at the current dir.

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,13 +960,82 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
+def test_tensorrt_integration():
+    # Integration tests
+    test_alexnet()
+    test_resnet18_v1()
+    test_resnet18_v2()
+    test_squeezenet()
+    test_mobilenet()
+    test_mobilenet_v2()
+    test_vgg11()
+    test_densenet121()
+
+
+def test_dynamic_offload(data_shape=(1, 32, 8, 8), k_shape=(1, 32, 3, 3)):
+    """
+    This test checks for proper dynamic offloading of relay graphs. An addition between
+    the outputs of two conv2d's is performed, one of them having all static args whereas
+    the other has a arg with dynamic shape. It is expected for the TRT partitioner to
+    offload the conv2d with dynamic arg to TVM while running the other in TRT.
+    """
+    x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), dtype="float32")
+    y = relay.var("y", shape=(data_shape), dtype="float32")
+    kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
+
+    def get_expected():
+        def set_func_attr(func, compile_name, symbol_name):
+            func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Compiler", compile_name)
+            func = func.with_attr("global_symbol", symbol_name)
+            return func
+
+        # Create a nested TRT function that matches the expected output
+        mod = tvm.IRModule()
+        var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
+        kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), dtype="float32")
+        out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], kernel_size=k_shape[2:4])
+        f1 = GlobalVar("tensorrt_0")
+        func = relay.Function([var1, kernel_trt], out1)
+        func = set_func_attr(func, "tensorrt", "tensorrt_0")
+        mod[f1] = func
+        mod = relay.transform.InferType()(mod)
+
+        # Create the main function
+        out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+        out = relay.add(out1, f1(y, kernel))
+        f = relay.Function([x, y, kernel], out)
+        mod["main"] = f
+        mod = relay.transform.InferType()(mod)
+        return mod
+
+    # Create relay function that will be offloaded to TRT
+    out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out = relay.add(out1, out2)
+    f = relay.Function([x, y, kernel], out)
+
+    # Pass the function to TRT compilation
+    mod = tvm.IRModule()
+    mod["main"] = f
+    mod = relay.transform.InferType()(mod)
+    mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={})
+
+    # Get the expected relay graph and compare
+    mod_exp = get_expected()
+    tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
+    return
+
+
 if __name__ == "__main__":
     test_tensorrt_not_compatible()
     test_tensorrt_simple()
     test_tensorrt_simple_cpu_io()
-    test_tensorrt_serialize()
-
-    # Op tests
+    test_tensorrt_serialize_graph_runtime()

Review comment:
       you can remove all these test_xx and just use `pytest.main([__file__])` because pytest is used for testing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522282005



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -134,15 +135,18 @@ def partition_for_tensorrt(
 
     if params:
         mod["main"] = bind_params_by_name(mod["main"], params)
+
     seq = tvm.transform.Sequential(
         [
             transform.InferType(),
             RemoveDropoutPass(),
             transform.RemoveUnusedFunctions(),
             transform.ConvertLayout(
-                {"nn.conv2d": ["NCHW", "default"], "nn.conv3d": ["NCDHW", "default"]}
+                {"nn.conv2d": ["NCHW", "default"],
+                 "nn.conv3d": ["NCDHW", "default"]}
             ),
             transform.FoldConstant(),
+            transform.InferType(),

Review comment:
       Do we need this?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -18,11 +18,12 @@
 """TensorRT supported operators."""
 import logging
 import numpy as np
+import os

Review comment:
       Do we need this?

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +156,51 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
+def check_dynamism(args, op_name):
+    """
+    This function checks for dynamism inside any of the args in the op.
+    Can be used to offload dynamic ops that are not supported by TRT to
+    be offloaded to relay VM.
+
+    Raises a NotImplementedError if the type of the arg is not of types
+    Call, Var, Constant, or TupleGetItem.
+
+    Parameters
+    ----------
+    args: a TRT array of the arguments of the op

Review comment:
       Please follow Python style for docstrings

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -715,6 +771,34 @@ def conv3d_transpose_annotate_fn(expr):  # pylint: disable=unused-variable
         return False
     return True
 
+_register_external_dynamic_check_func("add", add_annotate_fn)
+_register_external_dynamic_check_func("nn.batch_norm", batch_norm_annotate_fn)
+_register_external_dynamic_check_func("nn.softmax", softmax_annotate_fn)
+_register_external_dynamic_check_func("nn.conv2d", conv2d_annotate_fn)
+_register_external_dynamic_check_func("nn.dense", dense_annotate_fn)
+_register_external_dynamic_check_func("nn.bias_add", bias_add_annotate_fn)
+_register_external_dynamic_check_func("nn.max_pool2d", max_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.avg_pool2d", avg_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.global_max_pool2d", global_max_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("nn.global_avg_pool2d", global_avg_pool_2d_annotate_fn)
+_register_external_dynamic_check_func("expand_dims", expand_dims_annotate_fn)
+_register_external_dynamic_check_func("squeeze", squeeze_annotate_fn)
+_register_external_dynamic_check_func("concatenate", concatenate_annotate_fn)
+_register_external_dynamic_check_func("nn.conv2d_transpose", conv2d_transpose_annotate_fn)
+_register_external_dynamic_check_func("transpose", transpose_annotate_fn)
+_register_external_dynamic_check_func("layout_transform", layout_transform_annotate_fn)
+_register_external_dynamic_check_func("reshape", reshape_annotate_fn)
+_register_external_dynamic_check_func("nn.pad", pad_annotate_fn)
+_register_external_dynamic_check_func("strided_slice", strided_slice_annotate_fn)
+_register_external_dynamic_check_func("nn.adaptive_max_pool2d", adaptive_max_pool2d_annotate_fn)
+_register_external_dynamic_check_func("nn.adaptive_avg_pool2d", adaptive_avg_pool2d_annotate_fn)
+_register_external_dynamic_check_func("nn.conv3d", conv3d_annotate_fn)
+_register_external_dynamic_check_func("nn.max_pool3d", max_pool_3d_annotate_fn)
+_register_external_dynamic_check_func("nn.avg_pool3d", avg_pool_3d_annotate_fn)
+_register_external_dynamic_check_func("nn.conv3d_transpose", conv3d_transpose_annotate_fn)
+
+
+

Review comment:
       remove extra spaces everywhere

##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -173,6 +215,29 @@ def _register_external_op_helper(op_name, supported=True):
     )
 
 
+def _register_external_dynamic_check_func(op_name, checker):
+    """
+    Wrapper to check dynamic shapes inside any of the args in the op
+
+    Parameters
+    ----------
+    op_name: name of the op for debugging purposes only

Review comment:
       Same as above




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 merged pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
anijain2305 merged pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522653603



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +153,50 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
 
+def check_dynamism(args, op_name):
+    """
+    Check for dynamism inside any of the args in the op.
+
+    Parameters
+    ----------
+    args : tvm.ir.container.Array
+        Arguments of the op. Each of the argument shape is checked for presence of dynamic
+        components.
+    op_name: str
+        Name of the op for debugging purposes only.
+    Returns
+    ----------
+    ret : bool
+        True if dynamism is present, False otherwise
+    """
+    for arg in args:
+        if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
+            for dim_shape in arg.checked_type.shape:

Review comment:
       `        elif isinstance(arg, Tuple):
               return check_dynamism(arg.fields, op_name)`
   
   Hi @zhiics , thanks for the comments. For this particular issue, the next two lines in the code is as above. If the arg is a nested Tuple, it does get recursively called and its args get checked for dynamism. Let me know if there is something else I am missing here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#issuecomment-726167838


   @trevor-m  @anijain2305 @zhiics 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#issuecomment-727084182


   Thanks @rohanmukh @zhiics @trevor-m This is merged


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
zhiics commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r523123910



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +153,50 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
 
+def check_dynamism(args, op_name):
+    """
+    Check for dynamism inside any of the args in the op.
+
+    Parameters
+    ----------
+    args : tvm.ir.container.Array
+        Arguments of the op. Each of the argument shape is checked for presence of dynamic
+        components.
+    op_name: str
+        Name of the op for debugging purposes only.
+    Returns
+    ----------
+    ret : bool
+        True if dynamism is present, False otherwise
+    """
+    for arg in args:
+        if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
+            for dim_shape in arg.checked_type.shape:

Review comment:
       oops, I didn't see that. It works. Thanks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] rohanmukh commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

Posted by GitBox <gi...@apache.org>.
rohanmukh commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r523166626



##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,44 +972,7 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
-if __name__ == "__main__":
-    test_tensorrt_not_compatible()
-    test_tensorrt_simple()
-    test_tensorrt_simple_cpu_io()
-    test_tensorrt_serialize()
-
-    # Op tests
-    test_conv2d()
-    test_conv2d_nhwc()
-    test_conv2d_weights_const()
-    test_conv2d_weights_transposed()
-    test_dense()
-    test_bias_add()
-    test_pool2d()
-    test_global_pool2d()
-    test_batch_flatten()
-    test_expand_dims()
-    test_squeeze()
-    test_concatenate()
-    test_conv2d_transpose()
-    test_reshape()
-    test_transpose()
-    test_float_const()
-    test_pad()
-    test_softmax()
-    test_batch_norm()
-    test_unary()
-    test_clip()
-    test_leaky_relu()
-    test_binary()
-    test_reduce()
-    test_strided_slice()
-    test_adaptive_pool2d()
-    test_multiple_outputs()
-    test_conv3d()
-    test_pool3d()
-    test_conv3d_transpose()
-
+def test_tensorrt_integration():

Review comment:
       Thanks @trevor-m, sorry I missed that while importing the tests to pytest.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org