You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@mxnet.apache.org by GitBox <gi...@apache.org> on 2022/02/04 10:23:05 UTC

[GitHub] [incubator-mxnet] mazeltovlee opened a new issue #20875: Inconsistent behavior between infer_shape and infer_shape_partial when setting batch_size to be 0

mazeltovlee opened a new issue #20875:
URL: https://github.com/apache/incubator-mxnet/issues/20875


   ## Description
   The `infer_shape` cannot infer the shape when the `batch_size` (first dim of input) is 0. However, `infer_shape_partial` can still output the correct shape.
   Given the same graph:
   ```
   net = mx.sym.Variable('data', shape=(0,10,10,3))
   net = mx.sym.reshape(data=net, shape=(0, 100, 3))
   net = mx.sym.expand_dims(data=net, axis=2)
   ```
   When I run `net.infer_shape()`, it raises an UserWarning and output `(None, None, None)`
   ```
   /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:5: UserWarning: Cannot decide shape for the following arguments (0s in shape means unknown dimensions). Consider providing them as input:
   	data: (0, 10, 10, 3)
     """
   (None, None, None)
   ```
   But when I run `net.infer_shape_partial()`, it outputs a reasonable output shape:
   ```
   ([(0, 10, 10, 3)], [(0, 100, 1, 3)], [])
   ```
   
   I find out that if I change the shape of the first input from `(0,10,10,3)` to `(10,0,10,3)`, both `shape_infer` and `shape_infer_partial` can output correct output shape: `([(10, 10, 10, 3)], [(10, 100, 1, 3)], [])`. I expected that mxnet can still infer the symbol's shape if the `batch_size` (first dim of the symbol) is 0. If setting `batch_size` to 0 is not acceptable for shape inference, maybe mxnet can make the behavior between `shape_infer` and `shape_infer_partial` consistent.
   
   Moreover, another issue I found is that if I add another operator: `mx.sym.flatten()`, `net.infer_shape_partial()` will also unable to infer the output shape:
   ```
   import mxnet as mx
   net = mx.sym.Variable('data', shape=(0,10,10,3))
   net = mx.sym.flatten(data=net, axes=[0,3,2,1])
   net = mx.sym.reshape(data=net, shape=(0, 100, 3))
   net = mx.sym.expand_dims(data=net, axis=2)
   print(net.infer_shape_partial())  # ([(0, 10, 10, 3)], [()], [])
   ```
   
   
   ## To Reproduce
   You can either copy-and-paste the above code or directly run this colab link to reproduce these issues:
   https://colab.research.google.com/drive/1KZp5Qw4vuMi9aA6kNIkPQBWGQlbTcDtv?usp=sharing
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@mxnet.apache.org
For additional commands, e-mail: issues-help@mxnet.apache.org