You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/18 11:00:13 UTC
[tvm] branch main updated: Fix eltwise alter op layout for broadcast axis (#11337)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7f1c54f96a Fix eltwise alter op layout for broadcast axis (#11337)
7f1c54f96a is described below
commit 7f1c54f96ae4099c178f45402f3c156a565dedce
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Wed May 18 14:00:07 2022 +0300
Fix eltwise alter op layout for broadcast axis (#11337)
* Fix eltwise alter op layout for broadcast axis
* Add tests on boradcast blocking over already blocked layout
---
src/relay/transforms/infer_layout_utils.cc | 3 +-
tests/python/relay/test_pass_alter_op_layout.py | 200 ++++++++++++++++++++++++
2 files changed, 202 insertions(+), 1 deletion(-)
diff --git a/src/relay/transforms/infer_layout_utils.cc b/src/relay/transforms/infer_layout_utils.cc
index 32838e09a4..efe886c29d 100644
--- a/src/relay/transforms/infer_layout_utils.cc
+++ b/src/relay/transforms/infer_layout_utils.cc
@@ -64,7 +64,8 @@ Layout AdjustSubordinateFactors(const Layout& src_layout, const Layout& old_layo
// 4) a) Check if this shape element is 1.
if (auto* shape_int = shape_val.as<IntImmNode>()) {
- if (shape_int->value == 1) {
+ // We can treat 1 as broadcast only if axis was not split before
+ if (shape_int->value == 1 && old_layout.IndexOf(LayoutAxis::Get(axis)) == -1) {
new_layout += "1";
is_shape_one = true;
}
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index cffc33b0bc..5aff77ad36 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -1602,6 +1602,206 @@ def test_alter_layout_nonscalar_broadcast():
np.testing.assert_allclose(res.numpy(), res1.numpy())
+def test_alter_layout_blocked_no_broadcast():
+ """Test boradcast operators working on already blocked layout"""
+
+ def before():
+ dtype = "float32"
+ input_shape = (1, 8, 16, 16, 4)
+ filter_shape = (1, 8, 4, 4, 4, 4)
+ bias_shape = (1, 1, 1, 1, 4)
+ A = relay.var("data", shape=input_shape, dtype=dtype)
+ B = relay.var("weight", shape=filter_shape, dtype=dtype)
+ C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+ conv = relay.nn.conv2d(
+ A,
+ B,
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4i4o",
+ padding=[3, 3, 0, 0],
+ strides=[2, 2],
+ out_dtype=dtype,
+ channels=4,
+ kernel_size=(4, 4),
+ )
+ bias = relay.op.add(conv, C)
+ bias = relay.Function(analysis.free_vars(bias), bias)
+ return bias
+
+ def expected():
+ return before()
+
+ def alter_conv2d(attrs, inputs, tinfos, out_type):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs["data_layout"] = "NCHW4c"
+ new_attrs["kernel_layout"] = "OIHW4i4o"
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ a = run_opt_pass(before(), transform.AlterOpLayout())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+ inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+ weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+ z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
+ mod = tvm.IRModule.from_expr(before())
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ with tvm.transform.PassContext(opt_level=4):
+ res = relay.build_module.create_executor(
+ "graph", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ with tvm.transform.PassContext(opt_level=0):
+ res1 = relay.build_module.create_executor(
+ "debug", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ np.testing.assert_allclose(res.numpy(), res1.numpy())
+
+
+def test_alter_layout_blocked_broadcast():
+ """Test boradcast operators working on already blocked layout"""
+
+ def before():
+ dtype = "float32"
+ input_shape = (1, 8, 16, 16, 4)
+ filter_shape = (1, 8, 4, 4, 4, 4)
+ bias_shape = (1, 1, 1, 1, 1)
+ A = relay.var("data", shape=input_shape, dtype=dtype)
+ B = relay.var("weight", shape=filter_shape, dtype=dtype)
+ C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+ conv = relay.nn.conv2d(
+ A,
+ B,
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4i4o",
+ padding=[3, 3, 0, 0],
+ strides=[2, 2],
+ out_dtype=dtype,
+ channels=4,
+ kernel_size=(4, 4),
+ )
+ bias = relay.op.add(conv, C)
+ bias = relay.Function(analysis.free_vars(bias), bias)
+ return bias
+
+ def expected():
+ return before()
+
+ def alter_conv2d(attrs, inputs, tinfos, out_type):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs["data_layout"] = "NCHW4c"
+ new_attrs["kernel_layout"] = "OIHW4i4o"
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ a = run_opt_pass(before(), transform.AlterOpLayout())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+ inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+ weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+ z = np.random.uniform(size=(1, 1, 1, 1, 1)).astype(np.float32)
+ mod = tvm.IRModule.from_expr(before())
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ with tvm.transform.PassContext(opt_level=4):
+ res = relay.build_module.create_executor(
+ "graph", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ with tvm.transform.PassContext(opt_level=0):
+ res1 = relay.build_module.create_executor(
+ "debug", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ np.testing.assert_allclose(res.numpy(), res1.numpy())
+
+
+def test_alter_layout_re_blocking_broadcast():
+ """Test of re-blocking shapes with boradcast operators"""
+
+ def before():
+ dtype = "float32"
+ input_shape = (1, 8, 16, 16, 4)
+ filter_shape = (1, 8, 4, 4, 4, 4)
+ bias_shape = (1, 1, 1, 1, 4)
+ A = relay.var("data", shape=input_shape, dtype=dtype)
+ B = relay.var("weight", shape=filter_shape, dtype=dtype)
+ C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+ conv = relay.nn.conv2d(
+ A,
+ B,
+ data_layout="NCHW4c",
+ kernel_layout="OIHW4i4o",
+ padding=[3, 3, 0, 0],
+ strides=[2, 2],
+ out_dtype=dtype,
+ channels=4,
+ kernel_size=(4, 4),
+ )
+ bias = relay.op.add(conv, C)
+ bias = relay.Function(analysis.free_vars(bias), bias)
+ return bias
+
+ def expected():
+ dtype = "float32"
+ input_shape = (1, 8, 16, 16, 4)
+ filter_shape = (1, 8, 4, 4, 4, 4)
+ bias_shape = (1, 1, 1, 1, 4)
+ A = relay.var("data", shape=input_shape, dtype=dtype)
+ B = relay.var("weight", shape=filter_shape, dtype=dtype)
+ C = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+ A = relay.layout_transform(A, src_layout="NCHW4c", dst_layout="NCHW2c")
+ B = relay.layout_transform(B, src_layout="OIHW4i4o", dst_layout="OIHW2i2o")
+
+ conv = relay.nn.conv2d(
+ A,
+ B,
+ data_layout="NCHW2c",
+ kernel_layout="OIHW2i2o",
+ padding=[3, 3, 0, 0],
+ strides=[2, 2],
+ out_dtype=dtype,
+ channels=4,
+ kernel_size=(4, 4),
+ )
+ C = relay.layout_transform(C, src_layout="NCHW4c", dst_layout="NCHW2c")
+ bias = relay.op.add(conv, C)
+ bias = relay.layout_transform(bias, src_layout="NCHW2c", dst_layout="NCHW4c")
+ bias = relay.Function(analysis.free_vars(bias), bias)
+ return bias
+
+ def alter_conv2d(attrs, inputs, tinfos, out_type):
+ data, weight = inputs
+ new_attrs = dict(attrs)
+ new_attrs["data_layout"] = "NCHW2c"
+ new_attrs["kernel_layout"] = "OIHW2i2o"
+ return relay.nn.conv2d(data, weight, **new_attrs)
+
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ a = run_opt_pass(before(), transform.AlterOpLayout())
+ b = run_opt_pass(expected(), transform.InferType())
+ assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\nExpected = \n" + str(b)
+
+ inp = np.random.uniform(size=(1, 8, 16, 16, 4)).astype(np.float32)
+ weight = np.random.uniform(size=(1, 8, 4, 4, 4, 4)).astype(np.float32)
+ z = np.random.uniform(size=(1, 1, 1, 1, 4)).astype(np.float32)
+ mod = tvm.IRModule.from_expr(before())
+ with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+ with tvm.transform.PassContext(opt_level=4):
+ res = relay.build_module.create_executor(
+ "graph", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ with tvm.transform.PassContext(opt_level=0):
+ res1 = relay.build_module.create_executor(
+ "debug", mod, target="llvm", device=tvm.cpu()
+ ).evaluate()(inp, weight, z)
+ np.testing.assert_allclose(res.numpy(), res1.numpy(), rtol=1e-5, atol=1e-5)
+
+
def test_broadcast_non_adaptable():
"""NCHW4c + [x, x, 4] and NCHW4c is being altered to NCHW"""