You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/01/03 17:33:17 UTC

[tvm] branch main updated: Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)

This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 11379f7  Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)
11379f7 is described below

commit 11379f710bf9bebf4a7a0cf6c0943899047d11ed
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Jan 4 02:32:36 2022 +0900

    Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)
    
    * Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false)
    
    * black
    
    * lint
---
 src/relay/op/tensor/reduce.cc                   |  2 +-
 tests/python/relay/test_pass_alter_op_layout.py | 19 +++++++++++++++++++
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index 5001925..d844bb5 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -176,7 +176,7 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
           if (params->exclude) {
             // The primal axis is not reduced, so keep the input packed dim.
             inferred_out_string += packed_dim;
-          } else {
+          } else if (params->keepdims) {
             // If the primal axis is part of reduce axes in the original layout, the inner dim
             // becomes 1 after reduction.
             inferred_out_string += "1" + layout_dim;
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index 7514a93..ea7fe0b 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -24,6 +24,7 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
 from tvm.relay.testing import run_infer_type
 import numpy as np
 import tvm.testing
+from tvm.relay import testing
 
 
 def run_opt_pass(expr, passes):
@@ -1452,5 +1453,23 @@ def test_conv2d_strided_slice_packed_to_unpacked():
         assert tvm.ir.structural_equal(a, b)
 
 
+def test_conv2d_reduce_channels():
+    x = relay.var("data", shape=(1, 8, 48, 48))
+    y = relay.nn.conv2d(
+        data=x,
+        weight=relay.var("weight"),
+        kernel_size=(1, 1),
+        channels=8,
+        dilation=1,
+        strides=(47, 47),
+    )
+    z = relay.argmin(y, axis=1)
+
+    mod, params = testing.create_workload(z)
+
+    with tvm.transform.PassContext(opt_level=3):
+        relay.build(mod, params=params, target="llvm")
+
+
 if __name__ == "__main__":
     pytest.main([__file__])