You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/10 05:28:51 UTC

[GitHub] haven-jeon opened a new issue #12778: Cannot print parameter summary of embedding layer

haven-jeon opened a new issue #12778: Cannot print parameter summary of embedding layer 
URL: https://github.com/apache/incubator-mxnet/issues/12778
 
 
   ## Description
   0  parameter of Embedding layer, when print_summary() 
   
   ## Environment info (Required)
   
   MXNet 1.3.0
   
   Package used (Python/R/Scala/Julia):
   I am using python
   
   ## Minimum reproducible example
   
   ```
   class SentClassificationModel(gluon.HybridBlock):
       def __init__(self, vocab_size, num_embed, **kwargs):
           super(SentClassificationModel, self).__init__(**kwargs)
           with self.name_scope():
               self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed)
               self.drop = nn.Dropout(0.3)
               self.fc = nn.Dense(100, activation='relu')
               self.out = nn.Dense(2)  
       def hybrid_forward(self, F ,inputs):
           em_out = self.drop(self.embed(inputs))
           fc_out = self.fc(em_out) 
           return(self.out(fc_out))
   
   ctx = mx.gpu()
   
   model = SentClassificationModel(vocab_size = len(vocab.idx_to_token), num_embed=50)
   
   model.initialize(mx.init.Xavier(),ctx=ctx)
   model.hybridize()
   
   mx.viz.print_summary(
       model(mx.sym.var('data')), 
       shape={'data':(1,30)}, #set your shape here
   )
   ```
   
   > ________________________________________________________________________________________________________________________
   Layer (type)                                        Output Shape            Param #     Previous Layer                  
   ========================================================================================================================
   data(null)                                          30                      0                                           
   ________________________________________________________________________________________________________________________
   sentclassificationmodel0_embedding0_fwd(Embedding)  30x50                   0           data                            
   ________________________________________________________________________________________________________________________
   sentclassificationmodel0_dropout0_fwd(Dropout)      30x50                   0           sentclassificationmodel0_embeddi
   ________________________________________________________________________________________________________________________
   sentclassificationmodel0_dense0_fwd(FullyConnected) 100                     3100        sentclassificationmodel0_dropout
   ________________________________________________________________________________________________________________________
   sentclassificationmodel0_dense0_relu_fwd(Activation)100                     0           sentclassificationmodel0_dense0_
   ________________________________________________________________________________________________________________________
   sentclassificationmodel0_dense1_fwd(FullyConnected) 2                       202         sentclassificationmodel0_dense0_
   ========================================================================================================================
   Total params: 3302
   ________________________________________________________________________________________________________________________
   
   
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services