You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/04 07:31:40 UTC

[GitHub] [incubator-mxnet] reminisce opened a new pull request #14315: Shape support scalar tensor

reminisce opened a new pull request #14315: Shape support scalar tensor
URL: https://github.com/apache/incubator-mxnet/pull/14315
 
 
   ## Description ##
   This PR made necessary changes in the shape-related infra code to support scalars, aka, zero-dim tensors. 
   See this [RFC](https://github.com/apache/incubator-mxnet/issues/14253) for more information.
   
   Implemented `np.sum` to verify the correctness with the following script.
   
   After merging this PR, we can get started to make all the existing infer shape functions accommodate the new definition of empty shapes. Implementation of NumPy operators can be carried out in parallel.
   
   Note that before we change all the existing infer functions to accommodate the new empty shape definition, the CI will fail without surprise. When we work on specific operators, we can test those locally to verify the correctness and submit PRs. After all the infer functions are revised, the CI should pass eventually. Then, we can merge the dev branch to the master.
   
   ```python
   import mxnet as mx
   from mxnet import numpy as np
   
   
   shape = (2, 2)
   ret = np.sum(mx.nd.ones(shape))
   print("Imperative invoke result==============")
   print(ret)
   
   data = mx.sym.var('data', shape=shape)
   ret = mx.sym.numpy.sum(data)
   exe = ret.simple_bind(ctx=mx.cpu(), data=(2, 2))
   exe.forward(data=mx.nd.ones(shape))
   print("Symbol bind result===============")
   print(exe.outputs[0])
   
   print("CachedOp result===============")
   func = mx.nd.CachedOp(ret)
   print(func(mx.nd.ones(shape)))
   ```
   ```
   Imperative invoke result==============
   
   4.0
   <NDArray  @cpu(0)>
   Symbol bind result===============
   
   4.0
   <NDArray  @cpu(0)>
   CachedOp result===============
   
   4.0
   <NDArray  @cpu(0)>
   ```
   @junrushao1994 @szha @eric-haibin-lin @zheng-da @yzhliu

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services