You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2020/06/21 20:32:57 UTC
[incubator-tvm] branch v0.6 updated: [BACKPORT-0.6] fix small bug
about dense_grad (#5868)
This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/v0.6 by this push:
new cf367be [BACKPORT-0.6] fix small bug about dense_grad (#5868)
cf367be is described below
commit cf367be119d3e323437ade8870175715de65b1fc
Author: Yizhi Liu <li...@apache.org>
AuthorDate: Sun Jun 21 13:32:47 2020 -0700
[BACKPORT-0.6] fix small bug about dense_grad (#5868)
Co-authored-by: handar423 <47...@users.noreply.github.com>
Co-authored-by: handar423 <47...@users.noreply.github.com>
---
python/tvm/relay/op/_tensor_grad.py | 7 ++++---
tests/python/relay/test_op_grad_level2.py | 1 +
2 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 944e51e..416bff6 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -388,9 +388,10 @@ def bias_add_grad(orig, grad):
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
- return [collapse_sum_like(transpose(grad) * weight, data),
- collapse_sum_like(data * transpose(grad), weight)]
-
+ return [collapse_sum_like(_nn.dense(grad, transpose(weight),
+ units=weight.checked_type.shape[1]), data),
+ collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
+ units=data.checked_type.shape[1]), weight)]
@register_gradient("reshape")
def reshape_grad(orig, grad):
diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py
index 57b1e2c..2dbe0d6 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -161,6 +161,7 @@ def verify_dense_grad(d_shape, w_shape):
def test_dense_grad():
verify_dense_grad((1, 8), (16, 8))
verify_dense_grad((1, 4), (3, 4))
+ verify_dense_grad((5, 4), (3, 4))
def verify_batch_flatten_grad(d_shape):