You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/04/20 02:44:46 UTC

[GitHub] [incubator-mxnet] sxjscience opened a new issue #18102: [Numpy] The gradient of einsum is wrong

sxjscience opened a new issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102


   The gradient of einsum is not reliable. The following is just one example. There are actually **multiple scenarios** in which the gradient is wrong. This operator has both performance issues as stated in https://github.com/apache/incubator-mxnet/issues/18043 and numeric problems.
   
   We should recommend the users **not to use the einsum in MXNet** util these issues are fixed.
   
   ```python
   
   import numpy as np
   import mxnet as mx
   from numpy.testing import assert_allclose
   mx.npx.set_np()
   
   ctx = mx.cpu()
   
   A = mx.np.random.normal(0, 1, (1, 1, 5, 3), ctx=ctx)
   B = mx.np.random.normal(0, 1, (1, 1, 3, 2), ctx=ctx)
   out_grad = mx.np.random.normal(0, 1, (1, 1, 5, 2), ctx=ctx)
   
   A.attach_grad()
   B.attach_grad()
   
   with mx.autograd.record():
       out = mx.np.einsum('bnij,bnjc->bnic', A, B)
       out.backward(out_grad)
   
   out_gt = A.asnumpy()[0, 0].dot(B.asnumpy()[0, 0])
   A_gt_grad = out_grad.asnumpy()[0, 0].dot(B.asnumpy()[0, 0].T)
   B_gt_grad = A.asnumpy()[0, 0].T.dot(out_grad.asnumpy()[0, 0])
   A_einsum_grad = A.grad.asnumpy()
   B_einsum_grad = B.grad.asnumpy()
   
   A.grad[:] = 0
   B.grad[:] = 0
   with mx.autograd.record():
       out = mx.np.matmul(A, B)
       out.backward(out_grad)
   A_matmul_grad = A.grad.asnumpy()
   B_matmul_grad = B.grad.asnumpy()
   
   
   assert_allclose(A_gt_grad, A_matmul_grad[0, 0], 1E-5, 1E-5)
   assert_allclose(B_gt_grad, B_matmul_grad[0, 0], 1E-5, 1E-5)
   assert_allclose(A_gt_grad, A_einsum_grad[0, 0], 1E-5, 1E-5)
   assert_allclose(B_gt_grad, B_einsum_grad[0, 0], 1E-5, 1E-5)
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience edited a comment on issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
sxjscience edited a comment on issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102#issuecomment-616278906


   @yzhliu @hzfan @szha @leezu  FYI


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] yzhliu commented on issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
yzhliu commented on issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102#issuecomment-621511857


   Assignee: @hanke580 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] hanke580 commented on issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
hanke580 commented on issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102#issuecomment-638264351


   PR #18419, gradient fixed
   All check passed.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] hzfan commented on issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
hzfan commented on issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102#issuecomment-616297738


   Thanks for bringing this up. Will check it out.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience commented on issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102#issuecomment-616278906


   @yzhliu @hzfan @szha @leezu 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] sxjscience closed issue #18102: [Numpy] The gradient of einsum is wrong

Posted by GitBox <gi...@apache.org>.
sxjscience closed issue #18102:
URL: https://github.com/apache/incubator-mxnet/issues/18102


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org