You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/01/27 23:09:55 UTC

[tvm] branch main updated: [Torch] More graph rewrites for Faster RCNN / MaskRCNN (#7346)

This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4006bde  [Torch] More graph rewrites for Faster RCNN / MaskRCNN (#7346)
4006bde is described below

commit 4006bde68e32daeaac5de11d9fc331a28ff55706
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Jan 28 08:09:43 2021 +0900

    [Torch] More graph rewrites for Faster RCNN / MaskRCNN (#7346)
    
    * add post nms topk to max_out_size rewrite
    
    * add argsort conversion
    
    * scatter pattern first cut
    
    * matching seems to working
    
    * dup matching fixed
    
    * add converter
    
    * conversion seems working
    
    * add reshape, use take
    
    * remove pytorch argsort converter
    
    * update test
    
    * add doc
---
 python/tvm/relay/frontend/pytorch_utils.py         | 258 +++++++++++++++++++--
 .../frontend/pytorch/test_object_detection.py      |  18 +-
 2 files changed, 261 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py
index 6fc5a6a..248f535 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -16,13 +16,16 @@
 # under the License.
 # pylint: disable=import-outside-toplevel, unused-argument, invalid-name
 """ Common utilities used by PyTorch frontend """
+from .. import expr
 from .. import op
 from ..dataflow_pattern import (
+    wildcard,
     is_constant,
     is_op,
     rewrite,
     is_tuple,
-    wildcard,
+    is_tuple_get_item,
+    is_if,
     DFPatternCallback,
 )
 
@@ -36,6 +39,19 @@ def is_version_greater_than(ver):
     )
 
 
+def dyn_strided_slice_pattern(inp, end):
+    """A pattern to detect dynamic strided slice op."""
+    zero = is_constant()
+    cast_like = is_op("cast_like")(zero, is_constant())
+    less = is_op("less")(is_constant(), cast_like)
+    shape_of = is_op("shape_of")(inp)
+    cast_like = is_op("cast_like")(shape_of, is_constant())
+    add = is_op("add")(is_constant(), cast_like)
+    where = is_op("where")(less, add, is_constant())
+
+    return is_op("dyn.strided_slice")(inp, where, end, is_constant())
+
+
 def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
     """A pattern to detect batched_nms function in torchvision
 
@@ -73,7 +89,6 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
 
     """
     one = is_constant()
-    zero = is_constant()
 
     # Equivelent PyTorch code from above snippet
     # offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
@@ -84,17 +99,10 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
 
     # The following doesn't appear in the above Relay snippet. It is required for dynamic
     # stride_slice handling
-    cast_like = is_op("cast_like")(zero, is_constant())
-    less = is_op("less")(is_constant(), cast_like)
-    shape_of = is_op("shape_of")(mul)
-    cast_like = is_op("cast_like")(shape_of, is_constant())
-    add = is_op("add")(is_constant(), cast_like)
-    where = is_op("where")(less, add, is_constant())
     shape_of = is_op("shape_of")(mul)
     cast = is_op("cast")(shape_of)
-
     # This corresponds to offsets[:, None], where offsets is the result of multiplication
-    dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant())
+    dyn_strided_slice = dyn_strided_slice_pattern(mul, cast)
 
     # Add offsets to the boxes
     expand_dims = is_op("expand_dims")(dyn_strided_slice)
@@ -112,8 +120,49 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
     )
 
 
-class NMSRewrite(DFPatternCallback):
-    """A callback to rewrite nms and restore batched nms"""
+def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold):
+    """
+    Detect the following pattern used in torchvision detection models.
+
+    def batched_nms(...):
+        if boxes.numel() == 0:
+            return torch.empty((0,), dtype=torch.int64, device=boxes.device)
+        else:
+            ...
+            return nms(boxes_for_nms, scores, iou_threshold)
+
+    keep = batched_nms(boxes, scores, lvl, self.nms_thresh)
+    keep = keep[:post_nms_top_k] # keep only topk scoring predictions
+
+    An equivalent Relay subgraph:
+
+    %1184 = if (%1117) {
+      ...
+    } else {
+      ...
+      %1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...);
+      ...
+      %1183 = dyn.strided_slice(%1174, %1180, %1182, ...);
+      cast(%1183, dtype="int64")
+    };
+    %1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]);
+
+    """
+    nms = is_op("vision.non_max_suppression")(
+        data, valid_count, indices, is_constant(), iou_threshold
+    )
+    indices = is_op("squeeze")(is_tuple_get_item(nms, 0))
+    size = is_op("squeeze")(is_tuple_get_item(nms, 1))
+    dyn_strided_slice = dyn_strided_slice_pattern(indices, size)
+    cast_i64 = is_op("cast")(dyn_strided_slice)
+
+    batched_nms_result = is_if(cond, true_branch, cast_i64)
+
+    return is_op("strided_slice")(batched_nms_result)
+
+
+class MulticlassNMSRewrite(DFPatternCallback):
+    """A callback to rewrite nms and restore batched nms."""
 
     def __init__(self):
         super().__init__()
