[GitHub] szha closed pull request #9977: Cpu lstm inference

szha closed pull request #9977: Cpu lstm inference

diff --git a/python/mxnet/gluon/rnn/ b/python/mxnet/gluon/rnn/
index 204f3c9bd50..2fac39923c7 100644
--- a/python/mxnet/gluon/rnn/
+++ b/python/mxnet/gluon/rnn/
@@ -23,6 +23,7 @@
 from __future__ import print_function
 __all__ = ['RNN', 'LSTM', 'GRU']
+from ...autograd import is_training
 from ... import ndarray
 from .. import Block
 from . import rnn_cell
@@ -185,15 +186,17 @@ def forward(self, inputs, states=None):
             for i in range(self._dir):
                 self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
-        if inputs.context.device_type == 'gpu':
-            out = self._forward_gpu(inputs, states)
+        if inputs.context.device_type == 'gpu' or \
+            (not is_training() and self._mode == 'lstm'):
+            out = self._forward_kernel(inputs, states)
-            out = self._forward_cpu(inputs, states)
+            out = self._forward(inputs, states)
         # out is (output, state)
         return out[0] if skip_states else out
-    def _forward_cpu(self, inputs, states):
+    def _forward(self, inputs, states):
+        """forward using gluon cell"""
         ns = len(states)
         axis = self._layout.find('T')
         states = sum(zip(*((j for j in i) for i in states)), ())
@@ -207,7 +210,8 @@ def _forward_cpu(self, inputs, states):
         return outputs, new_states
-    def _forward_gpu(self, inputs, states):
+    def _forward_kernel(self, inputs, states):
+        """ forward using CUDNN or CPU kenrel"""
         if self._layout == 'NTC':
             inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
         ctx = inputs.context
diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h
index b4735b8eec6..13c077dd9e3 100644
--- a/src/operator/rnn-inl.h
+++ b/src/operator/rnn-inl.h
@@ -34,7 +34,11 @@
 #include <vector>
 #include <string>
 #include <utility>
+#include "./math.h"
+#include "./math_functions-inl.h"
 #include "./operator_common.h"
