You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2016/07/13 00:00:05 UTC

incubator-systemml git commit: [SYSTEMML-618][SYSTEMML-807][SYSTEMML-808] Adding new LSTM and RNN layers.

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 289eeaaff -> d926ea052


[SYSTEMML-618][SYSTEMML-807][SYSTEMML-808] Adding new LSTM and RNN layers.

This adds two new recurrent neural net layers: LSTM, RNN.  Generically,
recurrent neural nets operate over sequences of examples, and at each
timestep, a single example and the output of the network at the previous
timestep are both passed in as inputs.  The `rnn` layer implements this
generic algorithm, and an LSTM (Long Short Term Memory cell) is simply
an RNN with a slightly more complex algorithm that maintains an
additional internal cell state.  Generally, the LSTM is preferred to the
vanilla RNN.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d926ea05
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d926ea05
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d926ea05

Branch: refs/heads/master
Commit: d926ea0522af05aed2d3b1f01c5d76eb88c008b8
Parents: 289eeaa
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Tue Jul 12 16:57:25 2016 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Tue Jul 12 16:57:25 2016 -0700

----------------------------------------------------------------------
 scripts/staging/SystemML-NN/nn/layers/lstm.dml  | 255 +++++++++++++++++++
 scripts/staging/SystemML-NN/nn/layers/rnn.dml   | 182 +++++++++++++
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 247 ++++++++++++++++++
 scripts/staging/SystemML-NN/nn/test/tests.dml   |   2 +
 4 files changed, 686 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d926ea05/scripts/staging/SystemML-NN/nn/layers/lstm.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/lstm.dml b/scripts/staging/SystemML-NN/nn/layers/lstm.dml
