You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/01/27 23:09:55 UTC
[tvm] branch main updated: [Torch] More graph rewrites for Faster
RCNN / MaskRCNN (#7346)
This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4006bde [Torch] More graph rewrites for Faster RCNN / MaskRCNN (#7346)
4006bde is described below
commit 4006bde68e32daeaac5de11d9fc331a28ff55706
Author: masahi <ma...@gmail.com>
AuthorDate: Thu Jan 28 08:09:43 2021 +0900
[Torch] More graph rewrites for Faster RCNN / MaskRCNN (#7346)
* add post nms topk to max_out_size rewrite
* add argsort conversion
* scatter pattern first cut
* matching seems to working
* dup matching fixed
* add converter
* conversion seems working
* add reshape, use take
* remove pytorch argsort converter
* update test
* add doc
---
python/tvm/relay/frontend/pytorch_utils.py | 258 +++++++++++++++++++--
.../frontend/pytorch/test_object_detection.py | 18 +-
2 files changed, 261 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py
index 6fc5a6a..248f535 100644
--- a/python/tvm/relay/frontend/pytorch_utils.py
+++ b/python/tvm/relay/frontend/pytorch_utils.py
@@ -16,13 +16,16 @@
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument, invalid-name
""" Common utilities used by PyTorch frontend """
+from .. import expr
from .. import op
from ..dataflow_pattern import (
+ wildcard,
is_constant,
is_op,
rewrite,
is_tuple,
- wildcard,
+ is_tuple_get_item,
+ is_if,
DFPatternCallback,
)
@@ -36,6 +39,19 @@ def is_version_greater_than(ver):
)
+def dyn_strided_slice_pattern(inp, end):
+ """A pattern to detect dynamic strided slice op."""
+ zero = is_constant()
+ cast_like = is_op("cast_like")(zero, is_constant())
+ less = is_op("less")(is_constant(), cast_like)
+ shape_of = is_op("shape_of")(inp)
+ cast_like = is_op("cast_like")(shape_of, is_constant())
+ add = is_op("add")(is_constant(), cast_like)
+ where = is_op("where")(less, add, is_constant())
+
+ return is_op("dyn.strided_slice")(inp, where, end, is_constant())
+
+
def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""A pattern to detect batched_nms function in torchvision
@@ -73,7 +89,6 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
"""
one = is_constant()
- zero = is_constant()
# Equivelent PyTorch code from above snippet
# offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
@@ -84,17 +99,10 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
# The following doesn't appear in the above Relay snippet. It is required for dynamic
# stride_slice handling
- cast_like = is_op("cast_like")(zero, is_constant())
- less = is_op("less")(is_constant(), cast_like)
- shape_of = is_op("shape_of")(mul)
- cast_like = is_op("cast_like")(shape_of, is_constant())
- add = is_op("add")(is_constant(), cast_like)
- where = is_op("where")(less, add, is_constant())
shape_of = is_op("shape_of")(mul)
cast = is_op("cast")(shape_of)
-
# This corresponds to offsets[:, None], where offsets is the result of multiplication
- dyn_strided_slice = is_op("dyn.strided_slice")(mul, where, cast, is_constant())
+ dyn_strided_slice = dyn_strided_slice_pattern(mul, cast)
# Add offsets to the boxes
expand_dims = is_op("expand_dims")(dyn_strided_slice)
@@ -112,8 +120,49 @@ def batched_nms_pattern(boxes, scores, idxs, iou_threshold, num_boxes, indices):
)
-class NMSRewrite(DFPatternCallback):
- """A callback to rewrite nms and restore batched nms"""
+def topk_after_batch_nms_pattern(cond, true_branch, data, valid_count, indices, iou_threshold):
+ """
+ Detect the following pattern used in torchvision detection models.
+
+ def batched_nms(...):
+ if boxes.numel() == 0:
+ return torch.empty((0,), dtype=torch.int64, device=boxes.device)
+ else:
+ ...
+ return nms(boxes_for_nms, scores, iou_threshold)
+
+ keep = batched_nms(boxes, scores, lvl, self.nms_thresh)
+ keep = keep[:post_nms_top_k] # keep only topk scoring predictions
+
+ An equivalent Relay subgraph:
+
+ %1184 = if (%1117) {
+ ...
+ } else {
+ ...
+ %1172 = vision.non_max_suppression(%1167, %1168, %1171, -1, 0.7f, ...);
+ ...
+ %1183 = dyn.strided_slice(%1174, %1180, %1182, ...);
+ cast(%1183, dtype="int64")
+ };
+ %1185 = strided_slice(%1184, begin=[0], end=[1000], strides=[1]);
+
+ """
+ nms = is_op("vision.non_max_suppression")(
+ data, valid_count, indices, is_constant(), iou_threshold
+ )
+ indices = is_op("squeeze")(is_tuple_get_item(nms, 0))
+ size = is_op("squeeze")(is_tuple_get_item(nms, 1))
+ dyn_strided_slice = dyn_strided_slice_pattern(indices, size)
+ cast_i64 = is_op("cast")(dyn_strided_slice)
+
+ batched_nms_result = is_if(cond, true_branch, cast_i64)
+
+ return is_op("strided_slice")(batched_nms_result)
+
+
+class MulticlassNMSRewrite(DFPatternCallback):
+ """A callback to rewrite nms and restore batched nms."""
def __init__(self):
super().__init__()
@@ -169,10 +218,193 @@ class NMSRewrite(DFPatternCallback):
return self.convert_batched_nms(boxes, scores, idxs, iou_thres, num_boxes, indices)
+class PostNMSTopKRewrite(DFPatternCallback):
+ """A callback to rewrite nms to exploit max_out_size parameter."""
+
+ def __init__(self):
+ super().__init__()
+ self.cond = wildcard()
+ self.true_branch = wildcard()
+ self.data = wildcard()
+ self.valid_count = wildcard()
+ self.indices = wildcard()
+ self.iou_threshold = wildcard()
+
+ self.pattern = topk_after_batch_nms_pattern(
+ self.cond,
+ self.true_branch,
+ self.data,
+ self.valid_count,
+ self.indices,
+ self.iou_threshold,
+ )
+
+ def rewrite_batch_nms_with_max_out_size(
+ self, cond, true_branch, data, valid_count, indices, iou_threshold, post_nms_topk
+ ):
+ """Use the detected post NMS topk parameter in NMS op."""
+ nms_ret = op.vision.non_max_suppression(
+ data=data,
+ valid_count=valid_count,
+ indices=indices,
+ max_output_size=post_nms_topk,
+ iou_threshold=iou_threshold,
+ force_suppress=False,
+ top_k=-1,
+ coord_start=2,
+ score_index=1,
+ id_index=0,
+ return_indices=True,
+ invalid_to_bottom=False,
+ )
+
+ size = op.squeeze(nms_ret[1], axis=[1])
+ data_slice = op.squeeze(nms_ret[0], axis=[0])
+
+ ret = op.strided_slice(data_slice, begin=expr.const([0]), end=size, slice_mode="size")
+
+ nms_result = op.cast(ret, "int64")
+
+ return expr.If(cond, true_branch, nms_result)
+
+ def callback(self, pre, post, node_map):
+ post_nms_topk = post.attrs.end[0].value
+ return self.rewrite_batch_nms_with_max_out_size(
+ node_map[self.cond][0],
+ node_map[self.true_branch][0],
+ node_map[self.data][0],
+ node_map[self.valid_count][0],
+ node_map[self.indices][0],
+ node_map[self.iou_threshold][0],
+ post_nms_topk,
+ )
+
+
+def scatter_roi_align_result_pattern(levels, roi_align_results, num_scales):
+ """Detect the Relay subgraph corresponding to the following PyTorch code
+
+ first_result = roi_align_results[0]
+ dtype, device = first_result.dtype, first_result.device
+ res = torch.zeros((levels.size(0), first_result.size(1),
+ first_result.size(2), first_result.size(3)),
+ dtype=dtype, device=device)
+ for level in range(len(roi_align_results)):
+ index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+ index = index.expand(index.size(0),
+ roi_align_results[level].size(1),
+ roi_align_results[level].size(2),
+ roi_align_results[level].size(3))
+ res = res.scatter(0, index, roi_align_results[level])
+ return res
+ """
+
+ def do_where(levels, _):
+ idx_in_level = is_op("argwhere")(is_op("equal")(levels, is_constant()))
+ idx_in_level = is_op("split")(idx_in_level)
+ idx_in_level = is_tuple_get_item(idx_in_level, 0)
+ idx_in_level = is_op("squeeze")(idx_in_level)
+ idx_in_level = is_tuple_get_item(is_tuple([idx_in_level]), 0)
+ return idx_in_level
+
+ scatter_res = wildcard()
+
+ for i in range(num_scales):
+ # index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
+ scatter_indices = do_where(levels, i)
+ scatter_indices = is_op("reshape")(scatter_indices)
+
+ # index = index.expand(index.size(0),
+ # unmerged_results[level].size(1),
+ # unmerged_results[level].size(2),
+ # unmerged_results[level].size(3))
+ scatter_indices = is_op("repeat")(scatter_indices)
+ scatter_indices = is_op("repeat")(scatter_indices)
+ scatter_indices = is_op("repeat")(scatter_indices)
+
+ scatter_res = is_op("scatter")(scatter_res, scatter_indices, roi_align_results[i])
+
+ return is_op("reshape")(scatter_res)
+
+
+class ScatterRewrite(DFPatternCallback):
+ """A callback to rewrite repeated scatters with a batched gather."""
+
+ def __init__(self, num_scales):
+ super().__init__()
+ self.num_scales = num_scales
+ self.levels = wildcard()
+ self.roi_align_results = []
+ for _ in range(num_scales):
+ self.roi_align_results.append(wildcard())
+
+ self.pattern = scatter_roi_align_result_pattern(
+ self.levels, self.roi_align_results, num_scales
+ )
+
+ def convert_scatter_to_gather(self, levels, roi_align_results):
+ """Replace the detected scatter loop with the following PyTorch code
+
+ indices_per_level = []
+ for level in range(num_scales):
+ idx_in_level = torch.where(levels == level)[0]
+ indices_per_leve.append(idx_in_level)
+
+ stacked_features = torch.cat(roi_align_results, dim=0)
+ stacked_indices = torch.cat(indices_per_level, dim=0)
+ argsort_indices = torch.argort(stacked_indices)
+ return stacked_features[argsort_indices, :]
+ """
+
+ # Collect inidices and concat them
+ indices_per_level = []
+ for i in range(self.num_scales):
+ equal = op.equal(levels, expr.const(i, dtype="int64"))
+ argwhere = op.argwhere(equal)
+ split = op.split(argwhere, indices_or_sections=1, axis=1)
+ squeeze = op.squeeze(split[0], axis=[1])
+ indices = op.cast(squeeze, dtype="int64")
+ indices_per_level.append(indices)
+
+ indices_concat = op.concatenate(indices_per_level, 0)
+
+ # Concat roi align results per level, and argsort indices
+ # To prepare for a batched gather
+ roi_align_results_concat = op.concatenate(roi_align_results, 0)
+ argsort_indices = op.cast(op.argsort(indices_concat), dtype="int64")
+
+ # Permute rows by argsorted indices
+ permuted = op.take(roi_align_results_concat, argsort_indices, axis=0)
+
+ return op.reshape(permuted, [0, -1, 1, 1])
+
+ def callback(self, pre, post, node_map):
+ levels = node_map[self.levels][0]
+ roi_align_results = [node_map[feat][0] for feat in self.roi_align_results]
+ return self.convert_scatter_to_gather(levels, roi_align_results)
+
+
def rewrite_nms_to_batched_nms(mod):
"""Rewrite the input graph to replace non maximum surpression
in torchvision that does not take class id into account with the one
that avoids IOU tests between different classes.
"""
- mod["main"] = rewrite(NMSRewrite(), mod["main"])
+ mod["main"] = rewrite(MulticlassNMSRewrite(), mod["main"])
+ return mod
+
+
+def rewrite_batched_nms_with_max_out_size(mod):
+ """Rewrite the input graph to detect slicing after batched nms and
+ use the slicing size as the parameter max_out_size in NMS.
+ """
+ mod["main"] = rewrite(PostNMSTopKRewrite(), mod["main"])
+ return mod
+
+
+def rewrite_scatter_to_gather(mod, num_scales):
+ """Rewrite the input graph to replace a repeated scatter loop with
+ a batched gather. The scatter loop is used in torchvision MultiScaleRoIAlign
+ to merge roi_align results for all scales. The scatter is used to emulate
+ inplace updates.
+ """
+ mod["main"] = rewrite(ScatterRewrite(num_scales), mod["main"])
return mod
diff --git a/tests/python/frontend/pytorch/test_object_detection.py b/tests/python/frontend/pytorch/test_object_detection.py
index 2c32377..fd33dd1 100644
--- a/tests/python/frontend/pytorch/test_object_detection.py
+++ b/tests/python/frontend/pytorch/test_object_detection.py
@@ -26,7 +26,11 @@ import tvm
import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
-from tvm.relay.frontend.pytorch_utils import rewrite_nms_to_batched_nms
+from tvm.relay.frontend.pytorch_utils import (
+ rewrite_nms_to_batched_nms,
+ rewrite_batched_nms_with_max_out_size,
+ rewrite_scatter_to_gather,
+)
from tvm.contrib.download import download
@@ -72,7 +76,7 @@ def generate_jit_model(index):
]
model_func = model_funcs[index]
- model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=200))
+ model = TraceWrapper(model_func(pretrained=True, rpn_pre_nms_top_n_test=1000))
model.eval()
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))
@@ -141,6 +145,16 @@ def test_detection_models():
after = mod["main"]
assert not tvm.ir.structural_equal(after, before)
+ before = mod["main"]
+ mod = rewrite_batched_nms_with_max_out_size(mod)
+ after = mod["main"]
+ assert not tvm.ir.structural_equal(after, before)
+
+ before = mod["main"]
+ mod = rewrite_scatter_to_gather(mod, 4) # num_scales is 4 for maskrcnn_resnet50_fpn
+ after = mod["main"]
+ assert not tvm.ir.structural_equal(after, before)
+
tvm_res_after_rewrite = compile_and_run_vm(mod, params, data_np, "llvm")
# Results should be equivalent after rewriting