You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/03/23 19:52:16 UTC

[GitHub] [incubator-mxnet] sxjscience opened a new issue #17893: [Bug][Numpy] Wrong gradient of np.where

sxjscience opened a new issue #17893: [Bug][Numpy] Wrong gradient of np.where
URL: https://github.com/apache/incubator-mxnet/issues/17893
 
 
   ## Description
   Example 1: Using np.where(array, array, scalar)
   
   ```python
   import mxnet as mx
   mx.npx.set_np()
   
   a = mx.np.array([1, 0, 1])
   b = mx.np.array([2, 3, 4])
   
   b.attach_grad()
   
   with mx.autograd.record():
       c = mx.np.where(a, b, -1)
       c.backward()
   print(b.grad)
   ```
   Output: [0. 1. 0.]
   
   Example 2: Using np.where(array, array, array)
   ```python
   
   import mxnet as mx
   mx.npx.set_np()
   
   a = mx.np.array([1, 0, 1])
   b = mx.np.array([2, 3, 4])
   
   b.attach_grad()
   
   with mx.autograd.record():
       c = mx.np.where(a, b, mx.np.array([-1, -1, -1]))
       c.backward()
   print(b.grad)
   ```
   Output: [1. 0. 1.]
   
   The second one is correct.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] sxjscience closed issue #17893: [Bug][Numpy] Wrong gradient of np.where

Posted by GitBox <gi...@apache.org>.
sxjscience closed issue #17893: [Bug][Numpy] Wrong gradient of np.where
URL: https://github.com/apache/incubator-mxnet/issues/17893
 
 
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] hgt312 commented on issue #17893: [Bug][Numpy] Wrong gradient of np.where

Posted by GitBox <gi...@apache.org>.
hgt312 commented on issue #17893: [Bug][Numpy] Wrong gradient of np.where
URL: https://github.com/apache/incubator-mxnet/issues/17893#issuecomment-603106484
 
 
   Fixed in https://github.com/apache/incubator-mxnet/pull/17899

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

[GitHub] [incubator-mxnet] sxjscience commented on issue #17893: [Bug][Numpy] Wrong gradient of np.where

Posted by GitBox <gi...@apache.org>.
sxjscience commented on issue #17893: [Bug][Numpy] Wrong gradient of np.where
URL: https://github.com/apache/incubator-mxnet/issues/17893#issuecomment-603039124
 
 
   @hgt312 I think that there might be some issue in the backward + scalar case of where: https://github.com/apache/incubator-mxnet/blob/9a355ebc1dfc5c087d41bb24c946e7f773e01af9/src/operator/numpy/np_where_op.cc#L236-L254

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services