You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/11/08 15:50:58 UTC
[incubator-tvm] branch main updated: [RELAY][OP] roi_pool operator
alter layout (#6516)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 82f3e4b [RELAY][OP] roi_pool operator alter layout (#6516)
82f3e4b is described below
commit 82f3e4b3d6aa34a925cb011d575799d65a21629b
Author: Honghua Cao <49...@users.noreply.github.com>
AuthorDate: Sun Nov 8 23:50:42 2020 +0800
[RELAY][OP] roi_pool operator alter layout (#6516)
Co-authored-by: honghua.cao <ho...@streamcomputing.com>
---
python/tvm/relay/op/vision/_rcnn.py | 44 +++++++++++++++++-
src/relay/op/vision/rcnn_op.cc | 30 +++++++++++--
tests/python/relay/test_pass_convert_op_layout.py | 54 +++++++++++++++++++++++
3 files changed, 123 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py
index 46eb3cb..4686974 100644
--- a/python/tvm/relay/op/vision/_rcnn.py
+++ b/python/tvm/relay/op/vision/_rcnn.py
@@ -69,11 +69,53 @@ def convert_roi_align(attrs, inputs, tinfos, desired_layouts):
raise ValueError("Layout %s is not yet supported." % desired_data_layout)
+@reg.register_convert_op_layout("vision.roi_pool")
+def convert_roi_pool(attrs, inputs, tinfos, desired_layouts):
+ """Convert Layout pass registration for roi_pool op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current roi_pool
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ tinfos : list of types
+ List of input and output types
+ desired_layouts : list of layout strings
+ List of layouts defining our desired
+ layout for the data and rois inputs respectively.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The transformed expr
+ """
+ # pylint: disable=import-outside-toplevel
+ from tvm import relay
+
+ data, rois = inputs
+ new_attrs = dict(attrs)
+ assert (
+ len(desired_layouts) == 2
+ ), "A desired layout is expected for both of vision.roi_pool's inputs"
+
+ desired_data_layout, desired_rois_layout = map(str, desired_layouts)
+ assert desired_data_layout != "default", "Data layout cannot be default"
+ assert desired_rois_layout == "default", "Rois layout must be default"
+
+ new_attrs["layout"] = desired_data_layout
+ # rois layout not change
+ if desired_data_layout in ["NCHW", "NHWC"]:
+ return relay.vision.roi_pool(data, rois, **new_attrs)
+
+ raise ValueError("Layout %s is not yet supported." % desired_data_layout)
+
+
# roi_pool
@reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _):
"""Compute definition of roi_pool"""
- assert attrs.layout == "NCHW"
+ assert attrs.layout == "NCHW", "only support nchw for now"
return [
topi.vision.rcnn.roi_pool_nchw(
inputs[0],
diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc
index 8be38d0..f7bbf37 100644
--- a/src/relay/op/vision/rcnn_op.cc
+++ b/src/relay/op/vision/rcnn_op.cc
@@ -119,14 +119,35 @@ bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK(roi_pool_attrs);
ICHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
ICHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
- ICHECK_EQ(roi_pool_attrs->layout, "NCHW") << "ROI Pool only supports NCHW layout";
// assign output type
- std::vector<IndexExpr> oshape(
- {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
+ std::vector<IndexExpr> oshape;
+ if (roi_pool_attrs->layout == "NCHW") {
+ oshape = {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]};
+ } else if (roi_pool_attrs->layout == "NHWC") {
+ oshape = {rshape[0], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1], dshape[3]};
+ } else {
+ LOG(FATAL) << "vision.roi_pool does not support " << roi_pool_attrs->layout << " layout";
+ }
+
reporter->Assign(types[2], TensorType(oshape, data->dtype));
return true;
}
+template <typename T>
+Array<Array<Layout> > ROIPoolInferCorrectLayout(const Attrs& attrs,
+ const Array<Layout>& new_in_layouts,
+ const Array<Layout>& old_in_layouts,
+ const Array<tvm::relay::Type>& old_in_types) {
+ // NOTE: Discard "const" qualifier here.
+ T* params = const_cast<T*>(attrs.as<T>());
+ Layout data_layout = params->layout;
+
+ // Layout inference needs to define the layout for all inputs and output data layouts.
+ // For roi_pool, the second inputs is 2-D tensor with shape [num_roi, 5].
+ // So, we set the layout as "N5".
+ return Array<Array<Layout> >{{data_layout, Layout("N5")}, {data_layout}};
+}
+
Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
String layout) {
auto attrs = make_object<ROIPoolAttrs>();
@@ -153,7 +174,8 @@ RELAY_REGISTER_OP("vision.roi_pool")
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
- .add_type_rel("ROIPool", ROIPoolRel);
+ .add_type_rel("ROIPool", ROIPoolRel)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ROIPoolInferCorrectLayout<ROIPoolAttrs>);
TVM_REGISTER_NODE_TYPE(ProposalAttrs);
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 1fc5d39..7fc896a 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -966,6 +966,59 @@ def test_conv_strided_slice_convert_layout():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+def test_conv_roi_pool_convert_layout():
+ def before():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
+ y = relay.nn.conv2d(
+ x,
+ weight1,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ )
+ rois = relay.var("rois", shape=(32, 5))
+ y = relay.vision.roi_pool(
+ y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NCHW"
+ )
+ y = relay.Function(analysis.free_vars(y), y)
+ return y
+
+ def expected():
+ x = relay.var("x", shape=(1, 64, 56, 56))
+ weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
+ x = relay.layout_transform(x, "NCHW", "NHWC")
+ weight1 = relay.layout_transform(weight1, "OIHW", "HWIO")
+ y = relay.nn.conv2d(
+ x,
+ weight1,
+ channels=64,
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ )
+ rois = relay.var("rois", shape=(32, 5))
+ y = relay.vision.roi_pool(
+ y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NHWC"
+ )
+ ret = relay.layout_transform(y, "NHWC", "NCHW")
+ y = relay.Function(analysis.free_vars(ret), ret)
+ return y
+
+ a = before()
+ desired_layouts = {
+ "nn.conv2d": ["NHWC", "HWIO"],
+ "vision.roi_pool": ["NHWC", "default"],
+ }
+ a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
+ b = run_opt_pass(expected(), transform.InferType())
+
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
def test_default_keyword():
""" Check that the default keyword selects correct TVM default layout. """
@@ -1253,6 +1306,7 @@ if __name__ == "__main__":
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
+ test_conv_roi_pool_convert_layout()
test_conv_strided_slice_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()