@@ -169,10 +218,193 @@ class NMSRewrite(DFPatternCallback):
         return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)
 
 
+class PostNMSTopKRewrite(DFPatternCallback):
+    """A callback to rewrite nms to exploit max_out_size parameter."""
+
+    def __init__(self):
+        super().__init__()
+        self.cond = wildcard()
+        self.true_branch = wildcard()
+        self.data = wildcard()
+        self.valid_count = wildcard()
+        self.indices = wildcard()
+        self.iou_threshold = wildcard()
+
+        self.pattern = topk_after_batch_nms_pattern(
+            self.cond,
+            self.true_branch,
+            self.data,
+            self.valid_count,
+            self.indices,
+            self.iou_threshold,
+        )
+
+    def rewrite_batch_nms_with_max_out_size(
+        self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk
+    ):
+        """Use the detected post NMS topk parameter in NMS op."""
+        nms_ret = op.vision.non_max_suppression(
+            data=data,
+            valid_count=valid_count,
+            indices=indices,
+            max_output_size=post_nms_topk,
+            iou_threshold=iou_threshold,
+            force_suppress=False,
+            top_k=-1,
+            coord_start=2,
+            score_index=1,
+            id_index=0,
+            return_indices=True,
+            invalid_to_bottom=False,
+        )
+
+        size = op.squeeze(nms_ret[1], axis=[1])
+        data_slice = op.squeeze(nms_ret[0], axis=[0])
+
+        ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size")
+
+        nms_result = op.cast(ret, "int64")
+
+        return expr.If(cond, true_branch, nms_result)
+
+    def callback(self, pre, post, node_map):
+        post_nms_topk = post.attrs.end[0].value
+        return self.rewrite_batch_nms_with_max_out_size(
+            node_map[self.cond][0],
+            node_map[self.true_branch][0],
+            node_map[self.data][0],
+            node_map[self.valid_count][0],
+            node_map[self.indices][0],
+            node_map[self.iou_threshold][0],
+            post_nms_topk,
+        )
+
+
+def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales):
+    """Detect the Relay subgraph corresponding to the following PyTorch code
+
+    first_result = roi_align_results[0]
+    dtype, device = first_result.dtype, first_result.device
+    res = torch.zeros((levels.size(0), first_result.size(1),
+                       first_result.size(2), first_result.size(3)),
+                      dtype=dtype, device=device)
+    for level in range(len(roi_align_results)):
+        index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+        index = index.expand(index.size(0),
+                             roi_align_results[level].size(1),
+                             roi_align_results[level].size(2),
+                             roi_align_results[level].size(3))
+        res = res.scatter(0, index, roi_align_results[level])
+    return res
+    """
+
+    def do_where(levels, _):
+        idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant()))
+        idx_in_level = is_op("split")(idx_in_level)
+        idx_in_level = is_tuple_get_item(idx_in_level, 0)
+        idx_in_level = is_op("squeeze")(idx_in_level)
+        idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0)
+        return idx_in_level
+
+    scatter_res = wildcard()
+
+    for i in range(num_scales):
+        # index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+        scatter_indices = do_where(levels, i)
+        scatter_indices = is_op("reshape")(scatter_indices)
+
+        # index = index.expand(index.size(0),
+        #                      unmerged_results[level].size(1),
+        #                      unmerged_results[level].size(2),
+        #                      unmerged_results[level].size(3))
+        scatter_indices = is_op("repeat")(scatter_indices)
+        scatter_indices = is_op("repeat")(scatter_indices)
+        scatter_indices = is_op("repeat")(scatter_indices)
+
+        scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i])
+
+    return is_op("reshape")(scatter_res)
+
+
+class ScatterRewrite(DFPatternCallback):
+    """A callback to rewrite repeated scatters with a batched gather."""
+
+    def __init__(self, num_scales):
+        super().__init__()
+        self.num_scales = num_scales
+        self.levels = wildcard()
+        self.roi_align_results = []
+        for _ in range(num_scales):
+            self.roi_align_results.append(wildcard())
+
+        self.pattern = scatter_roi_align_result_pattern(
+            self.levels, self.roi_align_results, num_scales
+        )
+
+    def convert_scatter_to_gather(self, levels, roi_align_results):
+        """Replace the detected scatter loop with the following PyTorch code
+
+        indices_per_level = []
+        for level in range(num_scales):
+            idx_in_level = torch.where(levels == level)[0]
+            indices_per_leve.append(idx_in_level)
+
+        stacked_features = torch.cat(roi_align_results, dim=0)
+        stacked_indices = torch.cat(indices_per_level, dim=0)
+        argsort_indices = torch.argort(stacked_indices)
+        return stacked_features[argsort_indices, :]
+        """
+
+        # Collect inidices and concat them
+        indices_per_level = []
+        for i in range(self.num_scales):
+            equal = op.equal(levels, expr.const(i, dtype="int64"))
+            argwhere = op.argwhere(equal)
+            split = op.split(argwhere, indices_or_sections=1, axis=1)
+            squeeze = op.squeeze(split[0], axis=[1])
+            indices = op.cast(squeeze, dtype="int64")
+            indices_per_level.append(indices)
+
+        indices_concat = op.concatenate(indices_per_level, 0)
+
+        # Concat roi align results per level, and argsort indices
+        # To prepare for a batched gather
+        roi_align_results_concat = op.concatenate(roi_align_results, 0)
+        argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64")
+
+        # Permute rows by argsorted indices
+        permuted = op.take(roi_align_results_concat, argsort_indices, axis=0)
+
+        return op.reshape(permuted, [0, -1, 1, 1])
+
+    def callback(self, pre, post, node_map):
+        levels = node_map[self.levels][0]
+        roi_align_results = [node_map[feat][0] for feat in self.roi_align_results]
+        return self.convert_scatter_to_gather(levels, roi_align_results)
+
+
 def rewrite_nms_to_batched_nms(mod):
     """Rewrite the input graph to replace non maximum surpression
     in torchvision that does not take class id into account with the one
     that avoids IOU tests between different classes.
     """
