You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by li...@apache.org on 2020/06/21 20:32:57 UTC

[incubator-tvm] branch v0.6 updated: [BACKPORT-0.6] fix small bug about dense_grad (#5868)

This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/v0.6 by this push:
     new cf367be  [BACKPORT-0.6] fix small bug about dense_grad (#5868)
cf367be is described below

commit cf367be119d3e323437ade8870175715de65b1fc
Author: Yizhi Liu <li...@apache.org>
AuthorDate: Sun Jun 21 13:32:47 2020 -0700

    [BACKPORT-0.6] fix small bug about dense_grad (#5868)
    
    Co-authored-by: handar423 <47...@users.noreply.github.com>
    
    Co-authored-by: handar423 <47...@users.noreply.github.com>
---
 python/tvm/relay/op/_tensor_grad.py       | 7 ++++---
 tests/python/relay/test_op_grad_level2.py | 1 +
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py
index 944e51e..416bff6 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -388,9 +388,10 @@ def bias_add_grad(orig, grad):
 def dense_grad(orig, grad):
     """Returns [grad' @ weight, data @ grad']"""
     data, weight = orig.args
-    return [collapse_sum_like(transpose(grad) * weight, data),
-            collapse_sum_like(data * transpose(grad), weight)]
-
+    return [collapse_sum_like(_nn.dense(grad, transpose(weight),
+                                        units=weight.checked_type.shape[1]), data),
+            collapse_sum_like(_nn.dense(transpose(grad), transpose(data),
+                                        units=data.checked_type.shape[1]), weight)]
 
 @register_gradient("reshape")
 def reshape_grad(orig, grad):
diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py
index 57b1e2c..2dbe0d6 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -161,6 +161,7 @@ def verify_dense_grad(d_shape, w_shape):
 def test_dense_grad():
     verify_dense_grad((1, 8), (16, 8))
     verify_dense_grad((1, 4), (3, 4))
+    verify_dense_grad((5, 4), (3, 4))
 
 
 def verify_batch_flatten_grad(d_shape):