You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/11/08 15:50:58 UTC

[incubator-tvm] branch main updated: [RELAY][OP] roi_pool operator alter layout (#6516)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82f3e4b  [RELAY][OP] roi_pool operator alter layout (#6516)
82f3e4b is described below

commit 82f3e4b3d6aa34a925cb011d575799d65a21629b
Author: Honghua Cao <49...@users.noreply.github.com>
AuthorDate: Sun Nov 8 23:50:42 2020 +0800

    [RELAY][OP] roi_pool operator alter layout (#6516)
    
    Co-authored-by: honghua.cao <ho...@streamcomputing.com>
---
 python/tvm/relay/op/vision/_rcnn.py               | 44 +++++++++++++++++-
 src/relay/op/vision/rcnn_op.cc                    | 30 +++++++++++--
 tests/python/relay/test_pass_convert_op_layout.py | 54 +++++++++++++++++++++++
 3 files changed, 123 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py
index 46eb3cb..4686974 100644
--- a/python/tvm/relay/op/vision/_rcnn.py
+++ b/python/tvm/relay/op/vision/_rcnn.py
@@ -69,11 +69,53 @@ def convert_roi_align(attrs, inputs, tinfos, desired_layouts):
     raise ValueError("Layout %s is not yet supported." % desired_data_layout)
 
 
+@reg.register_convert_op_layout("vision.roi_pool")
+def convert_roi_pool(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for roi_pool op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current roi_pool
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of layout strings
+        List of layouts defining our desired
+        layout for the data and rois inputs respectively.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay
+
+    data, rois = inputs
+    new_attrs = dict(attrs)
+    assert (
+        len(desired_layouts) == 2
+    ), "A desired layout is expected for both of vision.roi_pool's inputs"
+
+    desired_data_layout, desired_rois_layout = map(str, desired_layouts)
+    assert desired_data_layout != "default", "Data layout cannot be default"
+    assert desired_rois_layout == "default", "Rois layout must be default"
+
+    new_attrs["layout"] = desired_data_layout
+    # rois layout not change
+    if desired_data_layout in ["NCHW", "NHWC"]:
+        return relay.vision.roi_pool(data, rois, **new_attrs)
+
+    raise ValueError("Layout %s is not yet supported." % desired_data_layout)
+
+
 # roi_pool
 @reg.register_compute("vision.roi_pool")
 def compute_roi_pool(attrs, inputs, _):
     """Compute definition of roi_pool"""
-    assert attrs.layout == "NCHW"
+    assert attrs.layout == "NCHW", "only support nchw for now"
     return [
         topi.vision.rcnn.roi_pool_nchw(
             inputs[0],
diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc
index 8be38d0..f7bbf37 100644
--- a/src/relay/op/vision/rcnn_op.cc
+++ b/src/relay/op/vision/rcnn_op.cc
@@ -119,14 +119,35 @@ bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   ICHECK(roi_pool_attrs);
   ICHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
   ICHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
-  ICHECK_EQ(roi_pool_attrs->layout, "NCHW") << "ROI Pool only supports NCHW layout";
   // assign output type
-  std::vector<IndexExpr> oshape(
-      {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
+  std::vector<IndexExpr> oshape;
+  if (roi_pool_attrs->layout == "NCHW") {
+    oshape = {rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]};
+  } else if (roi_pool_attrs->layout == "NHWC") {
+    oshape = {rshape[0], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1], dshape[3]};
+  } else {
+    LOG(FATAL) << "vision.roi_pool does not support " << roi_pool_attrs->layout << " layout";
+  }
+
   reporter->Assign(types[2], TensorType(oshape, data->dtype));
   return true;
 }
 
+template <typename T>
+Array<Array<Layout> > ROIPoolInferCorrectLayout(const Attrs& attrs,
+                                                const Array<Layout>& new_in_layouts,
+                                                const Array<Layout>& old_in_layouts,
+                                                const Array<tvm::relay::Type>& old_in_types) {
+  // NOTE: Discard "const" qualifier here.
+  T* params = const_cast<T*>(attrs.as<T>());
+  Layout data_layout = params->layout;
+
+  // Layout inference needs to define the layout for all inputs and output data layouts.
+  // For roi_pool, the second inputs is 2-D tensor with shape [num_roi, 5].
+  // So, we set the layout as "N5".
+  return Array<Array<Layout> >{{data_layout, Layout("N5")}, {data_layout}};
+}
+
 Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
                  String layout) {
   auto attrs = make_object<ROIPoolAttrs>();
@@ -153,7 +174,8 @@ RELAY_REGISTER_OP("vision.roi_pool")
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("rois", "Tensor", "The input rois")
     .set_support_level(5)
-    .add_type_rel("ROIPool", ROIPoolRel);
+    .add_type_rel("ROIPool", ROIPoolRel)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ROIPoolInferCorrectLayout<ROIPoolAttrs>);
 
 TVM_REGISTER_NODE_TYPE(ProposalAttrs);
 
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 1fc5d39..7fc896a 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -966,6 +966,59 @@ def test_conv_strided_slice_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_roi_pool_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        rois = relay.var("rois", shape=(32, 5))
+        y = relay.vision.roi_pool(
+            y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NCHW"
+        )
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(64, 64, 3, 3))
+        x = relay.layout_transform(x, "NCHW", "NHWC")
+        weight1 = relay.layout_transform(weight1, "OIHW", "HWIO")
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        rois = relay.var("rois", shape=(32, 5))
+        y = relay.vision.roi_pool(
+            y, rois, pooled_size=(14, 14), spatial_scale=0.0625, layout="NHWC"
+        )
+        ret = relay.layout_transform(y, "NHWC", "NCHW")
+        y = relay.Function(analysis.free_vars(ret), ret)
+        return y
+
+    a = before()
+    desired_layouts = {
+        "nn.conv2d": ["NHWC", "HWIO"],
+        "vision.roi_pool": ["NHWC", "default"],
+    }
+    a = run_opt_pass(a, transform.ConvertLayout(desired_layouts))
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_default_keyword():
     """ Check that the default keyword selects correct TVM default layout. """
 
@@ -1253,6 +1306,7 @@ if __name__ == "__main__":
     test_conv_convert_kernel_layout()
     test_conv_transpose_convert_layout()
     test_conv_roi_align_convert_layout()
+    test_conv_roi_pool_convert_layout()
     test_conv_strided_slice_convert_layout()
     test_default_keyword()
     test_different_ops_convert_layout()