You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/16 18:40:26 UTC

[GitHub] szha closed pull request #10940: [WIP] Fused LSTM Cell

szha closed pull request #10940: [WIP] Fused LSTM Cell
URL: https://github.com/apache/incubator-mxnet/pull/10940
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py
index 281aba45257..c5c929fb8a0 100644
--- a/python/mxnet/gluon/rnn/rnn_cell.py
+++ b/python/mxnet/gluon/rnn/rnn_cell.py
@@ -297,6 +297,63 @@ def __init__(self, prefix=None, params=None):
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
 
+class _FusedBaseRNNCell(HybridRecurrentCell): # pylint: disable=abstract-method
+    """Implementation of recurrent layers."""
+    def __init__(self, hidden_size, input_size,
+                 i2h_weight_initializer, h2h_weight_initializer,
+                 i2h_bias_initializer, h2h_bias_initializer,
+                 mode, **kwargs):
+        super(_FusedBaseRNNCell, self).__init__(**kwargs)
+        self._hidden_size = hidden_size
+        self._input_size = input_size
+        self._mode = mode
+        num_gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
+        self._gates = num_gates
+        self.i2h_weight = self.params.get('i2h_weight', shape=(num_gates*hidden_size, input_size),
+                                          init=i2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.h2h_weight = self.params.get('h2h_weight', shape=(num_gates*hidden_size, hidden_size),
+                                          init=h2h_weight_initializer,
+                                          allow_deferred_init=True)
+        self.i2h_bias = self.params.get('i2h_bias', shape=(num_gates*hidden_size,),
+                                        init=i2h_bias_initializer,
+                                        allow_deferred_init=True)
+        self.h2h_bias = self.params.get('h2h_bias', shape=(num_gates*hidden_size,),
+                                        init=h2h_bias_initializer,
+                                        allow_deferred_init=True)
+
+    def hybrid_forward(self, F, inputs, states, i2h_weight,
+                       h2h_weight, i2h_bias, h2h_bias):
+        prefix = 't%d_'%self._counter
+        params = F.concat(i2h_weight.reshape((-1,)),
+                          h2h_weight.reshape((-1,)),
+                          i2h_bias.reshape((-1,)),
+                          h2h_bias.reshape((-1,)), dim=0)
+        states = [s.reshape((-4, 1, -1, 0)) for s in states]
+        rnn = F.RNN(inputs.reshape((-4, 1, -1, 0)), params, *states, state_size=self._hidden_size,
+                    num_layers=1, state_outputs=True, mode=self._mode, name=prefix+'fused')
+
+        if self._mode == 'lstm':
+            name_suffix = ['out', 'out', 'state']
+            rnn = [rnn[i].reshape((-3, -2), name=prefix+suffix) for i, suffix
+                   in zip(range(3), name_suffix)]
+            outputs, states = rnn[0], [rnn[1], rnn[2]]
+        else:
+            name_suffix = ['out', 'out']
+            rnn = [rnn[i].reshape((-3, -2), name=prefix+suffix) for i, suffix
+                   in zip(range(2), name_suffix)]
+            outputs, states = rnn[0], [rnn[1]]
+
+        return outputs, states
+
+    def __repr__(self):
+        s = '{name}({mapping})'
+        shape = self.i2h_weight.shape
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
+        return s.format(name=self.__class__.__name__,
+                        mapping=mapping,
+                        **self.__dict__)
+
 
 class RNNCell(HybridRecurrentCell):
     r"""Elman RNN recurrent neural network cell.
@@ -398,7 +455,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
         return output, [output]
 
 
-class LSTMCell(HybridRecurrentCell):
+class LSTMCell(_FusedBaseRNNCell):
     r"""Long-Short Term Memory (LSTM) network cell.
 
     Each call computes the following function:
@@ -457,22 +514,13 @@ def __init__(self, hidden_size,
                  i2h_weight_initializer=None, h2h_weight_initializer=None,
                  i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
                  input_size=0, prefix=None, params=None):
-        super(LSTMCell, self).__init__(prefix=prefix, params=params)
-
-        self._hidden_size = hidden_size
-        self._input_size = input_size
-        self.i2h_weight = self.params.get('i2h_weight', shape=(4*hidden_size, input_size),
-                                          init=i2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.h2h_weight = self.params.get('h2h_weight', shape=(4*hidden_size, hidden_size),
-                                          init=h2h_weight_initializer,
-                                          allow_deferred_init=True)
-        self.i2h_bias = self.params.get('i2h_bias', shape=(4*hidden_size,),
-                                        init=i2h_bias_initializer,
-                                        allow_deferred_init=True)
-        self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
-                                        init=h2h_bias_initializer,
-                                        allow_deferred_init=True)
+        super(LSTMCell, self).__init__(hidden_size=hidden_size, input_size=input_size,
+                                       i2h_weight_initializer=i2h_weight_initializer,
+                                       h2h_weight_initializer=h2h_weight_initializer,
+                                       i2h_bias_initializer=i2h_bias_initializer,
+                                       h2h_bias_initializer=h2h_bias_initializer,
+                                       mode='lstm',
+                                       prefix=prefix, params=params)
 
     def state_info(self, batch_size=0):
         return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
@@ -481,34 +529,6 @@ def state_info(self, batch_size=0):
     def _alias(self):
         return 'lstm'
 
-    def __repr__(self):
-        s = '{name}({mapping})'
-        shape = self.i2h_weight.shape
-        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
-        return s.format(name=self.__class__.__name__,
-                        mapping=mapping,
-                        **self.__dict__)
-
-    def hybrid_forward(self, F, inputs, states, i2h_weight,
-                       h2h_weight, i2h_bias, h2h_bias):
-        prefix = 't%d_'%self._counter
-        i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
-                               num_hidden=self._hidden_size*4, name=prefix+'i2h')
-        h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
-                               num_hidden=self._hidden_size*4, name=prefix+'h2h')
-        gates = i2h + h2h
-        slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
-        in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
-        forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
-        in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
-        out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
-        next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
-                                   name=prefix+'state')
-        next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
-                                  name=prefix+'out')
-
-        return next_h, [next_h, next_c]
-
 
 class GRUCell(HybridRecurrentCell):
     r"""Gated Rectified Unit (GRU) network cell.
@@ -590,7 +610,7 @@ def _alias(self):
     def __repr__(self):
         s = '{name}({mapping})'
         shape = self.i2h_weight.shape
-        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])
+        mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // 3)
         return s.format(name=self.__class__.__name__,
                         mapping=mapping,
                         **self.__dict__)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index f017d7e65e7..0e35b881cd7 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1046,11 +1046,11 @@ def reshape(self, *shape, **kwargs):
         elif not shape:
             shape = kwargs.get('shape')
             assert shape, "Shape must be provided."
-        if not all(k in ['shape', 'reverse'] for k in kwargs):
+        if not all(k in ['shape', 'reverse', 'name'] for k in kwargs):
             raise TypeError(
                 "Got unknown keywords in reshape: {}. " \
                 "Accepted keyword arguments are 'shape' and 'reverse'.".format(
-                    ', '.join([k for k in kwargs if k not in ['shape', 'reverse']])))
+                    ', '.join([k for k in kwargs if k not in ['shape', 'reverse', 'name']])))
         reverse = kwargs.get('reverse', False)
         handle = NDArrayHandle()
 
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index a7fcb1c8817..f32e65b7de0 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -39,14 +39,17 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
   TShape dshape;
-  index_t size = 0;
-  bool has_zero = false;
-  int axis = -1;
+  int64_t size = 0, out_size = 0;
+  size_t has_zero = 0;
+  int axis = -1, zero_index = -1;
   for (int i = 0; i < param_.num_args; ++i) {
     TShape tmp = (*in_shape)[i];
     if (tmp.ndim()) {
       axis = CheckAxis(param_.dim, tmp.ndim());
-      has_zero = tmp[axis] == 0 || has_zero;
+      if (tmp[axis] == 0) {
+        has_zero++;
+        zero_index = i;  // only used when there's only one 0.
+      }
       size += tmp[axis];
       tmp[axis] = 0;
       shape_assign(&dshape, tmp);
@@ -56,6 +59,7 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
   TShape tmp = (*out_shape)[0];
   if (tmp.ndim()) {
     axis = CheckAxis(param_.dim, tmp.ndim());
+    out_size = tmp[axis];
     tmp[axis] = 0;
     shape_assign(&dshape, tmp);
   }
@@ -71,7 +75,16 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,
   CHECK(shape_assign(&(*out_shape)[0], dshape))
       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
 
-  return dshape.Size() != 0;
+  if (has_zero == 1 && out_size) {
+    (*in_shape)[zero_index][axis] = out_size - size;
+  }
+
+  for (const auto& ishape : *in_shape) {
+    if (shape_is_none(ishape)) {
+      return false;
+    }
+  }
+  return !shape_is_none(out_shape->at(0));
 }
 
 static bool ConcatType(const nnvm::NodeAttrs& attrs,
diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h
index c46233c367f..cf1311187c4 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -152,18 +152,33 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& shape,
 }
 
 inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
-                             std::vector<TShape> *in_attrs,
-                             std::vector<TShape> *out_attrs) {
+                         std::vector<TShape> *in_attrs,
+                         std::vector<TShape> *out_attrs) {
   const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
   CHECK_EQ(out_attrs->size(), 1U);
-  const TShape &dshape = (*in_attrs)[0];
+  TShape &dshape = (*in_attrs)[0];
   if (dshape.ndim() == 0) return false;
   TShape oshape;
   if (param_.shape.ndim() != 0) {
     oshape = InferReshapeShape(param_.shape, dshape, param_.reverse);
+
+    // index and counter for number of unknown dimensions, used for inverse shape inference.
+    int zero_index = -1, num_zeros = 0;
+    int64_t size_prod = 1;
+    for (size_t i = 0; i < dshape.ndim(); i++) {
+      if (dshape[i] == 0) {
+        num_zeros++;
+        zero_index = i;
+      } else {
+        size_prod *= dshape[i];
+      }
+    }
+    if (num_zeros == 1 && !shape_is_none((*out_attrs)[0])) {
+      CHECK(shape_assign(&oshape, (*out_attrs)[0]));
+      dshape[zero_index] = oshape.Size() / size_prod;
+    }
   } else if (param_.target_shape.ndim()) {
-    LOG(INFO) << "Using target_shape will be deprecated.";
     oshape = param_.target_shape;
     int neg_count = 0;
     index_t inf_idx = 0;
@@ -189,7 +204,7 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
     << "Target: " << oshape
     << "\nSource: " << dshape;
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
-  return true;
+  return !shape_is_none(dshape) && !shape_is_none((*out_attrs)[0]);
 }
 
 inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py
index f22b13d6575..871deeb26c4 100644
--- a/tests/python/unittest/test_gluon_rnn.py
+++ b/tests/python/unittest/test_gluon_rnn.py
@@ -67,20 +67,6 @@ def test_lstm_forget_bias():
                                forget_bias * np.ones(100, ), np.zeros((2 * 100,))])
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)
 
-def test_lstm_cpu_inference():
-    # should behave the same as lstm cell
-    EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
-                                  [0.72045636, 0.72045636, 0.95215213, 0.95215213]],
-                                 [[0.95215213, 0.95215213, 0.72045636, 0.72045636],
-                                  [0.95215213, 0.95215213, 0.72045636, 0.72045636]]])
-    x = mx.nd.ones(shape=(2, 2, 2))
-    model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
-    model.initialize(mx.init.One())
-    y = model(x).asnumpy()
-
-    mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
-                                      rtol=1e-3, atol=1e-5)
-    
 
 def test_gru():
     cell = gluon.rnn.GRUCell(100, prefix='rnn_')


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services