You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/09 14:45:07 UTC

[GitHub] TaoLv commented on a change in pull request #10104: [WIP][MXNET-107] Fused RNN implementation for CPU

TaoLv commented on a change in pull request #10104: [WIP][MXNET-107] Fused RNN implementation for CPU
URL: https://github.com/apache/incubator-mxnet/pull/10104#discussion_r187064942
 
 

 ##########
 File path: tests/python/unittest/test_operator.py
 ##########
 @@ -91,6 +91,24 @@ def test_lstm_bidirectional():
 
     check_rnn_consistency(stack, fused, T, N, I, H)
 
+# Currently, fused LSTM operator doesn't support dropout.
+# Will change this test after dropout is supported
+@with_seed()
+def test_lstm_dropout():
+    X = mx.sym.Variable('x')
+    Params = mx.sym.Variable('params')
+    HX = mx.sym.Variable('state')
+    CX = mx.sym.Variable('state_cell')
+    T, N, I, H = 300, 20, 800, 800
+    rnn = mx.sym.RNN(data=X, parameters=Params, state=HX, state_cell=CX,
+                     state_size=H, num_layers=5, mode='lstm', p=0.5, state_outputs=True, name='LSTM')
+    exe = rnn.simple_bind(ctx=mx.cpu(), x=(T, N, I))
+    try:
+        out = exe.forward(is_train=False)
+        out[0].wait_to_read()
+        assert False  # should not reach here
+    except mx.base.MXNetError as err:
 
 Review comment:
   Yes. Also to ensure the failure happens at a proper position and correct error message is presented. Follow @reminisce 's idea in #10844 .

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services