You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/15 02:14:40 UTC

[GitHub] Godricly opened a new pull request #10915: [MXNET-9704] An assertion check for invalid layout

Godricly opened a new pull request #10915: [MXNET-9704] An assertion check for invalid layout
URL: https://github.com/apache/incubator-mxnet/pull/10915
 
 
   ## Description ##
   a fix for issue #9704
   ## Checklist ##
   ### Essentials ###
   Please feel free to remove inapplicable items for your PR.
   - [X] Changes are complete (i.e. I finished coding on this PR)
   - [ ] All changes have test coverage:
     I don't know how to test for this one. A nasty testing script is posted downside.
   - [x] Code is well-documented: 
   - For user-facing API changes, API doc string has been updated. 
   - Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
   This link is broken. Please update it.
   - [x] To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change
   
   ### Changes ###
   - [x] assertion check for invalid layout of conv and pooling layers.
   
   ## Comments ##
     Assertion check for conv and pooling with documents updates. The pooling assertions was added by someone else before. I'm not sure if other layouts are supported too. 
   
   @gianlucacorrado @zhreshold @anjishnu
   
   a nasty debug script is posted here. feel free to modify it.
   ``` python
   import sys
   sys.path.insert(0,'../incubator-mxnet/python')
   from mxnet import nd
   from mxnet.gluon import nn
   from mxnet import gluon
   
   a = nd.zeros((5,2,10))
   conv1 = nn.Conv1D(5,1)
   conv1t = nn.Conv1DTranspose(5,1)
   pool1_ave = nn.AvgPool1D()
   pool1_max = nn.MaxPool1D()
   pool1_ave_g = nn.GlobalAvgPool1D()
   pool1_max_g = nn.GlobalMaxPool1D()
   conv1.initialize()
   conv1t.initialize()
   print(conv1(a).shape)
   print(conv1t(a).shape)
   print(pool1_ave(a).shape)
   print(pool1_max(a).shape)
   print(pool1_ave_g(a).shape)
   print(pool1_max_g(a).shape)
   
   a = nd.zeros((5,2,10,10))
   conv2 = nn.Conv2D(5, (1,1))
   conv2t = nn.Conv2DTranspose(5,(1,1), layout='NHWC')
   pool2_ave = nn.AvgPool2D()
   pool2_max = nn.MaxPool2D()
   pool2_ave_g = nn.GlobalAvgPool2D()
   pool2_max_g = nn.GlobalMaxPool2D()
   conv2.initialize()
   conv2t.initialize()
   print(conv2(a).shape)
   print(conv2t(a).shape)
   print(pool2_ave(a).shape)
   print(pool2_max(a).shape)
   print(pool2_ave_g(a).shape)
   print(pool2_max_g(a).shape)
   
   a = nd.zeros((5,2,10,10,10))
   conv3 = nn.Conv3D(5, (1,1,1))
   conv3t = nn.Conv3DTranspose(5,(1,1,1), layout='NDHWC')
   pool3_ave = nn.AvgPool3D()
   pool3_max = nn.MaxPool3D()
   pool3_ave_g = nn.GlobalAvgPool3D()
   pool3_max_g = nn.GlobalMaxPool3D()
   conv3.initialize()
   conv3t.initialize()
   print(conv3(a).shape)
   print(conv3t(a).shape)
   print(pool3_ave(a).shape)
   print(pool3_max(a).shape)
   print(pool3_ave_g(a).shape)
   print(pool3_max_g(a).shape)
   
   ```
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services