You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/12 22:56:39 UTC

[GitHub] [incubator-tvm] zhiics commented on a change in pull request #6905: [TRT][BYOC] handling dynamism in TensorRT to support OD models

zhiics commented on a change in pull request #6905:
URL: https://github.com/apache/incubator-tvm/pull/6905#discussion_r522473247



##########
File path: python/tvm/relay/op/contrib/tensorrt.py
##########
@@ -152,13 +153,50 @@ def partition_for_tensorrt(
     with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
         mod = seq(mod)
         mod = prune_tensorrt_subgraphs(mod)
+
     return mod, config
 
 
+def check_dynamism(args, op_name):
+    """
+    Check for dynamism inside any of the args in the op.
+
+    Parameters
+    ----------
+    args : tvm.ir.container.Array
+        Arguments of the op. Each of the argument shape is checked for presence of dynamic
+        components.
+    op_name: str
+        Name of the op for debugging purposes only.
+    Returns
+    ----------
+    ret : bool
+        True if dynamism is present, False otherwise
+    """
+    for arg in args:
+        if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
+            for dim_shape in arg.checked_type.shape:

Review comment:
       nit: the type might be a nested tuple. We may want to check the flattened one if it is a tuple: https://github.com/apache/incubator-tvm/blob/main/python/tvm/relay/op/memory/memory.py#L70

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -219,40 +226,106 @@ def test_tensorrt_not_compatible():
     mod = tvm.IRModule()
     mod["main"] = f
     mod, config = tensorrt.partition_for_tensorrt(mod)
-    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-        graph, lib, params = relay.build(mod, "cuda")
-    if skip_runtime_test():
-        return
-    mod = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
-    x_data = np.random.uniform(-1, 1, xshape).astype(dtype)
-    mod.run(x=x_data)
-    results = [mod.get_output(i).asnumpy() for i in range(mod.get_num_outputs())]
+    for mode in ["graph", "vm"]:
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            exec = relay.create_executor(mode, mod=mod, ctx=tvm.gpu(0), target="cuda")
+            if not skip_runtime_test():
+                results = exec.evaluate()(x_data)
 
 
-def test_tensorrt_serialize():
+def test_tensorrt_serialize_graph_runtime(data_shape=(1, 3, 224, 224), data_type="float32"):
     if skip_codegen_test():
         return
-    import mxnet
-    from mxnet.gluon.model_zoo.vision import get_model
 
+    i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
     block = get_model("resnet18_v1", pretrained=True)
-    mod, params = relay.frontend.from_mxnet(
-        block, shape={"data": (1, 3, 224, 224)}, dtype="float32"
-    )
-    # Compile
-    mod, config = tensorrt.partition_for_tensorrt(mod, params)
-    with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
-        lib = relay.build(mod, "cuda", params=params)
-    # Serialize
-    lib.export_library("compiled.so")
-    # Deserialize
-    loaded_lib = tvm.runtime.load_module("compiled.so")
-    # Run
-    if skip_runtime_test():
+    mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
+    mod, config = tensorrt.partition_for_tensorrt(mod)
+
+    def compile_graph(mod, params):
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            graph, lib, params = relay.build(mod, params=params, target="cuda")
+            params = relay.save_param_dict(params)
+        return graph, lib, params
+
+    def run_graph(graph, lib, params):
+        mod_ = graph_runtime.create(graph, lib, ctx=tvm.gpu(0))
+        mod_.load_params(params)
+        mod_.run(data=i_data)
+        res = mod_.get_output(0)
+        return res
+
+    def save_graph(graph, lib, params):
+        # Serialize
+        with open("compiled.json", "w") as f_graph_json:
+            f_graph_json.write(graph)
+        with open("compiled.params", "wb") as f_params:
+            f_params.write(params)
+        lib.export_library("compiled.so")
+
+    def load_graph():
+        # Deserialize
+        with open("compiled.json", "r") as f_graph_json:
+            graph = f_graph_json.read()
+        with open("compiled.params", "rb") as f_params:
+            params = bytearray(f_params.read())
+        lib = tvm.runtime.load_module("compiled.so")
+        return graph, lib, params
+
+    # Test serialization with graph runtime
+    graph, lib, graph_params = compile_graph(mod, params)
+    save_graph(graph, lib, graph_params)
+    loaded_graph, loaded_lib, loaded_params = load_graph()
+
+    if not skip_runtime_test():
+        result_dict = dict()
+        result_dict["graph"] = run_graph(graph, lib, graph_params)
+        result_dict["graph_ref"] = run_graph(loaded_graph, loaded_lib, loaded_params)
+        assert_result_dict_holds(result_dict)
+
+
+def test_tensorrt_serialize_vm(data_shape=(1, 3, 224, 224), data_type="float32"):
+    if skip_codegen_test():
         return
-    gen_module = tvm.contrib.graph_runtime.GraphModule(loaded_lib["default"](tvm.gpu(0)))
-    i_data = np.random.uniform(0, 1, (1, 3, 224, 224)).astype("float32")
-    gen_module.run(data=i_data)
+
+    i_data = np.random.uniform(0, 1, data_shape).astype(data_type)
+    block = get_model("resnet18_v1", pretrained=True)
+    mod, params = relay.frontend.from_mxnet(block, shape={"data": data_shape}, dtype=data_type)
+    mod, config = tensorrt.partition_for_tensorrt(mod)
+
+    def compile_vm(mod, params):
+        with tvm.transform.PassContext(opt_level=3, config={"relay.ext.tensorrt.options": config}):
+            vm_exec = relay.vm.compile(mod, target="cuda", params=params)
+            code, lib = vm_exec.save()
+        return code, lib
+
+    def run_vm(code, lib):
+        vm_exec = tvm.runtime.vm.Executable.load_exec(code, lib)
+        vm = VirtualMachine(vm_exec, tvm.gpu(0))
+        result = vm.invoke("main", data=i_data)
+        return result
+
+    def save_vm(code, lib):
+        # save and load the code and lib file.
+        lib.export_library("path_lib.so")
+        with open("path_code.ro", "wb") as fo:

Review comment:
       its probably better to use tmp dir because you sometimes may not have write permission at the current dir.

##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -874,13 +960,82 @@ def test_densenet121():
     run_and_verify_model("densenet121")
 
 
+def test_tensorrt_integration():
+    # Integration tests
+    test_alexnet()
+    test_resnet18_v1()
+    test_resnet18_v2()
+    test_squeezenet()
+    test_mobilenet()
+    test_mobilenet_v2()
+    test_vgg11()
+    test_densenet121()
+
+
+def test_dynamic_offload(data_shape=(1, 32, 8, 8), k_shape=(1, 32, 3, 3)):
+    """
+    This test checks for proper dynamic offloading of relay graphs. An addition between
+    the outputs of two conv2d's is performed, one of them having all static args whereas
+    the other has a arg with dynamic shape. It is expected for the TRT partitioner to
+    offload the conv2d with dynamic arg to TVM while running the other in TRT.
+    """
+    x = relay.var("x", shape=(data_shape[0], data_shape[1], Any(), Any()), dtype="float32")
+    y = relay.var("y", shape=(data_shape), dtype="float32")
+    kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
+
+    def get_expected():
+        def set_func_attr(func, compile_name, symbol_name):
+            func = func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
+            func = func.with_attr("Compiler", compile_name)
+            func = func.with_attr("global_symbol", symbol_name)
+            return func
+
+        # Create a nested TRT function that matches the expected output
+        mod = tvm.IRModule()
+        var1 = relay.var("tensorrt_0_i0", shape=(data_shape), dtype="float32")
+        kernel_trt = relay.var("tensorrt_0_i1", shape=(k_shape), dtype="float32")
+        out1 = relay.nn.conv2d(var1, kernel_trt, channels=k_shape[0], kernel_size=k_shape[2:4])
+        f1 = GlobalVar("tensorrt_0")
+        func = relay.Function([var1, kernel_trt], out1)
+        func = set_func_attr(func, "tensorrt", "tensorrt_0")
+        mod[f1] = func
+        mod = relay.transform.InferType()(mod)
+
+        # Create the main function
+        out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+        out = relay.add(out1, f1(y, kernel))
+        f = relay.Function([x, y, kernel], out)
+        mod["main"] = f
+        mod = relay.transform.InferType()(mod)
+        return mod
+
+    # Create relay function that will be offloaded to TRT
+    out1 = relay.nn.conv2d(x, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out2 = relay.nn.conv2d(y, kernel, channels=k_shape[0], kernel_size=k_shape[2:4])
+    out = relay.add(out1, out2)
+    f = relay.Function([x, y, kernel], out)
+
+    # Pass the function to TRT compilation
+    mod = tvm.IRModule()
+    mod["main"] = f
+    mod = relay.transform.InferType()(mod)
+    mod_trt, config = tensorrt.partition_for_tensorrt(mod, params={})
+
+    # Get the expected relay graph and compare
+    mod_exp = get_expected()
+    tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
+    return
+
+
 if __name__ == "__main__":
     test_tensorrt_not_compatible()
     test_tensorrt_simple()
     test_tensorrt_simple_cpu_io()
-    test_tensorrt_serialize()
-
-    # Op tests
+    test_tensorrt_serialize_graph_runtime()

Review comment:
       you can remove all these test_xx and just use `pytest.main([__file__])` because pytest is used for testing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org