You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/10/18 12:22:43 UTC

[GitHub] ustcfd opened a new issue #12863: Rewrite the GRUCell, Error:TypeError: forward() takes 3 positional arguments but 4 were given

ustcfd opened a new issue #12863: Rewrite the GRUCell, Error:TypeError: forward() takes 3 positional arguments but 4 were given
URL: https://github.com/apache/incubator-mxnet/issues/12863
 
 
   In order to add some graph convolution operations to GRUCell, I have to rewrite the GRUCell of Mxnet.
   Technically, I should add a parameter of the input, i.e, adjacency matrix(adj). Unfortunately, I got the error “TypeError: forward() takes 3 positional arguments but 4 were given.”
   
   ```
   class mgcn_grucell(HybridRecurrentCell):
       def __init__(self, hidden_size,
                i2h_weight_initializer=None, h2h_weight_initializer=None,
                i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                input_size=0, prefix=None, params=None):
       super(mgcn_grucell, self).__init__(prefix=prefix, params=params)
       //mgcn_cell and gru_cell differ only in the terms of that there is an added parameter adj in the rewritten function hybrid_forward().
   
       def hybrid_forward(self, F, inputs, states,adj):
       //......
   
    class HybridRecurrentCell(RecurrentCell, HybridBlock):
             """HybridRecurrentCell supports hybridize."""
             def __init__(self, prefix=None, params=None):
            super(HybridRecurrentCell, self).__init__(prefix=prefix, params=params)
             def hybrid_forward(self, F, x, *args, **kwargs):
                 raise NotImplementedError
   
    class RecurrentCell(Block):
          //.......
          //forword added adj
          def forward(self, inputs, states,adj):
            self._counter += 1
            return super(RecurrentCell, self).forward(inputs, states,adj)
   ```
   Actually, HybridRecurrentCell inherits from RecurrentCell. It still doesn’t work when I add the adj to the forward() in RecurrentCell .
   So, how can I figure it out?
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services