You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2020/10/05 23:54:33 UTC

[incubator-tvm] branch master updated: Fix a bug with Alter Op Layout (#6626)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 311eca4  Fix a bug with Alter Op Layout (#6626)
311eca4 is described below

commit 311eca49a696f137e8cac6f4b9ba485b80bda0ee
Author: Matthew Brookhart <mb...@octoml.ai>
AuthorDate: Mon Oct 5 17:54:21 2020 -0600

    Fix a bug with Alter Op Layout (#6626)
    
    * Regression test for a Scalar type issue in Alter Op Layout
    
    * fix the regression test by avoiding the Scalar optimization if types aren't defined
---
 src/relay/transforms/transform_layout.h         |  2 +-
 tests/python/relay/test_pass_alter_op_layout.py | 89 +++++++++++++++++++++++++
 2 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h
index 19632de..bf9bcb9 100644
--- a/src/relay/transforms/transform_layout.h
+++ b/src/relay/transforms/transform_layout.h
@@ -126,7 +126,7 @@ class TransformMemorizer : public ObjectRef {
     if (src_layout.ndim_primal() < dst_layout.ndim_primal()) {
       // If scalar, then no need of layout transformation as scalar can be broadcasted easily even
       // if the other operand has a transformed layout.
-      if (IsScalar(input_expr)) {
+      if (input_expr->checked_type_.defined() && IsScalar(input_expr)) {
         return raw;
       }
       int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal();
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index 7b242c4..4d50840 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -463,6 +463,95 @@ def test_alter_layout_scalar():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_alter_layout_scalar_regression():
+    """regression test where scalar fails"""
+
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 16))
+        bias = relay.var("bias", shape=(1, 1, 1, 16))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=16,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.add(y, bias)
+        mean = relay.mean(y, axis=3, exclude=True)
+        var = relay.variance(y, axis=3, exclude=True)
+        gamma = relay.var("gamma")
+        beta = relay.var("beta")
+        y = relay.nn.batch_norm(y, gamma, beta, mean, var, axis=3)
+        y = y[0]
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def alter_conv2d(attrs, inputs, tinfos, out_type):
+        data, weight = inputs
+        new_attrs = dict(attrs)
+        new_attrs["data_layout"] = "NCHW16c"
+        return relay.nn.conv2d(data, weight, **new_attrs)
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight = relay.var("weight", shape=(3, 3, 64, 16))
+        bias = relay.var("bias", shape=(1, 1, 1, 16))
+        x = relay.layout_transform(x, src_layout="NHWC", dst_layout="NCHW")
+        x = relay.layout_transform(x, src_layout="NCHW", dst_layout="NCHW16c")
+        weight = relay.layout_transform(weight, src_layout="HWIO", dst_layout="OIHW")
+        y = relay.nn.conv2d(
+            x, weight, channels=16, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
+        )
+        bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW")
+        bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
+        add = relay.add(y, bias)
+        y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
+        y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NHWC")
+        mean = relay.mean(y, axis=3, exclude=True)
+        var = relay.variance(y, axis=3, exclude=True)
+        denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
+        gamma = relay.var("gamma", shape=(16,))
+        denom = denom * gamma
+        denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2)
+        denom_expand2 = relay.expand_dims(denom_expand1, axis=0)
+        denom_nchwc16 = relay.layout_transform(
+            denom_expand2, src_layout="NCHW", dst_layout="NCHW16c"
+        )
+        out = add * denom_nchwc16
+        beta = relay.var("beta", shape=(16,))
+        numerator = (-mean) * denom + beta
+        numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2)
+        numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0)
+        numerator_nchwc16 = relay.layout_transform(
+            numerator_expand2, src_layout="NCHW", dst_layout="NCHW16c"
+        )
+        out = out + numerator_nchwc16
+        out = relay.layout_transform(out, src_layout="NCHW16c", dst_layout="NCHW")
+        y = relay.layout_transform(out, src_layout="NCHW", dst_layout="NHWC")
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
+        a = before()
+        desired_layouts = {"nn.conv2d": ["NCHW", "default"], "nn.batch_norm": ["NHWC", "default"]}
+        a = run_opt_pass(
+            a,
+            [
+                transform.InferType(),
+                relay.transform.ConvertLayout(desired_layouts),
+                transform.SimplifyInference(),
+                transform.CanonicalizeOps(),
+                transform.AlterOpLayout(),
+            ],
+        )
+        b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
+
+
 def test_alter_layout_concatenate():
     """ NCHW, NHWC and corner case concatenate layout transform."""