You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/11/30 17:48:53 UTC

[GitHub] [tvm] codeislife99 commented on a change in pull request #6955: Dynamic Batch Support for TRT

codeislife99 commented on a change in pull request #6955:
URL: https://github.com/apache/tvm/pull/6955#discussion_r532784861



##########
File path: tests/python/contrib/test_tensorrt.py
##########
@@ -1034,5 +1039,119 @@ def set_func_attr(func, compile_name, symbol_name):
     tvm.ir.assert_structural_equal(mod_trt, mod_exp, map_free_vars=True)
 
 
+def convert_traced_model_to_vm_trt(
+    traced_module: torch.jit.TopLevelTracedModule, np_sample_input: np.ndarray, target: str
+) -> tvm.runtime.vm.Executable:
+    """
+    This function converts a traced pytorch model to VM + TRT.
+    """
+    input_shape = np_sample_input.shape
+    input_name = "input0"
+    shape_list = [(input_name, input_shape)]
+    mod, params = relay.frontend.from_pytorch(traced_module, shape_list)
+    mod, config = tensorrt.partition_for_tensorrt(mod, params, remove_no_mac_subgraphs=True)
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
+        vm_trt_exec = relay.vm.compile(mod, target=target, params=params)
+
+    return vm_trt_exec
+
+
+def test_maskrcnn_resnet50() -> None:
+    """
+    This function tests the working of pytorch maskrcnn with resnet50 as backbone with
+    VM and VM + TRT. Since the order of compiled model outputs is a bit different from
+    original pytorch model, it uses a custom logic for comparison check.
+    """
+    if skip_codegen_test() or skip_runtime_test():
+        return
+
+    class TraceWrapper(torch.nn.Module):
+        """
+        This class is a wrapper over the torch module to convert the outputs into traceable form
+        """
+
+        def __init__(self, model: torch.nn.Module) -> None:
+            super().__init__()
+            self.model = model
+
+        def forward(
+            self, inp: torch.Tensor
+        ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+            out = self.model(inp)
+            return out[0]["boxes"], out[0]["scores"], out[0]["labels"], out[0]["masks"]
+
+    def get_traced_maskrcnn_model(np_sample_input: np.ndarray) -> torch.jit.TopLevelTracedModule:
+        """
+        This function takes a sample input and returns the traced maskrcnn model
+        """
+        model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
+        model = TraceWrapper(model_func(pretrained=True))
+        model.eval()
+        inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=np_sample_input.shape))
+
+        with torch.no_grad():
+            out = model(inp)
+            traced_module = torch.jit.trace(model, inp)
+            traced_module.eval()
+
+        return traced_module
+
+    def get_maskrcnn_input(in_size: int) -> np.ndarray:
+        """
+        This function gets a real image with multiple objects of interest and returns it.
+        """
+        input_shape = (1, 3, in_size, in_size)
+        img_path = "test_street_small.jpg"
+        img_url = (
+            "https://raw.githubusercontent.com/dmlc/web-data/"
+            "master/gluoncv/detection/street_small.jpg"
+        )
+        download(img_url, img_path)
+        import cv2
+
+        img = cv2.imread(img_path).astype("float32")
+        img = cv2.resize(img, (in_size, in_size))
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = np.transpose(img / 255.0, [2, 0, 1])
+        img = np.expand_dims(img, axis=0)
+
+        return img
+
+    in_size = 300
+    np_sample_input = get_maskrcnn_input(in_size)
+    traced_module = get_traced_maskrcnn_model(np_sample_input)
+    vm_trt_exec = convert_traced_model_to_vm_trt(traced_module, np_sample_input, target="llvm")
+    ctx = tvm.cpu()
+    vm = tvm.runtime.vm.VirtualMachine(vm_trt_exec, ctx)
+    vm.set_input("main", **{"input0": np_sample_input})
+    tvm_res = vm.run()
+
+    # Descending sort by scores and get the high confidence indices. In this example 9 is chosen,
+    # because this image has 9 boxes over 0.9 confidence
+    num_high_confidence_boxes = 9
+    tvm_indices = np.argsort(-1 * tvm_res[1].asnumpy())[:num_high_confidence_boxes]
+
+    with torch.no_grad():
+        out = traced_module(torch.Tensor(np_sample_input))
+        # Descending sort by scores and get the high confidence indices
+        pt_indices = np.argsort(-1 * out[1].numpy())[:num_high_confidence_boxes]
+
+    tol = [1e-1, 5e-3, 1e-5, 4e-1]  # [Box Tol, Score Tol, Label Tol, Mask Tol]
+    # Because of certain ops, there are certain minor differences in TVM outputs and PT outputs,
+    # This means that the tolerance can't be 1e-4 or 1e-5 throughout. The ideal way to get around
+    # this is to test it on an entire dataset and compare mAP with the original model.
+    # However, since that is not practically possible on CI, the following compromise is made.
+    # These tolerances are chosen based on their impact or lack thereof to the mAP score, e.g:
+    # 0.1 pixel difference of a box in a 300X300 image wont make any change.
+    for i, tol_val in zip(range(4), tol):
+        np.testing.assert_allclose(
+            tvm_res[i].asnumpy()[tvm_indices],
+            out[i].numpy()[pt_indices],
+            rtol=tol_val,
+            atol=tol_val,
+        )
+
+
 if __name__ == "__main__":
-    pytest.main([__file__])
+    test_maskrcnn_resnet50()

Review comment:
       oops, thanks ! 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org