You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by co...@apache.org on 2022/01/03 17:33:17 UTC
[tvm] branch main updated: Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)
This is an automated email from the ASF dual-hosted git repository.
comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 11379f7 Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)
11379f7 is described below
commit 11379f710bf9bebf4a7a0cf6c0943899047d11ed
Author: masahi <ma...@gmail.com>
AuthorDate: Tue Jan 4 02:32:36 2022 +0900
Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821)
* Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false)
* black
* lint
---
src/relay/op/tensor/reduce.cc | 2 +-
tests/python/relay/test_pass_alter_op_layout.py | 19 +++++++++++++++++++
2 files changed, 20 insertions(+), 1 deletion(-)
diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc
index 5001925..d844bb5 100644
--- a/src/relay/op/tensor/reduce.cc
+++ b/src/relay/op/tensor/reduce.cc
@@ -176,7 +176,7 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
if (params->exclude) {
// The primal axis is not reduced, so keep the input packed dim.
inferred_out_string += packed_dim;
- } else {
+ } else if (params->keepdims) {
// If the primal axis is part of reduce axes in the original layout, the inner dim
// becomes 1 after reduction.
inferred_out_string += "1" + layout_dim;
diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py
index 7514a93..ea7fe0b 100644
--- a/tests/python/relay/test_pass_alter_op_layout.py
+++ b/tests/python/relay/test_pass_alter_op_layout.py
@@ -24,6 +24,7 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
from tvm.relay.testing import run_infer_type
import numpy as np
import tvm.testing
+from tvm.relay import testing
def run_opt_pass(expr, passes):
@@ -1452,5 +1453,23 @@ def test_conv2d_strided_slice_packed_to_unpacked():
assert tvm.ir.structural_equal(a, b)
+def test_conv2d_reduce_channels():
+ x = relay.var("data", shape=(1, 8, 48, 48))
+ y = relay.nn.conv2d(
+ data=x,
+ weight=relay.var("weight"),
+ kernel_size=(1, 1),
+ channels=8,
+ dilation=1,
+ strides=(47, 47),
+ )
+ z = relay.argmin(y, axis=1)
+
+ mod, params = testing.create_workload(z)
+
+ with tvm.transform.PassContext(opt_level=3):
+ relay.build(mod, params=params, target="llvm")
+
+
if __name__ == "__main__":
pytest.main([__file__])