+#include "./mshadow_op.h"
+#include "./linalg.h"
 namespace mxnet {
 namespace op {
@@ -153,6 +157,212 @@ class RNNOp : public Operator {
   RNNParam param_;
 };  // class RNNOp
+template<typename DType>
+class RNNOp<cpu, DType> : public Operator {
+ public:
+  explicit RNNOp(RNNParam param) {
+    this->param_ = param;
+    // RNN Mode
+    param_.lstm_q_ = false;
+    switch (param_.mode) {
+      case rnn_enum::kLstm:
+        param_.lstm_q_ = true;
+        break;
+      default:
+        LOG(FATAL) << "only LSTM is implmented on CPU";
+    }
+  }
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_args) {
+    // Layout TNC
+    CHECK(!ctx.is_train) << "only inference mode is available"
+      "for cpu at the moment.";
+    size_t in_expected = param_.lstm_q_ ? 4 : 3;
+    size_t out_expected = param_.lstm_q_ ? 3 : 2;
+    if (!param_.state_outputs)
+      LOG(FATAL) << "no state outputs is currently not supported for cpu.";
+    CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
+    CHECK_EQ(in_data.size(), in_expected);
+    CHECK_EQ(out_data.size(), out_expected);
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    // get input + output tensors
+    // w layout i2h_w, h2h_w, i2h_b, h2h_b
+    Tensor<cpu, 3, DType> x =
+        in_data[rnn_enum::kData].get<cpu, 3, DType>(s);  // TNC
+    Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
+    Tensor<cpu, 3, DType> hx =
+        in_data[rnn_enum::kState].get<cpu, 3, DType>(s);  // LNC
+    Tensor<cpu, 3, DType> y =
+        out_data[rnn_enum::kOut].get<cpu, 3, DType>(s);  // TNC
+    int64_t seq_len = x.shape_[0];
+    int64_t num_layers = hx.shape_[0];
+    int64_t batch_size = x.shape_[1];
+    int64_t h_channel = hx.shape_[2];
+    int64_t in_channel = x.shape_[2];
+    Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
+      .get_with_shape<cpu, 2, DType>(
+          mshadow::Shape2(seq_len * batch_size, in_channel), s);  // (T*N)C
+    Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
+      .get_with_shape<cpu, 2, DType>(
+          mshadow::Shape2(
+              y.shape_[0] * y.shape_[1], y.shape_[2]), s);  // (T*N)C
+    CHECK(x.CheckContiguous());
+    CHECK(w.CheckContiguous());
+    CHECK(hx.CheckContiguous());
+    CHECK(y.CheckContiguous());
+    if (param_.lstm_q_) {
+      const size_t kNumMat = 4;
+      int64_t fused_h_ch = kNumMat * h_channel;
+      int64_t h_size = batch_size * fused_h_ch;
+      int64_t num_dir = 1 + param_.bidirectional;
+      int64_t h2h_w_size = h_channel * fused_h_ch;
+      Tensor<cpu, 3, DType> cx =
+          in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
+      CHECK(cx.CheckContiguous());
+      Tensor<cpu, 3, DType> cy =
+          out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
+      Tensor<cpu, 3, DType> hy =
+          out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
+      CHECK(cy.CheckContiguous());
+      CHECK(hy.CheckContiguous());
+      DType* workspace_addr =
+      static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
+          .get_host_space_internal(sizeof(DType) *
+                                  (seq_len * h_size + h_size
+                                  + y.shape_[0] * y.shape_[1] * y.shape_[2])));
+      Tensor<cpu, 3, DType> i2h_y(
+          workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
+      Tensor<cpu, 2, DType> i2h_y_flatten(
+          workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
+      Tensor<cpu, 2, DType> h2h_y(workspace_addr
+          + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
+      Tensor<cpu, 3, DType> y_tmp(workspace_addr
+          + (seq_len + 1) * h_size, y.shape_);
+      Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
+          + (seq_len + 1) * h_size, y_flatten.shape_);
+      CHECK(i2h_y.CheckContiguous());
+      CHECK(h2h_y.CheckContiguous());
+      CHECK(y_tmp.CheckContiguous());
+      for (int64_t layer = 0; layer < num_layers; layer++) {
+        int reverse_dir = 0;
+        int out_tmp = 0;
+        if (param_.bidirectional && layer % 2)
+          reverse_dir = 1;
+        if (layer / num_dir % 2 == 0)
+          out_tmp = 1;
+        mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
+            (layer < num_dir) ? in_channel : num_dir * h_channel);
+        mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
+        int64_t start = layer < num_dir ?
+            (layer * (in_channel * fused_h_ch + h2h_w_size)) :  // input layer
+              (num_dir * (in_channel * fused_h_ch + h2h_w_size)
+              + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
+        Tensor<cpu, 2, DType> i2h_w(w.dptr_ + start, i2h_w_shape);
+        start += layer < num_dir ?
+            in_channel * fused_h_ch : h2h_w_size * num_dir;
+        Tensor<cpu, 2, DType> h2h_w(w.dptr_ + start, h2h_w_shape);
+        start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
+            + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
+              + layer * fused_h_ch * 2;
+        Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
+        start += fused_h_ch;
+        Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
+        if (out_tmp) {
+          linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
+              i2h_y_flatten, false, true, s);
+        } else {
+          linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
+              i2h_y_flatten, false, true, s);
+        }
+        i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
+        for (int64_t t = 0; t < seq_len; t++) {
+          int64_t timestep = t;
+          if (reverse_dir)
+            timestep = seq_len - 1 - t;
+          linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
+              false, true, s);
+          h2h_y += repmat(h2h_b, batch_size);
+          // fused element-wise ops
+          LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
+              y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
+                hy[layer], cy[layer], batch_size, h_channel, t,
+                reverse_dir, out_tmp && (layer == num_layers - 1));
+        }
+      }
+    } else {
+      LOG(FATAL) << "only LSTM is available for cpu at the moment.";
+    }
+  }
+  virtual void Backward(const OpContext &ctx,
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+      const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_args) {
+    LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
+  }
+ private:
+  RNNParam param_;
+  void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
+                                  const Tensor<cpu, 2, DType> &cx,
+                                  const Tensor<cpu, 2, DType> &h2h_y,
+                                  const Tensor<cpu, 2, DType> &y,
+                                  // holding intermediate layer output
+                                  const Tensor<cpu, 2, DType> &tmp,
+                                  const Tensor<cpu, 2, DType> &hy,
+                                  const Tensor<cpu, 2, DType> &cy,
+                                  const int64_t batch_size,
+                                  const int64_t h_channel,
+                                  const int64_t t,
+                                  const int reverse_dir,
+                                  const int copy_tmp2y) {
+    int64_t length = batch_size * h_channel;
+    #pragma omp parallel for
+    for (int64_t ji = 0; ji < length; ++ji) {
+      int64_t j = ji / h_channel;  // batch dim
+      int64_t i = ji % h_channel;
+      int64_t f = i + h_channel;
+      int64_t c = i + h_channel * 2;
+      int64_t o = i + h_channel * 3;
+      int64_t j_pos = j * h_channel * 4;
+      h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i];
+      h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f];
+      h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o];
+      h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c];
+      h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i]));
+      h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f]));
+      h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o]));
+      h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]);
+      cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i])
+          + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c];
+      hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]);
+      tmp[j][i + h_channel * reverse_dir] = hy[j][i];
+      if (copy_tmp2y) {
+        y[j][i] = tmp[j][i];
+        if (reverse_dir)
+          y[j][i + h_channel] = tmp[j][i + h_channel];
+      }
+    }
+  }
+};  // class RNNOp
 template<typename xpu>
 Operator* CreateOp(RNNParam param, int dtype);
