You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2019/06/20 03:56:00 UTC
[incubator-mxnet] branch master updated: fixing var-seq-len rnn
backward() operator (#15278)
This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 4d96671 fixing var-seq-len rnn backward() operator (#15278)
4d96671 is described below
commit 4d9667121ae6fb643f2a02ab15e25231ed756cde
Author: stephenrawls <10...@users.noreply.github.com>
AuthorDate: Wed Jun 19 20:55:33 2019 -0700
fixing var-seq-len rnn backward() operator (#15278)
* fixing var-seq-len rnn backward() operator
* updating var-length lstm to test backward pass
* removing bit of dbg print to stderr i forgot to remove earlier
* resolving TODO about using int32 for sequence_length
* setting rtol and atol similar to other tests in this file
---
src/operator/rnn-inl.h | 18 ++++++++++--
tests/python/gpu/test_gluon_gpu.py | 58 +++++++++++++++++++++++---------------
2 files changed, 51 insertions(+), 25 deletions(-)
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index 1046f01..328e28d 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
int dtype = in_types[rnn_enum::kData];
int itype = dtype;
if (param.use_sequence_length) {
- itype = in_types[rnn_enum::kSequenceLength];
- if (param.mode == rnn_enum::kLstm) itype -= 1;
+ size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+ if (param.mode != rnn_enum::kLstm) {
+ seq_len_input_idx -= 1;
+ }
+ itype = in_types[seq_len_input_idx];
}
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
@@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
// Hacky. This relies on fact that seq-len type is either the last input,
// or we aren't using seq-len input and this type should be same as dtype.
// Would prefer direct access to RNNParam object here but not sure how to get.
- int itype = inputs[inputs.size()-1].type_flag_;
+ int itype = outputs[outputs.size()-1].type_flag_;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
@@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
}
}
+
+ if (param.use_sequence_length) {
+ size_t seq_len_input_idx = rnn_enum::kSequenceLength;
+ if (param.mode != rnn_enum::kLstm) {
+ seq_len_input_idx -= 1;
+ }
+ in_data.push_back(outputs[seq_len_input_idx]);
+ }
+
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
});
});
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index b60814a..fc65029 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -227,19 +227,6 @@ def check_layer_bidirectional(size, in_size, proj_size):
def check_layer_bidirectional_varseqlen(size, in_size):
- class RefBiLSTMVarSeqLen(gluon.Block):
- def __init__(self, size, **kwargs):
- super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
- with self.name_scope():
- self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
- self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')
-
- def forward(self, inpt, sequence_length):
- fwd = self._lstm_fwd(inpt)
- bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
- bwd = self._lstm_bwd(bwd_inpt)
- bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
- return nd.concat(fwd, bwd, dim=2)
weights = {}
for d in ['l', 'r']:
weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
@@ -248,31 +235,58 @@ def check_layer_bidirectional_varseqlen(size, in_size):
weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
- ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
+ ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_')
net.initialize()
ref_net.initialize()
net_params = net.collect_params()
ref_net_params = ref_net.collect_params()
for k in weights:
net_params[k].set_data(weights[k])
- ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])
-
+ ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k])
batch_size = 10
num_timesteps = 11
data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
+ data_np = data.asnumpy()
- # TODO: figure out why int32 doesn't work here
- sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")
-
- net_output = net(data, sequence_length=sequence_length).asnumpy()
- ref_net_output = ref_net(data, sequence_length).asnumpy()
+ sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32")
sequence_length_np = sequence_length.asnumpy().astype("int32")
+ # Reference net is processing batch elements one at a time, so that it is "perfectly sized"
+ # Because of that, we need to accumulate gradients in reference net.
+ for p in ref_net.collect_params().values():
+ p.grad_req = 'add'
+
+ ref_net_output = []
+ with autograd.record():
+ net_output = net(data.copy(), sequence_length=sequence_length.copy())
+
+ for b in range(batch_size):
+ data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size)
+ ref_output_slice = ref_net(data_slice)
+ ref_net_output.append(ref_output_slice)
+
+ net_output_np = net_output.asnumpy()
+
# TODO: test state return value as well output
# Only compare the valid sections for each batch entry
for b in range(batch_size):
- assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
+ assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1),
+ rtol=1e-2, atol=1e-6)
+
+ # Now test backward
+ net_output.backward()
+
+ for ref_output_slice in ref_net_output:
+ ref_output_slice.backward()
+
+ ref_net_params = ref_net.collect_params()
+
+ for k in weights:
+ net_grad = net_params[k].grad()
+ ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad()
+ assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(),
+ rtol=1e-2, atol=1e-6)
@with_seed()