You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2021/01/05 01:18:11 UTC

[tvm] branch main updated: [ConvertLayout] slice_like support (#7184)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d052752  [ConvertLayout] slice_like support (#7184)
d052752 is described below

commit d05275298d9e630af6d8ff958753fd010759935c
Author: Cody Yu <co...@gmail.com>
AuthorDate: Mon Jan 4 17:17:53 2021 -0800

    [ConvertLayout] slice_like support (#7184)
---
 src/relay/op/tensor/transform.cc                  | 41 +++++++++++++
 tests/python/relay/test_pass_convert_op_layout.py | 70 +++++++++++++++++++++++
 2 files changed, 111 insertions(+)

diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 19ca612..1ff428c 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -2752,6 +2752,46 @@ Expr MakeSliceLike(Expr data, Expr shape_like, Array<Integer> axes) {
   return Call(op, {data, shape_like}, Attrs(attrs), {});
 }
 
+Array<Array<Layout>> SliceLikeInferCorrectLayout(const Attrs& attrs,
+                                                 const Array<Layout>& new_in_layouts,
+                                                 const Array<Layout>& old_in_layouts,
+                                                 const Array<tvm::relay::Type>& old_in_types) {
+  Array<Integer> new_axes;
+  if (old_in_layouts.defined() && new_in_layouts.defined()) {
+    ICHECK_EQ(new_in_layouts.size(), 2);
+    ICHECK_EQ(new_in_layouts[0]->name, new_in_layouts[1]->name);
+    ICHECK_EQ(old_in_layouts.size(), 2);
+    ICHECK_EQ(old_in_layouts[0]->name, old_in_layouts[1]->name);
+
+    auto old_layout = old_in_layouts[0];
+    auto new_layout = new_in_layouts[0];
+
+    // Discard "const" qualifier.
+    auto* params = const_cast<SliceLikeAttrs*>(attrs.as<SliceLikeAttrs>());
+    ICHECK(params != nullptr);
+
+    for (auto axis : params->axes) {
+      auto new_axis = new_layout.IndexOf(old_layout[axis->value]);
+      // Cannot find the target axis in the new layout.
+      if (new_axis == -1) {
+        new_axes.clear();
+        break;
+      }
+      new_axes.push_back(new_axis);
+    }
+    if (!new_axes.empty()) {
+      params->axes = std::move(new_axes);
+      return Array<Array<Layout>>({{new_layout, new_layout}, {new_layout}});
+    }
+  }
+
+  if (old_in_layouts.defined()) {
+    ICHECK_EQ(old_in_layouts.size(), 2);
+    return {{old_in_layouts[0], old_in_layouts[1]}, {old_in_layouts[1]}};
+  }
+  return Array<Array<Layout>>({{Layout::Undef(), Layout::Undef()}, {Layout::Undef()}});
+}
+
 Array<te::Tensor> SliceLikeCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
                                    const Type& out_type) {
   const auto* param = attrs.as<SliceLikeAttrs>();
@@ -2801,6 +2841,7 @@ RELAY_REGISTER_OP("slice_like")
     .set_support_level(10)
     .add_type_rel("SliceLike", SliceLikeRel)
     .set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", SliceLikeInferCorrectLayout)
     .set_attr<TOpPattern>("TOpPattern", kInjective);
 
 // relay.layout_transform
diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py
index 6765d1f..4c4bb9d 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -499,6 +499,75 @@ def test_bn_convert_layout():
     assert len(has_lt) == 1
 
 
+def test_slice_like_convert_layout():
+    def verify_slice_like(after, expected_axes):
+        # Verify if the slice_like after the convert layout has the expected axes.
+        has_expected = list()
+        checker = lambda x: has_expected.append(
+            isinstance(x, tvm.relay.expr.Call)
+            and x.op.name == "slice_like"
+            and str(x.attrs.axes) == str(expected_axes)
+        )
+        relay.analysis.post_order_visit(after, checker)
+        assert any(has_expected)
+
+    def func_nhwc():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        out = relay.slice_like(y, y, axes=[1, 2])
+        return relay.Function(analysis.free_vars(out), out)
+
+    after = run_opt_pass(func_nhwc(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
+    verify_slice_like(after, [2, 3])
+
+    def func_nchw():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight1 = relay.var("weight1", shape=(32, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        out = relay.slice_like(y, y, axes=[2, 3])
+        return relay.Function(analysis.free_vars(out), out)
+
+    after = run_opt_pass(func_nchw(), transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
+    verify_slice_like(after, [1, 2])
+
+    def func_vars():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        # z has no layout information so convert layout won't happen.
+        z = relay.var("y", shape=(1, 56, 56, 32))
+        out = relay.slice_like(y, z, axes=[1, 2])
+        return relay.Function(analysis.free_vars(out), out)
+
+    after = run_opt_pass(func_vars(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]}))
+    verify_slice_like(after, [1, 2])
+
+
 def test_resnet_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -1412,6 +1481,7 @@ if __name__ == "__main__":
     test_conv_concat_convert_layout()
     test_dual_path_convert_layout()
     test_bn_convert_layout()
+    test_slice_like_convert_layout()
     test_resnet_convert_layout()
     test_scalar_convert_layout()
     test_conv_bn_convert_layout()