diff --git a/src/operator/ b/src/operator/
index 908428b383c..a60adbcd2fb 100644
--- a/src/operator/
+++ b/src/operator/
@@ -30,7 +30,6 @@ namespace mxnet {
 namespace op {
 Operator *CreateOp<cpu>(RNNParam param, int dtype) {
-  LOG(FATAL) << "RNN is only available for gpu at the moment.";
   Operator *op = NULL;
     op = new RNNOp<cpu, DType>(param);
diff --git a/tests/python/gpu/ b/tests/python/gpu/
index 2cef29c03fa..0d149685873 100644
--- a/tests/python/gpu/
+++ b/tests/python/gpu/
@@ -1523,6 +1523,23 @@ def check_rnn_layer(layer):
     for g, c in zip(gs, cs):
         assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
+def check_rnn_layer_w_rand_inputs(layer):
+    layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)])
+    x = mx.nd.uniform(shape=(10, 16, 30))
+    with mx.gpu(0):
+        x = x.copyto(mx.gpu(0))
+        states = layer.begin_state(16)
+        go, gs = layer(x, states)
+    with mx.cpu(0):
+        x = x.copyto(mx.cpu(0))
+        states = layer.begin_state(16)
+        co, cs = layer(x, states)
+    assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6)
+    for g, c in zip(gs, cs):
+        assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6)
 def test_rnn_layer():
     check_rnn_layer(gluon.rnn.RNN(100, num_layers=3))
@@ -1531,7 +1548,7 @@ def test_rnn_layer():
     check_rnn_layer(gluon.rnn.GRU(100, num_layers=3))
     check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
+    check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True))
 def test_sequence_reverse():
diff --git a/tests/python/unittest/ b/tests/python/unittest/
index 871deeb26c4..f22b13d6575 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -67,6 +67,20 @@ def test_lstm_forget_bias():
                                forget_bias * np.ones(100, ), np.zeros((2 * 100,))])
     assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias)
+def test_lstm_cpu_inference():
+    # should behave the same as lstm cell
+    EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213],
+                                  [0.72045636, 0.72045636, 0.95215213, 0.95215213]],
+                                 [[0.95215213, 0.95215213, 0.72045636, 0.72045636],
+                                  [0.95215213, 0.95215213, 0.72045636, 0.72045636]]])
+    x = mx.nd.ones(shape=(2, 2, 2))
+    model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
+    model.initialize(mx.init.One())
+    y = model(x).asnumpy()
+    mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
+                                      rtol=1e-3, atol=1e-5)
 def test_gru():
     cell = gluon.rnn.GRUCell(100, prefix='rnn_')


