You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/12 01:47:10 UTC

[GitHub] [incubator-mxnet] YutingZhang opened a new issue #14396: mx.nd.Custom not working in subprocess

YutingZhang opened a new issue #14396: mx.nd.Custom not working in subprocess
URL: https://github.com/apache/incubator-mxnet/issues/14396
 
 
   mx.nd.Custom gets stuck at subprocess. 
   
   The following code to replicate the error is from https://github.com/wkcn/MobulaOP/issues/40#issuecomment-471803878
   
   ```python
   from concurrent import futures
   
   import mxnet as mx
   import sys
   from mobula.testing import assert_almost_equal
   sys.path.append('../../')  # Add MobulaOP Path
   
   class AdditionOP(mx.operator.CustomOp):
       def __init__(self):
           super(AdditionOP, self).__init__()
       def forward(self, is_train, req, in_data, out_data, aux):
           out_data[0][:] = in_data[0] + in_data[1]
       def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
           in_grad[0][:] = out_grad[0]
           in_grad[1][:] = out_grad[0]
   
   @mx.operator.register("AdditionOP")
   class AdditionOPProp(mx.operator.CustomOpProp):
       def __init__(self):
           super(AdditionOPProp, self).__init__()
       def list_arguments(self):
           return ['a', 'b']
       def list_outputs(self):
           return ['output']
       def infer_shape(self, in_shape):
           return in_shape, [in_shape[0]]
       def create_operator(self, ctx, shapes, dtypes):
           return AdditionOP()
   
   def foo():
       a = mx.nd.array([1, 2, 3])
       b = mx.nd.array([4, 5, 6])
   
       a.attach_grad()
       b.attach_grad()
   
       print("REC")
       with mx.autograd.record():
           c = mx.nd.Custom(a, b, op_type='AdditionOP')
   
       dc = mx.nd.array([7, 8, 9])
       c.backward(dc)
   
       assert_almost_equal(a + b, c)
       assert_almost_equal(a.grad, dc)
       assert_almost_equal(b.grad, dc)
   
       print('Okay :-)')
       print('a + b = c \n {} + {} = {}'.format(a.asnumpy(), b.asnumpy(), c.asnumpy()))
   
   def main():
       ex = futures.ProcessPoolExecutor(1)
       r = ex.submit(foo)
       r.result()
   
   if __name__ == '__main__':
       main()
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services