You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/18 11:00:13 UTC

[tvm] branch main updated: Fix eltwise alter op layout for broadcast axis (#11337)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7f1c54f96a Fix eltwise alter op layout for broadcast axis (#11337)
7f1c54f96a is described below

commit 7f1c54f96ae4099c178f45402f3c156a565dedce
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Wed May 18 14:00:07 2022 +0300

    Fix eltwise alter op layout for broadcast axis (#11337)
    
    * Fix eltwise alter op layout for broadcast axis
    
    * Add tests on boradcast blocking over already blocked layout
---
 src/relay/transforms/infer_layout_utils.cc      |   3 +-
 tests/python/relay/test_pass_alter_op_layout.py | 200 ++++++++++++++++++++++++
 2 files changed, 202 insertions(+), 1 deletion(-)

diff --git a/src/relay/transforms/infer_layout_utils.cc b/src/relay/transforms/infer_layout_utils.cc
index 32838e09a4..efe886c29d 100644
--- a/src/relay/transforms/infer_layout_utils.cc
+++ b/src/relay/transforms/infer_layout_utils.cc
@@ -64,7 +64,8 @@ Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layo
 
         // 4) a) Check if this shape element is 1.
         if (auto* shape_int = shape_val.as<IntImmNode>()) {
-          if (shape_int->value == 1) {
+          // We can treat 1 as broadcast only if axis was not split before
+          if (shape_int->value == 1 && old_layout.IndexOf(LayoutAxis::Get(axis)) == -1) {
             new_layout += "1";
             is_shape_one = true;
           }
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index cffc33b0bc..5aff77ad36 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -1602,6 +1602,206 @@ def test_alter_layout_nonscalar_broadcast():
     np.testing.assert_allclose(res.numpy(), res1.numpy())
 
 
+def test_alter_layout_blocked_no_broadcast():
+    """Test boradcast operators working on already blocked layout"""
+
+    def before():
+        dtype = "float32"
+        input_shape = (1, 8, 16, 16, 4)
+        filter_shape = (1, 8, 4, 4, 4, 4)
+        bias_shape = (1, 1, 1, 1, 4)
+        A = relay.var("data", shape=input_shape, dtype=dtype)
+        B = relay.var("weight", shape=filter_shape, dtype=dtype)
+        C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+        conv = relay.nn.conv2d(
+            A,
+            B,
+            data_layout="NCHW4c",
+            kernel_layout="OIHW4i4o",
+            padding=[3, 3, 0, 0],
+            strides=[2, 2],
+            out_dtype=dtype,
+            channels=4,
+            kernel_size=(4, 4),
+        )
+        bias = relay.op.add(conv, C)
+        bias = relay.Function(analysis.free_vars(bias), bias)
+        return bias
+
+    def expected():
+        return before()
+
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs["data_layout"] = "NCHW4c"
+        new_attrs["kernel_layout"] = "OIHW4i4o"
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        a = run_opt_pass(before(), transform.AlterOpLayout())
+        b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+    inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+    weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+    z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
+    mod = tvm.IRModule.from_expr(before())
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        with tvm.transform.PassContext(opt_level=4):
+            res = relay.build_module.create_executor(
+                "graph", mod, target="llvm", device=tvm.cpu()
+            ).evaluate()(inp, weight, z)
+    with tvm.transform.PassContext(opt_level=0):
+        res1 = relay.build_module.create_executor(
+            "debug", mod, target="llvm", device=tvm.cpu()
+        ).evaluate()(inp, weight, z)
+    np.testing.assert_allclose(res.numpy(), res1.numpy())
+
+
+def test_alter_layout_blocked_broadcast():
+    """Test boradcast operators working on already blocked layout"""
+
+    def before():
+        dtype = "float32"
+        input_shape = (1, 8, 16, 16, 4)
+        filter_shape = (1, 8, 4, 4, 4, 4)
+        bias_shape = (1, 1, 1, 1, 1)
+        A = relay.var("data", shape=input_shape, dtype=dtype)
+        B = relay.var("weight", shape=filter_shape, dtype=dtype)
+        C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+        conv = relay.nn.conv2d(
+            A,
+            B,
+            data_layout="NCHW4c",
+            kernel_layout="OIHW4i4o",
+            padding=[3, 3, 0, 0],
+            strides=[2, 2],
+            out_dtype=dtype,
+            channels=4,
+            kernel_size=(4, 4),
+        )
+        bias = relay.op.add(conv, C)
+        bias = relay.Function(analysis.free_vars(bias), bias)
+        return bias
+
+    def expected():
+        return before()
+
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs["data_layout"] = "NCHW4c"
+        new_attrs["kernel_layout"] = "OIHW4i4o"
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        a = run_opt_pass(before(), transform.AlterOpLayout())
+        b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+    inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+    weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+    z = np.random.uniform(size=(1, 1, 1, 1, 1)).astype(np.float32)
+    mod = tvm.IRModule.from_expr(before())
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        with tvm.transform.PassContext(opt_level=4):
+            res = relay.build_module.create_executor(
+                "graph", mod, target="llvm", device=tvm.cpu()
+            ).evaluate()(inp, weight, z)
+    with tvm.transform.PassContext(opt_level=0):
+        res1 = relay.build_module.create_executor(
+            "debug", mod, target="llvm", device=tvm.cpu()
+        ).evaluate()(inp, weight, z)
+    np.testing.assert_allclose(res.numpy(), res1.numpy())
+
+
+def test_alter_layout_re_blocking_broadcast():
+    """Test of re-blocking shapes with boradcast operators"""
+
+    def before():
+        dtype = "float32"
+        input_shape = (1, 8, 16, 16, 4)
+        filter_shape = (1, 8, 4, 4, 4, 4)
+        bias_shape = (1, 1, 1, 1, 4)
+        A = relay.var("data", shape=input_shape, dtype=dtype)
+        B = relay.var("weight", shape=filter_shape, dtype=dtype)
+        C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+        conv = relay.nn.conv2d(
+            A,
+            B,
+            data_layout="NCHW4c",
+            kernel_layout="OIHW4i4o",
+            padding=[3, 3, 0, 0],
+            strides=[2, 2],
+            out_dtype=dtype,
+            channels=4,
+            kernel_size=(4, 4),
+        )
+        bias = relay.op.add(conv, C)
+        bias = relay.Function(analysis.free_vars(bias), bias)
+        return bias
+
+    def expected():
+        dtype = "float32"
+        input_shape = (1, 8, 16, 16, 4)
+        filter_shape = (1, 8, 4, 4, 4, 4)
+        bias_shape = (1, 1, 1, 1, 4)
+        A = relay.var("data", shape=input_shape, dtype=dtype)
+        B = relay.var("weight", shape=filter_shape, dtype=dtype)
+        C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+        A = relay.layout_transform(A, src_layout="NCHW4c", dst_layout="NCHW2c")
+        B = relay.layout_transform(B, src_layout="OIHW4i4o", dst_layout="OIHW2i2o")
+
+        conv = relay.nn.conv2d(
+            A,
+            B,
+            data_layout="NCHW2c",
+            kernel_layout="OIHW2i2o",
+            padding=[3, 3, 0, 0],
+            strides=[2, 2],
+            out_dtype=dtype,
+            channels=4,
+            kernel_size=(4, 4),
+        )
+        C = relay.layout_transform(C, src_layout="NCHW4c", dst_layout="NCHW2c")
+        bias = relay.op.add(conv, C)
+        bias = relay.layout_transform(bias, src_layout="NCHW2c", dst_layout="NCHW4c")
+        bias = relay.Function(analysis.free_vars(bias), bias)
+        return bias
+
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs["data_layout"] = "NCHW2c"
+        new_attrs["kernel_layout"] = "OIHW2i2o"
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        a = run_opt_pass(before(), transform.AlterOpLayout())
+        b = run_opt_pass(expected(), transform.InferType())
+        assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+    inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+    weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+    z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
+    mod = tvm.IRModule.from_expr(before())
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        with tvm.transform.PassContext(opt_level=4):
+            res = relay.build_module.create_executor(
+                "graph", mod, target="llvm", device=tvm.cpu()
+            ).evaluate()(inp, weight, z)
+    with tvm.transform.PassContext(opt_level=0):
+        res1 = relay.build_module.create_executor(
+            "debug", mod, target="llvm", device=tvm.cpu()
+        ).evaluate()(inp, weight, z)
+    np.testing.assert_allclose(res.numpy(), res1.numpy(), rtol=1e-5, atol=1e-5)
+
+
 def test_broadcast_non_adaptable():
     """NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""