You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/16 19:26:49 UTC

[GitHub] [incubator-mxnet] arcadiaphy commented on issue #14396: mx.nd.Custom not working in subprocess

arcadiaphy commented on issue #14396: mx.nd.Custom not working in subprocess
URL: https://github.com/apache/incubator-mxnet/issues/14396#issuecomment-473577236
 
 
   After #14363, the threads in custom is created after running first custom operator, the new script to reproduce the bug:
   ```
   from concurrent import futures
   
   import mxnet as mx
   import sys
   
   class AdditionOP(mx.operator.CustomOp):
       def __init__(self):
           super(AdditionOP, self).__init__()
       def forward(self, is_train, req, in_data, out_data, aux):
           out_data[0][:] = in_data[0] + in_data[1]
       def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
           in_grad[0][:] = out_grad[0]
           in_grad[1][:] = out_grad[0]
   
   @mx.operator.register("AdditionOP")
   class AdditionOPProp(mx.operator.CustomOpProp):
       def __init__(self):
           super(AdditionOPProp, self).__init__()
       def list_arguments(self):
           return ['a', 'b']
       def list_outputs(self):
           return ['output']
       def infer_shape(self, in_shape):
           return in_shape, [in_shape[0]]
       def create_operator(self, ctx, shapes, dtypes):
           return AdditionOP()
   
   def foo():
       a = mx.nd.array([1, 2, 3])
       b = mx.nd.array([4, 5, 6])
   
       a.attach_grad()
       b.attach_grad()
   
       print("REC")
       with mx.autograd.record():
           c = mx.nd.Custom(a, b, op_type='AdditionOP')
   
       dc = mx.nd.array([7, 8, 9])
       c.backward(dc)
   
       print('Okay :-)')
       print('a + b = c \n {} + {} = {}'.format(a.asnumpy(), b.asnumpy(), c.asnumpy()))
   
   def main():
       foo()
       ex = futures.ProcessPoolExecutor(1)
       r = ex.submit(foo)
       r.result()
   
   if __name__ == '__main__':
       main()
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services