new file mode 100644
index 0000000..b0fdd52
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/lstm.dml
@@ -0,0 +1,255 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * LSTM layer.
+ */
+source("nn/layers/sigmoid.dml") as sigmoid
+source("nn/layers/tanh.dml") as tanh
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int T, int D,
+                   boolean return_sequences, matrix[double] out0, matrix[double] c0)
+    return (matrix[double] out, matrix[double] c,
+            matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog) {
+  /*
+   * Computes the forward pass for an LSTM layer with M neurons.
+   * The input data has N sequences of T examples, each with D features.
+   *
+   * In an LSTM, an internal cell state is maintained, additive
+   * interactions operate over the cell state at each timestep, and
+   * some amount of this cell state is exposed as output at each
+   * timestep.  Additionally, the output of the previous timestep is fed
+   * back in as an additional input at the current timestep.
+   *
+   * Reference:
+   *  - Long Short-Term Memory, Hochreiter, 1997
+   *    - http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, T*D).
+   *  - W: Weights (parameters) matrix, of shape (D+M, 4M).
+   *  - b: Biases vector, of shape (1, 4M).
+   *  - T: Length of example sequences (number of timesteps).
+   *  - D: Dimensionality of the input features.
+   *  - return_sequences: Whether to return `out` at all timesteps,
+   *      or just for the final timestep.
+   *  - out0: Output matrix at previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - c0: Initial cell state matrix, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *
+   * Outputs:
+   *  - out: If `return_sequences` is True, outputs for all timesteps,
+   *      of shape (N, T*M).  Else, outputs for the final timestep, of
+   *      shape (N, M).
+   *  - c: Cell state for final timestep, of shape (N, M).
+   *  - cache_out: Cache of outputs, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_c: Cache of cell state, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_ifog: Cache of intermediate values, of shape (T, N*4M).
+   *      Note: This is used for performance during training.
+   */
+  N = nrow(X)
+  M = as.integer(ncol(W)/4)
+  out_prev = out0
+  c_prev = c0
+  c = c_prev
+  if (return_sequences) {
+    out = matrix(0, rows=N, cols=T*M)
+  }
+  else {
+    out = matrix(0, rows=N, cols=M)
+  }
+  # caches to be used during the backward pass for performance
+  cache_out = matrix(0, rows=T, cols=N*M)
+  cache_c = matrix(0, rows=T, cols=N*M) 
+  cache_ifog = matrix(0, rows=T, cols=N*4*M)
+
+  for (t in 1:T) {  # each timestep
+    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+    input = cbind(X_t, out_prev)  # shape (N, D+M)
+    ifog = input %*% W + b  # input, forget, output, and g gates; shape (N, 4M)
+    tmp = sigmoid::forward(ifog[,1:3*M])  # i,f,o gates squashed with sigmoid
+    ifog[,1:3*M] = tmp
+    tmp = tanh::forward(ifog[,3*M+1:4*M])  # g gate squashed with tanh
+    ifog[,3*M+1:4*M] = tmp
+    # c_t = f*prev_c + i*g
+    c = ifog[,M+1:2*M]*c_prev + ifog[,1:M]*ifog[,3*M+1:4*M]  # shape (N, M)
+    # out_t = o*tanh(c)
+    tmp = tanh::forward(c)
+    out_t = ifog[,2*M+1:3*M] * tmp  # shape (N, M)
+
+    # store
+    if (return_sequences) {
+      out[,(t-1)*M+1:t*M] = out_t
+    }
+    else {
+      out = out_t
+    }
+    out_prev = out_t
+    c_prev = c
+    cache_out[t,] = matrix(out_t, rows=1, cols=N*M)  # reshape
+    cache_c[t,] = matrix(c, rows=1, cols=N*M)  # reshape
+    cache_ifog[t,] = matrix(ifog, rows=1, cols=N*4*M)  # reshape
+  }
+}
+
+backward = function(matrix[double] dout, matrix[double] dc, 
+                    matrix[double] X, matrix[double] W, matrix[double] b, int T, int D,
+                    boolean given_sequences, matrix[double] out0, matrix[double] c0,
+                    matrix[double] cache_out, matrix[double] cache_c, matrix[double] cache_ifog)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db,
+            matrix[double] dout0, matrix[double] dc0) {
+  /*
+   * Computes the backward pass for an LSTM layer with M neurons.
+   *
+   * Inputs:
+   *  - dout: Gradient on output from upstream.  If `given_sequences`
+   *      is True, contains gradients on outputs for all timesteps,
+   *      of shape (N, T*M).  Else, contains gradient on output for
+   *      the final timestep, of shape (N, M).
+   *  - dc: Gradient on final (current) cell state from later in time,
+   *      of shape (N, M).
+   *  - X: Input data matrix, of shape (N, T*D).
+   *  - W: Weights (parameters) matrix, of shape (D+M, 4M).
+   *  - b: Biases vector, of shape (1, 4M).
+   *  - T: Length of example sequences (number of timesteps).
+   *  - D: Dimensionality of the input features.
+   *  - given_sequences: Whether `dout` is for all timesteps,
+   *      or just for the final timestep.  This is based on whether
+   *      `return_sequences` was true in the forward pass.
+   *  - out0: Output matrix at previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - c0: Initial cell state matrix, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - cache_out: Cache of outputs, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_c: Cache of cell state, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *  - cache_ifog: Cache of intermediate values, of shape (T, N*4*M).
+   *      Note: This is used for performance during training.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, T*D).
+   *  - dW: Gradient wrt W, of shape (D+M, 4M).
+   *  - db: Gradient wrt b, of shape (1, 4M).
+   *  - dout0: Gradient wrt out0, of shape (N, M).
+   *  - dc0: Gradient wrt c0, of shape (N, M).
+   */
+  N = nrow(X)
+  M = as.integer(ncol(W)/4)
+  dX = matrix(0, rows=N, cols=T*D)
+  dW = matrix(0, rows=D+M, cols=4*M)
+  db = matrix(0, rows=1, cols=4*M)
+  dout0 = matrix(0, rows=N, cols=M)
+  dc0 = matrix(0, rows=N, cols=M)
+  dct = dc
+  if (!given_sequences) {
+    # only given dout for output at final timestep, so prepend empty douts for all other timesteps
+    dout = cbind(matrix(0, rows=N, cols=(T-1)*D), dout)  # shape (N, T*M)
+  }
+
+  t = T
+  for (iter in 1:T) {  # each timestep in reverse order
+    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+    dout_t = dout[,(t-1)*M+1:t*M]  # shape (N, M)
+    out_t = matrix(cache_out[t,], rows=N, cols=M)  # shape (N, M)
+    ct = matrix(cache_c[t,], rows=N, cols=M)  # shape (N, M)
+    if (t == 1) {
+      out_prev = out0  # shape (N, M)
+      c_prev = c0  # shape (N, M)
+    }
+    else {
+      out_prev = matrix(cache_out[t-1,], rows=N, cols=M)  # shape (N, M)
+      c_prev = matrix(cache_c[t-1,], rows=N, cols=M)  # shape (N, M)
+    }
+    input = cbind(X_t, out_prev)  # shape (N, D+M)
+    ifog = matrix(cache_ifog[t,], rows=N, cols=4*M)
+    i = ifog[,1:M]  # input gate, shape (N, M)
+    f = ifog[,M+1:2*M]  # forget gate, shape (N, M)
+    o = ifog[,2*M+1:3*M]  # output gate, shape (N, M)
+    g = ifog[,3*M+1:4*M]  # g gate, shape (N, M)
+
+    tmp = tanh::backward(dout_t, ct)
+    dct = dct + o * tmp  # shape (N, M)
+    tmp = tanh::forward(ct)
+    do = tmp * dout_t  # output gate, shape (N, M)
+    df = c_prev * dct  # forget gate, shape (N, M)
+    dc_prev = f * dct  # shape (N, M)
+    di = g * dct  # input gate, shape (N, M)
+    dg = i * dct  # g gate, shape (N, M)
+    
+    di_raw = i * (1-i) * di
+    df_raw = f * (1-f) * df
+    do_raw = o * (1-o) * do
+    dg_raw = (1 - g^2) * dg
+    difog_raw = cbind(di_raw, cbind(df_raw, cbind(do_raw, dg_raw)))  # shape (N, 4M)
+
+    dW = dW + t(input) %*% difog_raw  # shape (D+M, 4M)
+    db = db + colSums(difog_raw)  # shape (1, 4M)
+    dinput = difog_raw %*% t(W)  # shape (N, D+M)
+    dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+    dout_prev = dinput[,D+1:D+M]  # shape (N, M)
+    if (t == 1) {
+      dout0 = dout_prev  # shape (N, M)
+      dc0 = dc_prev  # shape (N, M)
+    }
+    else {
+      dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev  # shape (N, M)
+      dct = dc_prev  # shape (N, M)
+    }
+    t = t-1
+  }
+}
+
+init = function(int N, int D, int M)
+    return (matrix[double] W, matrix[double] b, matrix[double] out0, matrix[double] c0) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   * 
+   * We use the Glorot uniform heuristic which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * uniform weights by a factor of sqrt(6/(fan_in + fan_out)).
+   *
+   * Inputs:
+   *  - N: Number of examples in batch.
+   *  - D: Dimensionality of the input features.
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (D+M, 4M).
+   *  - b: Biases vector, of shape (1, 4M).
+   *  - out0: Dummy output matrix at previous timestep, of shape (N, M).
+   *  - c0: Initial empty cell state matrix, of shape (N, M).
+   */
+  fan_in = D+M
+  fan_out = 4*M
+  scale = sqrt(6/(fan_in+fan_out))
+  W = rand(rows=D+M, cols=4*M, min=-scale, max=scale, pdf="uniform")
+  b = matrix(0, rows=1, cols=4*M) 
+  out0 = matrix(0, rows=N, cols=M)
+  c0 = matrix(0, rows=N, cols=M)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d926ea05/scripts/staging/SystemML-NN/nn/layers/rnn.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/rnn.dml b/scripts/staging/SystemML-NN/nn/layers/rnn.dml
new file mode 100644
index 0000000..6c432bd
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/rnn.dml
@@ -0,0 +1,182 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Simple (Vanilla) RNN layer.
+ */
+source("nn/layers/tanh.dml") as tanh
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int T, int D,
+                   boolean return_sequences, matrix[double] out0)
+    return (matrix[double] out, matrix[double] cache_out) {
+  /*
+   * Computes the forward pass for a simple RNN layer with M neurons.
+   * The input data has N sequences of T examples, each with D features.
+   *
+   * In a simple RNN, the output of the previous timestep is fed back
+   * in as an additional input at the current timestep.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, T*D).
+   *  - W: Weights (parameters) matrix, of shape (D+M, M).
+   *  - b: Biases vector, of shape (1, M).
+   *  - T: Length of example sequences (number of timesteps).
+   *  - D: Dimensionality of the input features.
+   *  - return_sequences: Whether to return `out` at all timesteps,
+   *      or just for the final timestep.
+   *  - out0: Output matrix at previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *
+   * Outputs:
+   *  - out: If `return_sequences` is True, outputs for all timesteps,
+   *      of shape (N, T*M).  Else, outputs for the final timestep, of
+   *      shape (N, M).
+   *  - cache_out: Cache of outputs, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   */
+  N = nrow(X)
+  M = ncol(W)
+  out_prev = out0
+  if (return_sequences) {
+    out = matrix(0, rows=N, cols=T*M)
+  }
+  else {
+    out = matrix(0, rows=N, cols=M)
+  }
+  # caches to be used during the backward pass for performance
+  cache_out = matrix(0, rows=T, cols=N*M)
+
+  for (t in 1:T) {  # each timestep
+    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+    input = cbind(X_t, out_prev)  # shape (N, D+M)
+    out_t = tanh::forward(input %*% W + b)  # shape (N, M)
+    # store
+    if (return_sequences) {
+      out[,(t-1)*M+1:t*M] = out_t
+    }
+    else {
+      out = out_t
+    }
+    out_prev = out_t
+    cache_out[t,] = matrix(out_t, rows=1, cols=N*M)  # reshape
+  }
+}
+
+backward = function(matrix[double] dout, matrix[double] X, matrix[double] W, matrix[double] b,
+                    int T, int D, boolean given_sequences, matrix[double] out0,
+                    matrix[double] cache_out) 
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db, matrix[double] dout0) {
+  /*
+   * Computes the backward pass for a simple RNN layer with M neurons.
+   *
+   * Inputs:
+   *  - dout: Gradient on output from upstream.  If `given_sequences`
+   *      is True, contains gradients on outputs for all timesteps,
+   *      of shape (N, T*M).  Else, contains gradient on output for
+   *      the final timestep, of shape (N, M).
+   *  - X: Input data matrix, of shape (N, T*D).
+   *  - W: Weights (parameters) matrix, of shape (D+M, M).
+   *  - b: Biases vector, of shape (1, M).
+   *  - T: Length of example sequences (number of timesteps).
+   *  - D: Dimensionality of the input features.
+   *  - given_sequences: Whether `dout` is for all timesteps,
+   *      or just for the final timestep.  This is based on whether
+   *      `return_sequences` was true in the forward pass.
+   *  - out0: Output matrix at previous timestep, of shape (N, M).
+   *      Note: This is *optional* and could just be an empty matrix.
+   *  - cache_out: Cache of outputs, of shape (T, N*M).
+   *      Note: This is used for performance during training.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, T*D).
+   *  - dW: Gradient wrt W, of shape (D+M, 4M).
+   *  - db: Gradient wrt b, of shape (1, 4M).
+   *  - dout0: Gradient wrt out0, of shape (N, M).
+   */
+  N = nrow(X)
+  M = ncol(W)
+  dX = matrix(0, rows=N, cols=T*D)
+  dW = matrix(0, rows=D+M, cols=M)
+  db = matrix(0, rows=1, cols=M)
+  dout0 = matrix(0, rows=N, cols=M)
+  if (!given_sequences) {
+    # only given dout for output at final timestep, so prepend empty douts for all other timesteps
+    dout = cbind(matrix(0, rows=N, cols=(T-1)*D), dout)  # shape (N, T*M)
+  }
+
+  t = T
+  for (iter in 1:T) {  # each timestep in reverse order
+    X_t = X[,(t-1)*D+1:t*D]  # shape (N, D)
+    dout_t = dout[,(t-1)*M+1:t*M]  # shape (N, M)
+    out_t = matrix(cache_out[t,], rows=N, cols=M)  # shape (N, M)
+    if (t == 1) {
+      out_prev = out0  # shape (N, M)
+    }
+    else {
+      out_prev = matrix(cache_out[t-1,], rows=N, cols=M)  # shape (N, M)
+    }
+    input = cbind(X_t, out_prev)  # shape (N, D+M)
+    dout_t_raw = (1 - out_t^2) * dout_t  # into tanh, shape (N, M)
+    dW = dW + t(input) %*% dout_t_raw  # shape (D+M, M)
+    db = db + colSums(dout_t_raw)  # shape (1, M)
+    dinput = dout_t_raw %*% t(W)  # shape (N, D+M)
+    dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+    dout_prev = dinput[,D+1:D+M]  # shape (N, M)
+    if (t == 1) {
+      dout0 = dout_prev  # shape (N, M)
+    }
+    else {
+      dout[,(t-2)*M+1:(t-1)*M] = dout[,(t-2)*M+1:(t-1)*M] + dout_prev  # shape (N, M)
+    }
+    t = t-1
+  }
+}
+
+init = function(int N, int D, int M)
+    return (matrix[double] W, matrix[double] b, matrix[double] out0) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   * 
+   * We use the Glorot uniform heuristic which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * uniform weights by a factor of sqrt(6/(fan_in + fan_out)).
+   *
+   * Inputs:
+   *  - N: Number of examples in batch.
+   *  - D: Dimensionality of the input features.
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (D+M, M).
+   *  - b: Biases vector, of shape (1, M).
+   *  - out0: Dummy output matrix at previous timestep, of shape (N, M).
+   */
+  fan_in = D+M
+  fan_out = M
+  scale = sqrt(6/(fan_in+fan_out))
+  W = rand(rows=D+M, cols=M, min=-scale, max=scale, pdf="uniform")
+  b = matrix(0, rows=1, cols=M) 
+  out0 = matrix(0, rows=N, cols=M)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d926ea05/scripts/staging/SystemML-NN/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/grad_check.dml b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
index af985a3..67bd854 100644
--- a/scripts/staging/SystemML-NN/nn/test/grad_check.dml
+++ b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
@@ -32,9 +32,11 @@ source("nn/layers/l1_reg.dml") as l1_reg
 source("nn/layers/l2_loss.dml") as l2_loss
 source("nn/layers/l2_reg.dml") as l2_reg
 source("nn/layers/log_loss.dml") as log_loss
