You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/06/20 03:56:00 UTC

[incubator-mxnet] branch master updated: fixing var-seq-len rnn backward() operator (#15278)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d96671  fixing var-seq-len rnn backward() operator (#15278)
4d96671 is described below

commit 4d9667121ae6fb643f2a02ab15e25231ed756cde
Author: stephenrawls <10...@users.noreply.github.com>
AuthorDate: Wed Jun 19 20:55:33 2019 -0700

    fixing var-seq-len rnn backward() operator (#15278)
    
    * fixing var-seq-len rnn backward() operator
    
    * updating var-length lstm to test backward pass
    
    * removing bit of dbg print to stderr i forgot to remove earlier
    
    * resolving TODO about using int32 for sequence_length
    
    * setting rtol and atol similar to other tests in this file
---
 src/operator/rnn-inl.h             | 18 ++++++++++--
 tests/python/gpu/test_gluon_gpu.py | 58 +++++++++++++++++++++++---------------
 2 files changed, 51 insertions(+), 25 deletions(-)

diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 1046f01..328e28d 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
   int dtype = in_types[rnn_enum::kData];
   int itype = dtype;
   if (param.use_sequence_length) {
-    itype = in_types[rnn_enum::kSequenceLength];
-    if (param.mode == rnn_enum::kLstm) itype -= 1;
+      size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+      if  (param.mode != rnn_enum::kLstm) {
+        seq_len_input_idx -= 1;
+      }
+    itype = in_types[seq_len_input_idx];
   }
 
   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
@@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
   // Hacky. This relies on fact that seq-len type is either the last input,
   // or we aren't using seq-len input and this type should be same as dtype.
   // Would prefer direct access to RNNParam object here but not sure how to get.
-  int itype = inputs[inputs.size()-1].type_flag_;
+  int itype = outputs[outputs.size()-1].type_flag_;
 
   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
       MSHADOW_TYPE_SWITCH(itype, IType, {
@@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
             }
           }
 
+
+          if (param.use_sequence_length) {
+            size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+            if  (param.mode != rnn_enum::kLstm) {
+              seq_len_input_idx -= 1;
+            }
+            in_data.push_back(outputs[seq_len_input_idx]);
+          }
+
           op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
         });
     });
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index b60814a..fc65029 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -227,19 +227,6 @@ def check_layer_bidirectional(size, in_size, proj_size):
 
 
 def check_layer_bidirectional_varseqlen(size, in_size):
-    class RefBiLSTMVarSeqLen(gluon.Block):
-        def __init__(self, size, **kwargs):
-            super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
-            with self.name_scope():
-                self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
-                self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')
-
-        def forward(self, inpt, sequence_length):
-            fwd = self._lstm_fwd(inpt)
-            bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
-            bwd = self._lstm_bwd(bwd_inpt)
-            bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
-            return nd.concat(fwd, bwd, dim=2)
     weights = {}
     for d in ['l', 'r']:
         weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
@@ -248,31 +235,58 @@ def check_layer_bidirectional_varseqlen(size, in_size):
         weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
 
     net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
-    ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
+    ref_net  = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_')
     net.initialize()
     ref_net.initialize()
     net_params = net.collect_params()
     ref_net_params = ref_net.collect_params()
     for k in weights:
         net_params[k].set_data(weights[k])
-        ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
-
+        ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k])
 
     batch_size = 10
     num_timesteps = 11
     data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
+    data_np = data.asnumpy()
 
-    # TODO: figure out why int32 doesn't work here
-    sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")
-
-    net_output = net(data, sequence_length=sequence_length).asnumpy()
-    ref_net_output = ref_net(data, sequence_length).asnumpy()
+    sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32")
     sequence_length_np = sequence_length.asnumpy().astype("int32")
 
+    # Reference net is processing batch elements one at a time, so that it is "perfectly sized"
+    # Because of that, we need to accumulate gradients in reference net.
+    for p in ref_net.collect_params().values():
+        p.grad_req = 'add'
+
+    ref_net_output = []
+    with autograd.record():
+        net_output = net(data.copy(), sequence_length=sequence_length.copy())
+
+        for b in range(batch_size):
+            data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size)
+            ref_output_slice = ref_net(data_slice)
+            ref_net_output.append(ref_output_slice)
+
+    net_output_np = net_output.asnumpy()
+
     # TODO: test state return value as well output
     # Only compare the valid sections for each batch entry
     for b in range(batch_size):
-        assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
+        assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1),
+                        rtol=1e-2, atol=1e-6)
+
+    # Now test backward
+    net_output.backward()
+
+    for ref_output_slice in ref_net_output:
+        ref_output_slice.backward()
+
+    ref_net_params = ref_net.collect_params()
+
+    for k in weights:
+        net_grad = net_params[k].grad()
+        ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad()
+        assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(),
+                            rtol=1e-2, atol=1e-6)
 
 
 @with_seed()