You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2020/09/17 10:09:01 UTC

[incubator-tvm] branch master updated: Add PT OD tutorial (#6500)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f524f8  Add PT OD tutorial (#6500)
8f524f8 is described below

commit 8f524f83066f1b221652cdfed834220157bfaf44
Author: Yao Wang <ke...@gmail.com>
AuthorDate: Thu Sep 17 03:08:52 2020 -0700

    Add PT OD tutorial (#6500)
---
 .../frontend/deploy_object_detection_pytorch.py    | 154 +++++++++++++++++++++
 1 file changed, 154 insertions(+)

diff --git a/tutorials/frontend/deploy_object_detection_pytorch.py b/tutorials/frontend/deploy_object_detection_pytorch.py
new file mode 100644
index 0000000..6408685
--- /dev/null
+++ b/tutorials/frontend/deploy_object_detection_pytorch.py
@@ -0,0 +1,154 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Compile PyTorch Object Detection Models
+=======================================
+This article is an introductory tutorial to deploy PyTorch object
+detection models with Relay VM.
+
+For us to begin with, PyTorch should be installed.
+TorchVision is also required since we will be using it as our model zoo.
+
+A quick solution is to install via pip
+
+.. code-block:: bash
+
+    pip install torch==1.4.0
+    pip install torchvision==0.5.0
+
+or please refer to official site
+https://pytorch.org/get-started/locally/
+
+PyTorch versions should be backwards compatible but should be used
+with the proper TorchVision version.
+
+Currently, TVM supports PyTorch 1.4 and 1.3. Other versions may
+be unstable.
+"""
+
+import tvm
+from tvm import relay
+from tvm import relay
+from tvm.runtime.vm import VirtualMachine
+from tvm.contrib.download import download
+
+import numpy as np
+import cv2
+
+# PyTorch imports
+import torch
+import torchvision
+
+######################################################################
+# Load pre-trained maskrcnn from torchvision and do tracing
+# ---------------------------------------------------------
+in_size = 300
+
+input_shape = (1, 3, in_size, in_size)
+
+
+def do_trace(model, inp):
+    model_trace = torch.jit.trace(model, inp)
+    model_trace.eval()
+    return model_trace
+
+
+def dict_to_tuple(out_dict):
+    if "masks" in out_dict.keys():
+        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
+    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]
+
+
+class TraceWrapper(torch.nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.model = model
+
+    def forward(self, inp):
+        out = self.model(inp)
+        return dict_to_tuple(out[0])
+
+
+model_func = torchvision.models.detection.maskrcnn_resnet50_fpn
+model = TraceWrapper(model_func(pretrained=True))
+
+model.eval()
+inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
+
+with torch.no_grad():
+    out = model(inp)
+    script_module = do_trace(model, inp)
+
+######################################################################
+# Download a test image and pre-process
+# -------------------------------------
+img_path = "test_street_small.jpg"
+img_url = (
+    "https://raw.githubusercontent.com/dmlc/web-data/" "master/gluoncv/detection/street_small.jpg"
+)
+download(img_url, img_path)
+
+img = cv2.imread(img_path).astype("float32")
+img = cv2.resize(img, (in_size, in_size))
+img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+img = np.transpose(img / 255.0, [2, 0, 1])
+img = np.expand_dims(img, axis=0)
+
+######################################################################
+# Import the graph to Relay
+# -------------------------
+input_name = "input0"
+shape_list = [(input_name, input_shape)]
+mod, params = relay.frontend.from_pytorch(script_module, shape_list)
+
+######################################################################
+# Compile with Relay VM
+# ---------------------
+# Note: Currently only CPU target is supported. For x86 target, it is
+# highly recommended to build TVM with Intel MKL and Intel OpenMP to get
+# best performance, due to the existence of large dense operator in
+# torchvision rcnn models.
+
+# Add "-libs=mkl" to get best performance on x86 target.
+# For x86 machine supports AVX512, the complete target is
+# "llvm -mcpu=skylake-avx512 -libs=mkl"
+target = "llvm"
+
+with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
+    vm_exec = relay.vm.compile(mod, target=target, params=params)
+
+######################################################################
+# Inference with Relay VM
+# -----------------------
+ctx = tvm.cpu()
+vm = VirtualMachine(vm_exec, ctx)
+vm.set_input("main", **{input_name: img})
+tvm_res = vm.run()
+
+######################################################################
+# Get boxes with score larger than 0.9
+# ------------------------------------
+score_threshold = 0.9
+boxes = tvm_res[0].asnumpy().tolist()
+valid_boxes = []
+for i, score in enumerate(tvm_res[1].asnumpy().tolist()):
+    if score > score_threshold:
+        valid_boxes.append(boxes[i])
+    else:
+        break
+
+print("Get {} valid boxes".format(len(valid_boxes)))