+source("nn/layers/lstm.dml") as lstm
 source("nn/layers/max_pool.dml") as max_pool
 source("nn/layers/max_pool_builtin.dml") as max_pool_builtin
 source("nn/layers/relu.dml") as relu
+source("nn/layers/rnn.dml") as rnn
 source("nn/layers/sigmoid.dml") as sigmoid
 source("nn/layers/softmax.dml") as softmax
 source("nn/layers/tanh.dml") as tanh
@@ -667,6 +669,150 @@ log_loss = function() {
   }
 }
 
+lstm = function() {
+  /*
+   * Gradient check for the LSTM layer.
+   */
+  print("Grad checking the LSTM layer with L2 loss.")
+
+  # Generate data
+  N = 3  # num examples
+  D = 10  # num features
+  T = 15  # num timesteps (sequence length)
+  M = 5 # num neurons
+  return_seq = TRUE
+  X = rand(rows=N, cols=T*D)
+  y = rand(rows=N, cols=T*M)
+  yc = rand(rows=N, cols=M)
+  out0 = rand(rows=N, cols=M)
+  c0 = rand(rows=N, cols=M)
+  [W, b, dummy, dummy2] = lstm::init(N, D, M)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, c, cache_out, cache_c, cache_ifog] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+  dout = l2_loss::backward(out, y)
+  dc = l2_loss::backward(c, yc)
+  [dX, dW, db, dout0, dc0] = lstm::backward(dout, dc, X, W, b, T, D, return_seq, out0, c0,
+                                            cache_out, cache_c, cache_ifog)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, cmh, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outmh = l2_loss::forward(outmh, y)
+      loss_cmh = l2_loss::forward(cmh, yc)
+      lossmh = loss_outmh + loss_cmh
+      X[i,j] = old + h
+      [outph, cph, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outph = l2_loss::forward(outph, y)
+      loss_cph = l2_loss::forward(cph, yc)
+      lossph = loss_outph + loss_cph
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, cmh, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outmh = l2_loss::forward(outmh, y)
+      loss_cmh = l2_loss::forward(cmh, yc)
+      lossmh = loss_outmh + loss_cmh
+      W[i,j] = old + h
+      [outph, cph, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outph = l2_loss::forward(outph, y)
+      loss_cph = l2_loss::forward(cph, yc)
+      lossph = loss_outph + loss_cph
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, cmh, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outmh = l2_loss::forward(outmh, y)
+      loss_cmh = l2_loss::forward(cmh, yc)
+      lossmh = loss_outmh + loss_cmh
+      b[i,j] = old + h
+      [outph, cph, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outph = l2_loss::forward(outph, y)
+      loss_cph = l2_loss::forward(cph, yc)
+      lossph = loss_outph + loss_cph
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking out0.")
+  for (i in 1:nrow(out0)) {
+    for (j in 1:ncol(out0)) {
+      # Compute numerical derivative
+      old = as.scalar(out0[i,j])
+      out0[i,j] = old - h
+      [outmh, cmh, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outmh = l2_loss::forward(outmh, y)
+      loss_cmh = l2_loss::forward(cmh, yc)
+      lossmh = loss_outmh + loss_cmh
+      out0[i,j] = old + h
+      [outph, cph, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outph = l2_loss::forward(outph, y)
+      loss_cph = l2_loss::forward(cph, yc)
+      lossph = loss_outph + loss_cph
+      out0[i,j] = old  # reset
+      dout0_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dout0[i,j]), dout0_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking c0.")
+  for (i in 1:nrow(c0)) {
+    for (j in 1:ncol(c0)) {
+      # Compute numerical derivative
+      old = as.scalar(c0[i,j])
+      c0[i,j] = old - h
+      [outmh, cmh, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outmh = l2_loss::forward(outmh, y)
+      loss_cmh = l2_loss::forward(cmh, yc)
+      lossmh = loss_outmh + loss_cmh
+      c0[i,j] = old + h
+      [outph, cph, cache, cache, cache] = lstm::forward(X, W, b, T, D, return_seq, out0, c0)
+      loss_outph = l2_loss::forward(outph, y)
+      loss_cph = l2_loss::forward(cph, yc)
+      lossph = loss_outph + loss_cph
+      c0[i,j] = old  # reset
+      dc0_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dc0[i,j]), dc0_num, lossph, lossmh)
+    }
+  }
+}
+
 max_pool = function() {
   /*
    * Gradient check for the max pooling layer.
@@ -841,6 +987,107 @@ relu = function() {
   }
 }
 
+rnn = function() {
+  /*
+   * Gradient check for the simple RNN layer.
+   */
+  print("Grad checking the simple RNN layer with L2 loss.")
+
+  # Generate data
+  N = 3  # num examples
+  D = 10  # num features
+  T = 15  # num timesteps (sequence length)
+  M = 5 # num neurons
+  return_seq = TRUE
+  X = rand(rows=N, cols=T*D)
+  y = rand(rows=N, cols=T*M)
+  out0 = rand(rows=N, cols=M)
+  [W, b, dummy] = rnn::init(N, D, M)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+  dout = l2_loss::backward(out, y)
+  [dX, dW, db, dout0] = rnn::backward(dout, X, W, b, T, D, return_seq, out0, cache_out)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      [outph, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      [outph, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking out0.")
+  for (i in 1:nrow(out0)) {
+    for (j in 1:ncol(out0)) {
+      # Compute numerical derivative
+      old = as.scalar(out0[i,j])
+      out0[i,j] = old - h
+      [outmh, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossmh = l2_loss::forward(outmh, y)
+      out0[i,j] = old + h
+      [outph, cache_out] = rnn::forward(X, W, b, T, D, return_seq, out0)
+      lossph = l2_loss::forward(outph, y)
+      out0[i,j] = old  # reset
+      dout0_num = (lossph - lossmh) / (2 * h)  # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dout0[i,j]), dout0_num, lossph, lossmh)
+    }
+  }
+}
+
 sigmoid = function() {
   /*
    * Gradient check for the sigmoid nonlinearity layer.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d926ea05/scripts/staging/SystemML-NN/nn/test/tests.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/tests.dml b/scripts/staging/SystemML-NN/nn/test/tests.dml
index cac56c2..5535c85 100644
--- a/scripts/staging/SystemML-NN/nn/test/tests.dml
+++ b/scripts/staging/SystemML-NN/nn/test/tests.dml
@@ -40,10 +40,12 @@ tmp = grad_check::conv_builtin()
 tmp = grad_check::dropout()
 tmp = grad_check::l1_reg()
 tmp = grad_check::l2_reg()
+tmp = grad_check::lstm()
 tmp = grad_check::max_pool_simple()
 tmp = grad_check::max_pool()
 tmp = grad_check::max_pool_builtin()
 tmp = grad_check::relu()
+tmp = grad_check::rnn()
 tmp = grad_check::sigmoid()
 tmp = grad_check::softmax()
 tmp = grad_check::tanh()