-    mod["main"] = rewrite(NMSRewrite(), mod["main"])
+    mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"])
+    return mod
+
+
+def rewrite_batched_nms_with_max_out_size(mod):
+    """Rewrite the input graph to detect slicing after batched nms and
+    use the slicing size as the parameter max_out_size in NMS.
+    """
+    mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"])
+    return mod
+
+
+def rewrite_scatter_to_gather(mod, num_scales):
+    """Rewrite the input graph to replace a repeated scatter loop with
+    a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign
+    to merge roi_align results for all scales. The scatter is used to emulate
+    inplace updates.
+    """
+    mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"])
     return mod
diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py
index 2c32377..fd33dd1 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -26,7 +26,11 @@ import tvm
 import tvm.testing
 from tvm import relay
 from tvm.runtime.vm import VirtualMachine
-from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms
+from tvm.relay.frontend.pytorch_utils import (
+    rewrite_nms_to_batched_nms,
+    rewrite_batched_nms_with_max_out_size,
+    rewrite_scatter_to_gather,
+)
 from tvm.contrib.download import download
 
 
@@ -72,7 +76,7 @@ def generate_jit_model(index):
     ]
 
     model_func = model_funcs[index]
-    model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
+    model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000))
 
     model.eval()
     inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
@@ -141,6 +145,16 @@ def test_detection_models():
     after = mod["main"]
     assert not tvm.ir.structural_equal(after, before)
 
+    before = mod["main"]
+    mod = rewrite_batched_nms_with_max_out_size(mod)
+    after = mod["main"]
+    assert not tvm.ir.structural_equal(after, before)
+
+    before = mod["main"]
+    mod = rewrite_scatter_to_gather(mod, 4)  # num_scales is 4 for maskrcnn_resnet50_fpn
+    after = mod["main"]
+    assert not tvm.ir.structural_equal(after, before)
+
     tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")
 
     # Results should be equivalent after rewriting