You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2016/05/19 01:02:42 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-618] Deep Learning DML Library

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 9dc42b2fc -> 781d24d86


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/conv_simple.dml b/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
new file mode 100644
index 0000000..f065668
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/test/conv_simple.dml
@@ -0,0 +1,211 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Convolutional layer.
+ *
+ * This implementation is intended to be a simple, reference version.
+ */
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+                   int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew, int padh, int padw)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial convolutional layer with
+   * F filters.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * This implementation is intended to be a simple, reference version.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
+  
+  # Create output volume
+  out = matrix(0, rows=N, cols=F*Hout*Wout)
+
+  # Convolution - Simple reference implementation
+  parfor (n in 1:N) {  # all examples
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)
+    # Pad image
+    Xn_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
+    parfor (c in 1:C) {
+      Xn_slice = matrix(Xn[c,], rows=Hin, cols=Win)  # depth slice C reshaped
+      Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      Xn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = Xn_slice
+      Xn_padded[c, ] = matrix(Xn_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))  # reshape
+    }
+    # Convolve image with filters
+    parfor (f in 1:F, check=0) {  # all filters
+      parfor (hout in 1:Hout, check=0) {  # all output rows
+        h0 = (hout-1) * strideh + 1
+        parfor (wout in 1:Wout, check=0) {  # all output columns
+          w0 = (wout-1) * stridew + 1
+          # Create a patch of the input example corresponding spatially to the filter sizes
+          Xn_padded_patch = matrix(0, rows=C, cols=Hf*Wf)  # zeros
+          parfor (c in 1:C, check=0) {
+            Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)  # reshape
+            Xn_padded_patch[c,] = 
+              matrix(Xn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf], rows=1, cols=Hf*Wf)  # reshape
+          }
+          out[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout] = 
+            W[f,] %*% matrix(Xn_padded_patch, rows=C*Hf*Wf, cols=1) + b[f,]
+        }
+      }
+    }
+  }
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout,
+                    matrix[double] X, matrix[double] W, matrix[double] b,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Computes the backward pass for a 2D spatial convolutional layer
+   * with F filters.
+   *
+   * This implementation is intended to be a simple, reference version.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Previous input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt W, of shape (F, C*Hf*Wf).
+   *  - db: Gradient wrt b, of shape (F, 1).
+   */
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
+  
+  # Create gradient volumes
+  dX = matrix(0, rows=N, cols=C*Hin*Win)
+  dW = matrix(0, rows=F, cols=C*Hf*Wf)
+  db = matrix(0, rows=F, cols=1)
+
+  # Partial derivatives for convolution - Simple reference implementation
+  for (n in 1:N) {  # all examples
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)
+    # Pad image
+    Xn_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
+    parfor (c in 1:C) {
+      Xn_slice = matrix(Xn[c,], rows=Hin, cols=Win)  # depth slice C reshaped
+      Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)
+      Xn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = Xn_slice
+      Xn_padded[c, ] = matrix(Xn_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))  # reshape
+    }
+    dXn_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))
+    for (f in 1:F) {  # all filters
+      for (hout in 1:Hout) {  # all output rows
+        h0 = (hout-1) * strideh + 1
+        for (wout in 1:Wout) {  # all output columns
+          w0 = (wout-1) * stridew + 1
+          # Create a patch of the input example corresponding spatially to the filter sizes
+          Xn_padded_patch = matrix(0, rows=C, cols=Hf*Wf)  # zeros
+          dXn_padded_patch = matrix(W[f,] * dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout],
+                                    rows=C, cols=Hf*Wf)  # reshape
+          for (c in 1:C) {
+            Xn_padded_slice = matrix(Xn_padded[c,], rows=Hin+2*padh, cols=Win+2*padw)  # reshape
+            Xn_padded_patch[c,] = 
+              matrix(Xn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf], rows=1, cols=Hf*Wf)  # reshape
+            dXn_padded_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
+            dXn_padded_slice[h0:h0-1+Hf, w0:w0-1+Wf] =
+              matrix(dXn_padded_patch[c,], rows=Hf, cols=Wf)  # reshape
+            dXn_padded[c,] = dXn_padded[c,] +
+              matrix(dXn_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))
+          }
+          dW[f,] = dW[f,] + matrix(Xn_padded_patch, rows=1, cols=C*Hf*Wf) *
+            dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout]
+          db[f,] = db[f,] + dout[n, (f-1)*Hout*Wout + (hout-1)*Wout + wout]
+        }
+      }
+    }
+    # Unpad derivs on input
+    dXn = matrix(0, rows=C, cols=Hin*Win)
+    parfor (c in 1:C, check=0) {
+      dXn_padded_slice = matrix(dXn_padded[c,], rows=(Hin+2*padh), cols=(Win+2*padw))
+      dXn_slice = dXn_padded_slice[padh+1:padh+Hin, padw+1:padw+Win]
+      dXn[c, ] = matrix(dXn_slice, rows=1, cols=Hin*Win)
+    }
+    dX[n,] = matrix(dXn, rows=1, cols=C*Hin*Win)
+  }
+}
+
+init = function(int F, int C, int Hf, int Wf)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   * 
+   * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
+   * which limits the magnification of inputs/gradients during
+   * forward/backward passes by scaling unit-Gaussian weights by a
+   * factor of sqrt(2/n), under the assumption of relu neurons.
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   */
+  W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1) 
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/grad_check.dml b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
new file mode 100644
index 0000000..af985a3
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
@@ -0,0 +1,1139 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Gradient checks for various architectures.
+ */
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv.dml") as conv
+source("nn/layers/conv_builtin.dml") as conv_builtin
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l1_loss.dml") as l1_loss
+source("nn/layers/l1_reg.dml") as l1_reg
+source("nn/layers/l2_loss.dml") as l2_loss
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/log_loss.dml") as log_loss
+source("nn/layers/max_pool.dml") as max_pool
+source("nn/layers/max_pool_builtin.dml") as max_pool_builtin
+source("nn/layers/relu.dml") as relu
+source("nn/layers/sigmoid.dml") as sigmoid
+source("nn/layers/softmax.dml") as softmax
+source("nn/layers/tanh.dml") as tanh
+source("nn/test/conv_simple.dml") as conv_simple
+source("nn/test/max_pool_simple.dml") as max_pool_simple
+source("nn/util.dml") as util
+
+check_rel_error = function(double dw_a, double dw_n, double lossph, double lossmh)
+    return (double rel_error) {
+  /*
+   * Check and report any issues with the relative error measure between
+   * the analytical and numerical partial derivatives.
+   *
+   *  - Issues an "ERROR" statement for relative errors > 1e-2, 
+   *  indicating that the gradient is likely incorrect.
+   *  - Issues a "WARNING" statement for relative errors < 1e-2
+   *  but > 1e-4, indicating that the may be incorrect.
+   *
+   * Inputs:
+   *  - dw_a: Analytical partial derivative wrt w.
+   *  - dw_n: Numerical partial derivative wrt w.
+   *  - lossph: Loss evaluated with w set to w+h.
+   *  - lossmh: Loss evaluated with w set to w-h.
+   *
+   * Outputs:
+   *  - rel_error: Relative error measure between the two derivatives.
+   */
+  # Compute relative error
+  rel_error = util::compute_rel_error(dw_a, dw_n)
+  
+  # Evaluate relative error
+  if (rel_error > 1e-2) {
+      print("ERROR: Relative error " + rel_error + " > 1e-2 with " + dw_a +
+            " analytical vs " + dw_n + " numerical, with lossph " + lossph +
+            " and lossmh " + lossmh)
+  }
+  else if (rel_error > 1e-4 & rel_error <= 1e-2) {
+      print("WARNING: Relative error " + rel_error + " <= 1e-2 & > 1e-4 with " + dw_a +
+            " analytical vs " + dw_n + " numerical, with lossph " + lossph +
+            " and lossmh " + lossmh)
+  }
+}
+
+affine = function() {
+  /*
+   * Gradient check for the affine layer.
+   */
+  print("Grad checking the affine layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 100 # num features
+  M = 10 # num neurons
+  X = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=M)
+  [W, b] = affine::init(D, M)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = affine::forward(X, W, b)
+  dout = l2_loss::backward(out, y)
+  [dX, dW, db] = affine::backward(dout, X, W, b)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = affine::forward(X, W, b)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = affine::forward(X, W, b)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      outmh = affine::forward(X, W, b)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      outph = affine::forward(X, W, b)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      outmh = affine::forward(X, W, b)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      outph = affine::forward(X, W, b)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+}
+
+conv = function() {
+  /*
+   * Gradient check for the convolutional layer using `im2col`.
+   */
+  print("Grad checking the `im2col` convolutional layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  F = 2  # num filters
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 1
+  pad = 1
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=F*Hin*Win)
+
+  # Create layers
+  [W, b] = conv::init(F, C, Hf, Wf)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+  dout = l2_loss::backward(out, y)
+  [dX, dW, db] =
+    conv::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      [outph, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      [outph, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+}
+
+conv_builtin = function() {
+  /*
+   * Gradient check for the convolutional layer using built-in functions.
+   */
+  print("Grad checking the built-in convolutional layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  F = 2  # num filters
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 1
+  pad = 1
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=F*Hin*Win)
+
+  # Create layers
+  [W, b] = conv_builtin::init(F, C, Hf, Wf)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+  dout = l2_loss::backward(out, y)
+  [dX, dW, db] =
+    conv_builtin::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      [outph, Hout, Wout] = conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+}
+
+conv_simple = function() {
+  /*
+   * Gradient check for the simple reference convolutional layer.
+   */
+  print("Grad checking the simple reference convolutional layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  F = 2  # num filters
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 1
+  pad = 1
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=F*Hin*Win)
+
+  # Create layers
+  [W, b] = conv_simple::init(F, C, Hf, Wf)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+  dout = l2_loss::backward(out, y)
+  [dX, dW, db] =
+    conv_simple::backward(dout, Hout, Wout, X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W.")
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      W[i,j] = old + h
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      W[i,j] = old  # reset
+      dW_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b.")
+  for (i in 1:nrow(b)) {
+    for (j in 1:ncol(b)) {
+      # Compute numerical derivative
+      old = as.scalar(b[i,j])
+      b[i,j] = old - h
+      [outmh, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossmh = l2_loss::forward(outmh, y)
+      b[i,j] = old + h
+      [outph, Hout, Wout] = conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+      lossph = l2_loss::forward(outph, y)
+      b[i,j] = old  # reset
+      db_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db[i,j]), db_num, lossph, lossmh)
+    }
+  }
+}
+
+cross_entropy_loss = function() {
+  /*
+   * Gradient check for the cross-entropy loss function.
+   */
+  print("Grad checking the cross-entropy loss function.")
+
+  # Generate data
+  N = 3 # num examples
+  K = 10 # num targets
+  pred = rand(rows=N, cols=K, min=0, max=1, pdf="uniform")
+  pred = pred / rowSums(pred)  # normalized probs
+  y = rand(rows=N, cols=K, min=0, max=1, pdf="uniform")
+  y = y / rowSums(y)  # normalized probs
+
+  # Compute analytical gradient
+  dpred = cross_entropy_loss::backward(pred, y)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(pred)) {
+    for (j in 1:ncol(pred)) {
+      # Compute numerical derivative
+      old = as.scalar(pred[i,j])
+      pred[i,j] = old - h
+      lossmh = cross_entropy_loss::forward(pred, y)
+      pred[i,j] = old + h
+      lossph = cross_entropy_loss::forward(pred, y)
+      pred[i,j] = old  # reset W[i,j]
+      dpred_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dpred[i,j]), dpred_num, lossph, lossmh)
+    }
+  }
+}
+
+dropout = function() {
+  /*
+   * Gradient check for the (inverted) dropout layer.
+   */
+  print("Grad checking the (inverted) dropout layer with L2 loss.")
+
+  # Generate data
+  N = 3  # num examples
+  M = 100  # num neurons
+  p = 0.5  # probability of dropping neuron output
+  seed = as.integer(floor(as.scalar(rand(rows=1, cols=1, min=1, max=100000))))  # random seed
+  X = rand(rows=N, cols=M)
+  y = rand(rows=N, cols=M)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, mask] = dropout::forward(X, p, seed)
+  dout = l2_loss::backward(out, y)
+  dX = dropout::backward(dout, X, p, mask)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, mask] = dropout::forward(X, p, seed)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, mask] = dropout::forward(X, p, seed)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+l1_loss = function() {
+  /*
+   * Gradient check for the L1 loss function.
+   */
+  print("Grad checking the L1 loss function.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 2 # num targets
+  pred = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=D)
+
+  # Compute analytical gradient
+  dpred = l1_loss::backward(pred, y)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(pred)) {
+    for (j in 1:ncol(pred)) {
+      # Compute numerical derivative
+      old = as.scalar(pred[i,j])
+      pred[i,j] = old - h
+      lossmh = l1_loss::forward(pred, y)
+      pred[i,j] = old + h
+      lossph = l1_loss::forward(pred, y)
+      pred[i,j] = old  # reset W[i,j]
+      dpred_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dpred[i,j]), dpred_num, lossph, lossmh)
+    }
+  }
+}
+
+l1_reg = function() {
+  /*
+   * Gradient check for the L1 regularization function.
+   */
+  print("Grad checking the L1 regularization function.")
+
+  # Generate data
+  D = 5 # num features
+  M = 3 # num neurons
+  lambda = 0.01
+  W = rand(rows=D, cols=M)
+
+  # Compute analytical gradient
+  dW = l1_reg::backward(W, lambda)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      reg_lossmh = l1_reg::forward(W, lambda)
+      W[i,j] = old + h
+      reg_lossph = l1_reg::forward(W, lambda)
+      W[i,j] = old  # reset W[i,j]
+      dW_num = (reg_lossph - reg_lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, reg_lossph, reg_lossmh)
+    }
+  }
+}
+
+l2_loss = function() {
+  /*
+   * Gradient check for the L2 loss function.
+   */
+  print("Grad checking the L2 loss function.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 2 # num targets
+  pred = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=D)
+
+  # Compute analytical gradient
+  dpred = l2_loss::backward(pred, y)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(pred)) {
+    for (j in 1:ncol(pred)) {
+      # Compute numerical derivative
+      old = as.scalar(pred[i,j])
+      pred[i,j] = old - h
+      lossmh = l2_loss::forward(pred, y)
+      pred[i,j] = old + h
+      lossph = l2_loss::forward(pred, y)
+      pred[i,j] = old  # reset W[i,j]
+      dpred_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dpred[i,j]), dpred_num, lossph, lossmh)
+    }
+  }
+}
+
+l2_reg = function() {
+  /*
+   * Gradient check for the L2 regularization function.
+   */
+  print("Grad checking the L2 regularization function.")
+
+  # Generate data
+  D = 5 # num features
+  M = 3 # num neurons
+  lambda = 0.01
+  W = rand(rows=D, cols=M)
+
+  # Compute analytical gradient
+  dW = l2_reg::backward(W, lambda)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(W)) {
+    for (j in 1:ncol(W)) {
+      # Compute numerical derivative
+      old = as.scalar(W[i,j])
+      W[i,j] = old - h
+      reg_lossmh = l2_reg::forward(W, lambda)
+      W[i,j] = old + h
+      reg_lossph = l2_reg::forward(W, lambda)
+      W[i,j] = old  # reset W[i,j]
+      dW_num = (reg_lossph - reg_lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW[i,j]), dW_num, reg_lossph, reg_lossmh)
+    }
+  }
+}
+
+log_loss = function() {
+  /*
+   * Gradient check for the log loss function.
+   */
+  print("Grad checking the log loss function.")
+
+  # Generate data
+  N = 20 # num examples
+  D = 1 # num targets
+  pred = rand(rows=N, cols=D, min=0, max=1, pdf="uniform")
+  y = round(rand(rows=N, cols=D, min=0, max=1, pdf="uniform"))
+
+  # Compute analytical gradient
+  dpred = log_loss::backward(pred, y)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(pred)) {
+    for (j in 1:ncol(pred)) {
+      # Compute numerical derivative
+      old = as.scalar(pred[i,j])
+      pred[i,j] = old - h
+      lossmh = log_loss::forward(pred, y)
+      pred[i,j] = old + h
+      lossph = log_loss::forward(pred, y)
+      pred[i,j] = old  # reset W[i,j]
+      dpred_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dpred[i,j]), dpred_num, lossph, lossmh)
+    }
+  }
+}
+
+max_pool = function() {
+  /*
+   * Gradient check for the max pooling layer.
+   */
+  print("Grad checking the max pooling layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 4  # input height
+  Win = 4  # input width
+  Hf = 2  # pool filter height
+  Wf = 2  # pool filter width
+  stride = 2
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*2*2)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  dout = l2_loss::backward(out, y)
+  dX = max_pool::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, stride)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+max_pool_builtin = function() {
+  /*
+   * Gradient check for the max pooling layer.
+   */
+  print("Grad checking the built-in max pooling layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 4  # input height
+  Win = 4  # input width
+  Hf = 2  # pool filter height
+  Wf = 2  # pool filter width
+  stride = 2
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*2*2)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  dout = l2_loss::backward(out, y)
+  dX = max_pool_builtin::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, stride)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+max_pool_simple = function() {
+  /*
+   * Gradient check for the simple reference max pooling layer.
+   */
+  print("Grad checking the simple reference max pooling layer with L2 loss.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 2  # num channels
+  Hin = 4  # input height
+  Win = 4  # input width
+  Hf = 2  # pool filter height
+  Wf = 2  # pool filter width
+  stride = 2
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*2*2)
+
+  # Compute analytical gradients of loss wrt parameters
+  [out, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  dout = l2_loss::backward(out, y)
+  dX = max_pool_simple::backward(dout, Hout, Wout, X, C, Hin, Win, Hf, Wf, stride, stride)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      [outmh, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      [outph, Hout, Wout] = max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+relu = function() {
+  /*
+   * Gradient check for the ReLU nonlinearity layer.
+   *
+   * NOTE: This could result in a false-negative in which the test
+   * fails due to a kink being crossed in the nonlinearity.  This
+   * occurs when the tests, f(x-h) and f(x+h), end up on opposite
+   * sides of the zero threshold of max(0, fx).  For now, just run
+   * the tests again.  In the future, we can explicitly check for
+   * this and rerun the test automatically.
+   */
+  print("Grad checking the ReLU nonlinearity layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  M = 10 # num neurons
+  X = rand(rows=N, cols=M, min=-5, max=5)
+  y = rand(rows=N, cols=M)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = relu::forward(X)
+  dout = l2_loss::backward(out, y)
+  dX = relu::backward(dout, X)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = relu::forward(X)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = relu::forward(X)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+sigmoid = function() {
+  /*
+   * Gradient check for the sigmoid nonlinearity layer.
+   */
+  print("Grad checking the sigmoid nonlinearity layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  M = 10 # num neurons
+  X = rand(rows=N, cols=M)
+  y = rand(rows=N, cols=M)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = sigmoid::forward(X)
+  dout = l2_loss::backward(out, y)
+  dX = sigmoid::backward(dout, X)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = sigmoid::forward(X)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = sigmoid::forward(X)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+softmax = function() {
+  /*
+   * Gradient check for the softmax layer.
+   */
+  print("Grad checking the softmax layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 10 # num classes
+  X = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=D, min=0, max=1, pdf="uniform")
+  y = y / rowSums(y)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = softmax::forward(X)
+  dout = l2_loss::backward(out, y)
+  dX = softmax::backward(dout, X)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = softmax::forward(X)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = softmax::forward(X)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+tanh = function() {
+  /*
+   * Gradient check for the hyperbolic tangent (tanh) nonlinearity layer.
+   */
+  print("Grad checking the tanh nonlinearity layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  M = 10 # num neurons
+  X = rand(rows=N, cols=M)
+  y = rand(rows=N, cols=M)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = tanh::forward(X)
+  dout = l2_loss::backward(out, y)
+  dX = tanh::backward(dout, X)
+
+  # Grad check
+  h = 1e-5
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = tanh::forward(X)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = tanh::forward(X)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+}
+
+two_layer_affine_l2_net = function() {
+  /*
+   * Gradient check for a two-layer, fully-connected, feed-forward
+   * network with ReLU nonlinearity and L2 loss.
+   *
+   * NOTE: This could result in a false-negative in which the test
+   * fails due to a kink being crossed in the ReLU nonlinearity.  This
+   * occurs when the tests, f(x-h) and f(x+h), end up on opposite
+   * sides of the zero threshold of max(0, fx).  For now, just run
+   * the tests again.  In the future, we can explicitly check for
+   * this and rerun the test automatically.
+   */
+  print("Grad checking a two-layer, fully-connected, feed-forward network with a ReLU " +
+        "nonlinearity, and an L2 loss function.")
+
+  # Generate input data
+  N = 1000 # num examples
+  D = 100 # num features
+  yD = 5 # num targets
+  X = rand(rows=N, cols=D, pdf="normal") * 0.0001
+  y = rand(rows=N, cols=yD)
+
+  # Create 2-layer, fully-connected network
+  M = 10 # number of hidden neurons
+  [W1, b1] = affine::init(D, M)
+  [W2, b2] = affine::init(M, yD)
+
+  # Optimize for short "burn-in" time to move to characteristic
+  # mode of operation and unmask any real issues.
+  print(" - Burn-in:")
+  lr = 0.0001
+  decay = 0.99
+  for(i in 1:5) {
+    # Compute forward and backward passes of net
+    [pred, loss, dX, dW1, db1, dW2, db2] = two_layer_affine_l2_net_run(X, y, W1, b1, W2, b2)
+    print("   - L2 loss: " + loss)
+
+    # Optimize with basic SGD
+    W1 = W1 - lr * dW1
+    b1 = b1 - lr * db1
+    W2 = W2 - lr * dW2
+    b2 = b2 - lr * db2
+    lr = lr * decay
+  }
+
+  # Compute analytical gradients
+  [pred, loss, dX, dW1, db1, dW2, db2] = two_layer_affine_l2_net_run(X, y, W1, b1, W2, b2)
+  
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:2) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old_w = as.scalar(X[i,j])
+      X[i,j] = old_w - h
+      [lossmh, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      X[i,j] = old_w + h
+      [lossph, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      X[i,j] = old_w  # reset W[i,j]
+      dX_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W1.")
+  for (i in 1:nrow(W1)) {
+    for (j in 1:ncol(W1)) {
+      # Compute numerical derivative
+      old_w = as.scalar(W1[i,j])
+      W1[i,j] = old_w - h
+      [lossmh, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      W1[i,j] = old_w + h
+      [lossph, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      W1[i,j] = old_w  # reset W[i,j]
+      dWij_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW1[i,j]), dWij_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking W2.")
+  for (i in 1:nrow(W2)) {
+    for (j in 1:ncol(W2)) {
+      # Compute numerical derivative
+      old_w = as.scalar(W2[i,j])
+      W2[i,j] = old_w - h
+      [lossmh, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      W2[i,j] = old_w + h
+      [lossph, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      W2[i,j] = old_w  # reset W[i,j]
+      dWij_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(dW2[i,j]), dWij_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b1.")
+  for (i in 1:nrow(b1)) {
+    for (j in 1:ncol(b1)) {
+      # Compute numerical derivative
+      old_b = as.scalar(b1[i,j])
+      b1[i,j] = old_b - h
+      [lossmh, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      b1[i,j] = old_b + h
+      [lossph, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      b1[i,j] = old_b  # reset b[1,j]
+      dbij_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db1[i,j]), dbij_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking b2.")
+  for (i in 1:nrow(b2)) {
+    for (j in 1:ncol(b2)) {
+      # Compute numerical derivative
+      old_b = as.scalar(b2[i,j])
+      b2[i,j] = old_b - h
+      [lossmh, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      b2[i,j] = old_b + h
+      [lossph, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+      b2[i,j] = old_b  # reset b[1,j]
+      dbij_num = (lossph - lossmh) / (2 * h) # numerical derivative
+
+      # Check error
+      rel_error = check_rel_error(as.scalar(db2[i,j]), dbij_num, lossph, lossmh)
+    }
+  }
+}
+
+/*
+ * Test network with forward/backward functions.
+ */
+two_layer_affine_l2_net_run = function(matrix[double] X, matrix[double] y,
+                                       matrix[double] W1, matrix[double] b1,
+                                       matrix[double] W2, matrix[double] b2)
+    return (matrix[double] pred, double loss,
+            matrix[double] dX,
+            matrix[double] dW1, matrix[double] db1,
+            matrix[double] dW2, matrix[double] db2) {
+  # Compute forward pass
+  [loss, pred, aout, hout] = two_layer_affine_l2_net_forward(X, y, W1, b1, W2, b2)
+
+  # Compute backward pass
+  [dX, dpred, daout, dhout, dW1, db1, dW2, db2] =
+    two_layer_affine_l2_net_backward(X, y, pred, aout, hout, W1, b1, W2, b2)
+}
+
+two_layer_affine_l2_net_forward = function(matrix[double] X, matrix[double] y,
+                                           matrix[double] W1, matrix[double] b1,
+                                           matrix[double] W2, matrix[double] b2)
+    return (double loss, matrix[double] pred, matrix[double] aout, matrix[double] hout) {
+  # Compute forward pass
+  hout = affine::forward(X, W1, b1)
+  aout = relu::forward(hout)
+  pred = affine::forward(aout, W2, b2)
+
+  # Compute loss
+  loss = l2_loss::forward(pred, y)
+}
+
+two_layer_affine_l2_net_backward = function(matrix[double] X, matrix[double] y, matrix[double] pred,
+                                            matrix[double] aout, matrix[double] hout,
+                                            matrix[double] W1, matrix[double] b1,
+                                            matrix[double] W2, matrix[double] b2)
+    return (matrix[double] dX, matrix[double] dpred,
+            matrix[double] daout, matrix[double] dhout,
+            matrix[double] dW1, matrix[double] db1, matrix[double] dW2, matrix[double] db2) {
+  # Compute backward pass
+  dpred = l2_loss::backward(pred, y)
+  [daout, dW2, db2] = affine::backward(dpred, aout, W2, b2)
+  dhout = relu::backward(daout, hout)
+  [dX, dW1, db1] = affine::backward(dhout, X, W1, b1)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml b/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
new file mode 100644
index 0000000..2f90779
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/test/max_pool_simple.dml
@@ -0,0 +1,130 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Max pooling layer.
+ *
+ * This implementation is intended to be a simple, reference version.
+ */
+forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * This implementation is intended to be a simple, reference version.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  # Create output volume
+  out = matrix(0, rows=N, cols=C*Hout*Wout)  
+
+  # Max pooling
+  parfor (n in 1:N, check=0) {  # all examples
+    img = matrix(X[n,], rows=C, cols=Hin*Win)
+    parfor (c in 1:C, check=0) {  # all channels
+      img_slice = matrix(img[c,], rows=Hin, cols=Win)
+      parfor (hout in 1:Hout, check=0) {  # all output rows
+        hin = (hout-1) * strideh + 1
+        parfor (wout in 1:Wout, check=0) {  # all output columns
+          win = (wout-1) * stridew + 1
+          out[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout] =
+            max(img_slice[hin:hin+Hf-1, win:win+Wf-1])
+        }
+      }
+    }
+  }
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, int C, 
+                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   */
+  N = nrow(X)
+  
+  # Create gradient volume
+  dX = matrix(0, rows=N, cols=C*Hin*Win)
+  
+  # Gradient of max pooling
+  parfor (n in 1:N, check=0) {  # all examples
+    img = matrix(X[n,], rows=C, cols=Hin*Win)
+    dimg = matrix(0, rows=C, cols=Hin*Win)
+    parfor (c in 1:C, check=0) {  # all channels
+      img_slice = matrix(img[c,], rows=Hin, cols=Win)
+      dimg_slice = matrix(0, rows=Hin, cols=Win)
+      for (hout in 1:Hout, check=0) {  # all output rows
+        hin = (hout-1) * strideh + 1
+        for (wout in 1:Wout) {  # all output columns
+          win = (wout-1) * stridew + 1
+          img_slice_patch = img_slice[hin:hin+Hf-1, win:win+Wf-1]
+          max_val = max(img_slice_patch)
+          max_val_ind = ppred(img_slice_patch, max_val, "==")  # max value indicator
+          # gradient passes through only for the max value in this patch
+          dimg_slice_patch = max_val_ind * dout[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout]
+          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] =
+            dimg_slice[hin:hin+Hf-1, win:win+Wf-1] + dimg_slice_patch
+        }
+      }
+      dimg[c,] = matrix(dimg_slice, rows=1, cols=Hin*Win)
+    }
+    dX[n,] = matrix(dimg, rows=1, cols=C*Hin*Win)
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/test.dml b/scripts/staging/SystemML-NN/nn/test/test.dml
new file mode 100644
index 0000000..58ee3e1
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/test/test.dml
@@ -0,0 +1,192 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Various tests, not including gradient checks.
+ */
+source("nn/layers/conv.dml") as conv
+source("nn/layers/conv_builtin.dml") as conv_builtin
+source("nn/layers/max_pool.dml") as max_pool
+source("nn/layers/max_pool_builtin.dml") as max_pool_builtin
+source("nn/test/conv_simple.dml") as conv_simple
+source("nn/test/max_pool_simple.dml") as max_pool_simple
+source("nn/util.dml") as util
+
+conv = function() {
+  /*
+   * Test for the `conv` functions.
+   */
+  print("Testing the conv functions.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 3  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  F = 2  # num filters
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 1
+  pad = 1
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+
+  # Create layer
+  [W, b] = conv::init(F, C, Hf, Wf)
+
+  # Forward
+  [out, Hout, Wout] = conv::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+  [out_simple, Hout_simple, Wout_simple] =
+    conv_simple::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+  [out_builtin, Hout_builtin, Wout_builtin] =
+    conv_builtin::forward(X, W, b, C, Hin, Win, Hf, Wf, stride, stride, pad, pad)
+
+  # Equivalency check
+  out = matrix(out, rows=1, cols=N*F*Hout*Wout)
+  out_simple = matrix(out_simple, rows=1, cols=N*F*Hout*Wout)
+  out_builtin = matrix(out_builtin, rows=1, cols=N*F*Hout*Wout)
+  for (i in 1:length(out)) {
+    rel_error = util::check_rel_error(as.scalar(out[1,i]), as.scalar(out_simple[1,i]), 1e-10, 1e-12)
+    rel_error = util::check_rel_error(as.scalar(out[1,i]), as.scalar(out_builtin[1,i]), 1e-10, 1e-12)
+  }
+}
+
+im2col = function() {
+  /*
+   * Test for the `im2col` and `col2im` functions.
+   */
+  print("Testing the im2col and col2im functions.")
+
+	# Generate data
+  C = 3  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  Hf = 3  # filter height
+  Wf = 3  # filter width
+  stride = 2
+  pad = (Hin * stride - Hin + Hf - stride) / 2
+  Hout = as.integer((Hin + 2 * pad - Hf) / stride + 1)
+  Wout = as.integer((Win + 2 * pad - Wf) / stride + 1)
+  x = rand(rows=C, cols=Hin*Win)
+
+  # pad
+  x_pad = util::pad_image(x, Hin, Win, pad, pad)
+
+  # im2col
+  x_cols = util::im2col(x_pad, Hin+2*pad, Win+2*pad, Hf, Wf, stride, stride)
+
+  # col2im
+  x_pad2 = util::col2im(x_cols, C, Hin+2*pad, Win+2*pad, Hf, Wf, stride, stride, "none")
+
+  # Equivalency check
+  equivalent = util::all_equal(x_pad, x_pad2)
+  if (!equivalent)
+    print("ERROR: im2col and then col2im does not yield the original image.")
+}
+
+padding = function() {
+  /*
+   * Test for the `pad_image` and `unpad_image` functions.
+   */
+  print("Testing the padding and unpadding functions.")
+
+  # Generate data
+  C = 3  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  pad = 3  # padding
+  x = rand(rows=C, cols=Hin*Win)
+
+  # Pad image
+  x_pad = util::pad_image(x, Hin, Win, pad, pad)
+  
+  # Check for padded rows & columns
+  for (c in 1:C) {
+    x_pad_slice = matrix(x_pad[c,], rows=Hin+2*pad, cols=Win+2*pad)
+    for (i in 1:pad) {
+      rowsum = sum(x_pad_slice[i,])
+      colsum = sum(x_pad_slice[,i])
+      if (rowsum != 0)
+        print("ERROR: Padding was not applied to row " + i + ".")
+      if (colsum != 0)
+        print("ERROR: Padding was not applied to column " + i + ".")
+    }
+  }
+
+  # Unpad image
+  x1 = util::unpad_image(x_pad, Hin, Win, pad, pad)
+
+  # Equivalency check
+  equivalent = util::all_equal(x, x1)
+  if (!equivalent)
+    print("ERROR: Padding and then unpadding does not yield the original image.")
+}
+
+max_pool = function() {
+  /*
+   * Test for the `max_pool` functions.
+   */
+  print("Testing the max pool functions.")
+
+  # Generate data
+  N = 2  # num examples
+  C = 3  # num channels
+  Hin = 8  # input height
+  Win = 8  # input width
+  Hf = 2  # filter height
+  Wf = 2  # filter width
+  stride = 2
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+
+  # Forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  [out_simple, Hout_simple, Wout_simple] =
+    max_pool_simple::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+  [out_builtin, Hout_builtin, Wout_builtin] =
+    max_pool_builtin::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+
+  # Equivalency check
+  out = matrix(out, rows=1, cols=N*C*Hout*Wout)
+  out_simple = matrix(out_simple, rows=1, cols=N*C*Hout*Wout)
+  out_builtin = matrix(out_builtin, rows=1, cols=N*C*Hout*Wout)
+  for (i in 1:length(out)) {
+    rel_error = util::check_rel_error(as.scalar(out[1,i]), as.scalar(out_simple[1,i]), 1e-10, 1e-12)
+    rel_error = util::check_rel_error(as.scalar(out[1,i]), as.scalar(out_builtin[1,i]), 1e-10, 1e-12)
+  }
+
+  # ---
+  # Check for correct behavior
+  # Generate data
+  C = 2  # num channels
+  Hin = 4  # input height
+  Win = 4  # input width
+  X = matrix(seq(1,16,1), rows=Hin, cols=Win)
+  X = matrix(rbind(X, t(X)), rows=1, cols=C*Hin*Win)
+  X = rbind(X, X)  # N=2
+
+  # Forward
+  [out, Hout, Wout] = max_pool::forward(X, C, Hin, Win, Hf, Wf, stride, stride)
+
+  # Equivalency check
+  target = matrix("6 8 14 16 6 14 8 16", rows=1, cols=C*Hout*Wout)
+  target = rbind(target, target)  # N=2
+  tmp = util::check_all_equal(out, target)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/test/tests.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/tests.dml b/scripts/staging/SystemML-NN/nn/test/tests.dml
new file mode 100644
index 0000000..1b91967
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/test/tests.dml
@@ -0,0 +1,72 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Script to run tests.
+ */
+source("nn/test/grad_check.dml") as grad_check
+source("nn/test/test.dml") as test
+
+print("")
+print("Starting grad checks.")
+print("---")
+
+tmp = grad_check::cross_entropy_loss()
+tmp = grad_check::l1_loss()
+tmp = grad_check::l2_loss()
+tmp = grad_check::log_loss()
+tmp = grad_check::affine()
+tmp = grad_check::conv_simple()
+tmp = grad_check::conv()
+tmp = grad_check::conv_builtin()
+tmp = grad_check::dropout()
+tmp = grad_check::l1_reg()
+tmp = grad_check::l2_reg()
+tmp = grad_check::max_pool_simple()
+tmp = grad_check::max_pool()
+tmp = grad_check::max_pool_builtin()
+tmp = grad_check::relu()
+tmp = grad_check::sigmoid()
+tmp = grad_check::softmax()
+tmp = grad_check::tanh()
+tmp = grad_check::two_layer_affine_l2_net()
+
+print("---")
+print("Grad checks complete -- look for any ERRORs or WARNINGs.")
+print("If any tests involving ReLUs failed, try a few times " +
+      "to ensure that they were not false negatives due to " +
+      "kinks being crossed.")
+print("")
+
+print("")
+print("Starting other tests.")
+print("---")
+
+tmp = test::im2col()
+tmp = test::padding()
+tmp = test::conv()
+tmp = test::max_pool()
+
+print("---")
+print("Other tests complete -- look for any ERRORs or WARNINGs.")
+print("")
+print("")
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/util.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/util.dml b/scripts/staging/SystemML-NN/nn/util.dml
new file mode 100644
index 0000000..213363b
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/util.dml
@@ -0,0 +1,266 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Utility functions.
+ */
+all_equal = function(matrix[double] X1, matrix[double] X2)
+    return(boolean equivalent) {
+  /*
+   * Determine if two matrices are equivalent.
+   *
+   * Inputs:
+   *  - X1: Input matrix, of shape (any, any).
+   *  - X2: Input matrix, of same shape as X1.
+   *
+   * Outputs:
+   *  - equivalent: Whether or not the two matrices are equivalent.
+   */
+  equivalent = as.logical(prod(ppred(X1, X2, "==")))
+}
+
+check_all_equal = function(matrix[double] X1, matrix[double] X2)
+    return(boolean equivalent) {
+  /*
+   * Check if two matrices are equivalent, and report any issues.
+   *
+   *  - Issues an "ERROR" statement if elements of the two matrices
+   *  are not equal.
+   *
+   * Inputs:
+   *  - X1: Input matrix, of shape (any, any).
+   *  - X2: Input matrix, of same shape as X1.
+   *
+   * Outputs:
+   *  - equivalent: Whether or not the two matrices are equivalent.
+   */
+  # Determine if matrices are equivalent
+  equivalent = all_equal(X1, X2)
+
+  # Evaluate relative error
+  if (!equivalent) {
+      print("ERROR: The two matrices are not equivalent.")
+  }
+}
+
+compute_rel_error = function(double x1, double x2) return (double rel_error) {
+  /*
+   * Relative error measure between two values.
+   *
+   * Uses smoothing to avoid divide-by-zero errors.
+   *
+   * Inputs:
+   *  - x1: First value.
+   *  - x2: Second value.
+   *
+   * Outputs:
+   *  - rel_error: Relative error measure between the two values.
+   */
+  rel_error = abs(x1 - x2) / max(1e-8, abs(x1) + abs(x2))
+}
+
+check_rel_error = function(double x1, double x2, double thresh_error, double thresh_warn)
+    return (double rel_error) {
+  /*
+   * Check and report any issues with the relative error measure between
+   * two values.
+   *
+   *  - Issues an "ERROR" statement for relative errors > thresh_error,
+   *  indicating that the implementation is likely incorrect.
+   *  - Issues a "WARNING" statement for relative errors < thresh_error
+   *  but > thresh_warn, indicating that the implementation may be incorrect.
+   *
+   * Inputs:
+   *  - x1: First value.
+   *  - x2: Second value.
+   *  - thresh_error: Error threshold.
+   *  - thresh_warn: Warning threshold.
+   *
+   * Outputs:
+   *  - rel_error: Relative error measure between the two values.
+   */
+  # Compute relative error
+  rel_error = compute_rel_error(x1, x2)
+
+  # Evaluate relative error
+  if (rel_error > thresh_error) {
+      print("ERROR: Relative error " + rel_error + " > " + thresh_error + " with " + x1 +
+            " vs " + x2 + ".")
+  }
+  else if (rel_error > thresh_warn & rel_error < thresh_error) {
+      print("WARNING: Relative error " + rel_error + " > " + thresh_warn + " with " + x1 +
+            " vs " + x2 + ".")
+  }
+}
+
+im2col = function(matrix[double] img, int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+    return (matrix[double] img_cols) {
+  /*
+   * Rearrange local image regions (patches) into columns.
+   *
+   * Assumes image has already been padded as necessary.
+   *
+   * Inputs:
+   *  - img: Input image, of shape (C, Hin*Win), where C is the number
+   *      of input channels (depth).
+   *  - Hin: Input height, including padding.
+   *  - Win: Input width, including padding.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - img_cols: Local spatial regions (patches) of the image stretched
+   *      out into columns, of shape (C*Hf*Wf, Hout*Wout).
+   */
+  C = nrow(img)
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  img_cols = matrix(0, rows=C*Hf*Wf, cols=Hout*Wout)  # zeros
+  parfor (hout in 1:Hout, check=0) {  # all output rows
+    hin = (hout-1) * strideh + 1
+    parfor (wout in 1:Wout, check=0) {  # all output columns
+      win = (wout-1) * stridew + 1
+      # Extract a local patch of the input image corresponding spatially to the filter sizes.
+      img_patch = matrix(0, rows=C, cols=Hf*Wf)  # zeros
+      parfor (c in 1:C) {  # all channels
+        img_slice = matrix(img[c,], rows=Hin, cols=Win)  # reshape
+        img_patch[c,] = matrix(img_slice[hin:hin+Hf-1, win:win+Wf-1], rows=1, cols=Hf*Wf)
+      }
+      img_cols[,(hout-1)*Wout + wout] = matrix(img_patch, rows=C*Hf*Wf, cols=1)  # reshape
+    }
+  }
+}
+
+col2im = function(matrix[double] img_cols, int C, int Hin, int Win, int Hf, int Wf,
+                  int strideh, int stridew, string reduction)
+    return (matrix[double] img) {
+  /*
+   * Create an image from columns of local image regions (patches).
+   *
+   * The reduction strategy determines how to deal with overlapping
+   * patches.  If it is set to "add", any overlapping patches will be
+   * added together when creating the image.  This is useful when
+   * computing gradients on the original image given gradients on the
+   * patches.  Otherwise, if "none" is provided, any overlapping
+   * patches will just override previous ones when creating the image.
+   * This is useful when recreating an image from the output of
+   * `im2col`.
+   *
+   * Assumes original image was already padded as necessary.
+   *
+   * Inputs:
+   *  - img_cols: Local spatial regions (patches) of the image stretched
+   *      out into columns, of shape (C*Hf*Wf, Hout*Wout).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height, including padding.
+   *  - Win: Input width, including padding.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - reduction: The reduction strategy to use for overlapping
+   *      patches.  Valid options are "add" and "none".
+   *
+   * Outputs:
+   *  - img: Input image, of shape (C, Hin*Win).
+   */
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  img = matrix(0, rows=C, cols=Hin*Win)  # zeros
+  for (hout in 1:Hout) {  # all output rows
+    hin = (hout-1) * strideh + 1
+    for (wout in 1:Wout) {  # all output columns
+      win = (wout-1) * stridew + 1
+      # Extract a local patch of the input image corresponding spatially to the filter sizes.
+      img_patch = matrix(img_cols[,(hout-1)*Wout + wout], rows=C, cols=Hf*Wf)  # zeros
+      parfor (c in 1:C) {  # all channels
+        img_patch_slice = matrix(img_patch[c,], rows=Hf, cols=Wf)  # reshape
+        if (reduction == "add") {
+          img_slice = matrix(0, rows=Hin, cols=Win)
+          img_slice[hin:hin+Hf-1, win:win+Wf-1] = img_patch_slice
+          img[c,] = img[c,] + matrix(img_slice, rows=1, cols=Hin*Win)
+        } else {
+          img_slice = matrix(img[c,], rows=Hin, cols=Win)
+          img_slice[hin:hin+Hf-1, win:win+Wf-1] = img_patch_slice
+          img[c,] = matrix(img_slice, rows=1, cols=Hin*Win)
+        }
+      }
+    }
+  }
+}
+
+pad_image = function(matrix[double] img, int Hin, int Win, int padh, int padw)
+    return (matrix[double] img_padded) {
+  /*
+   * Pads an image along the height and width dimensions with zeros.
+   *
+   * Inputs:
+   *  - img: Input image, of shape (C, Hin*Win), where C is the number
+   *      of input channels (depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - img_padded: The input image padded along the height and width
+   *      dimensions, of shape (C, (Hin+2*padh)*(Win+2*padw)).
+   */
+  C = nrow(img)
+  img_padded = matrix(0, rows=C, cols=(Hin+2*padh)*(Win+2*padw))  # zeros
+  parfor (c in 1:C) {
+    img_slice = matrix(img[c,], rows=Hin, cols=Win)  # depth slice C reshaped
+    img_padded_slice = matrix(0, rows=Hin+2*padh, cols=Win+2*padw)
+    img_padded_slice[padh+1:padh+Hin, padw+1:padw+Win] = img_slice
+    img_padded[c,] = matrix(img_padded_slice, rows=1, cols=(Hin+2*padh)*(Win+2*padw))  # reshape
+  }
+}
+
+unpad_image = function(matrix[double] img_padded, int Hin, int Win, int padh, int padw)
+    return (matrix[double] img) {
+  /*
+   * Unpads an image along the height and width dimensions.
+   *
+   * Inputs:
+   *  - img_padded: The input image padded along the height and width
+   *      dimensions, of shape (C, (Hin+2*padh)*(Win+2*padw)).
+   *  - Hin: Input height of unpadded image.
+   *  - Win: Input width of unpadded image.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - img: Input image, of shape (C, Hin*Win), where C is the number
+   *      of input channels (depth).
+   */
+  C = nrow(img_padded)
+  img = matrix(0, rows=C, cols=Hin*Win)
+  parfor (c in 1:C) {
+    img_padded_slice = matrix(img_padded[c,], rows=(Hin+2*padh), cols=(Win+2*padw))
+    img_slice = img_padded_slice[padh+1:padh+Hin, padw+1:padw+Win]
+    img[c,] = matrix(img_slice, rows=1, cols=Hin*Win)
+  }
+}
+


[2/2] incubator-systemml git commit: [SYSTEMML-618] Deep Learning DML Library

Posted by du...@apache.org.
[SYSTEMML-618] Deep Learning DML Library

This introduces a new deep learning DML library, SystemML-NN.  SystemML-NN is a layers-based library written in pure DML that contains layers with a simple forward/backward API for affine, spatial convolution, max-pooling, non-linearities (relu, sigmoid, softmax, etc.), dropout, loss functions, other layers, optimizers, and gradient checks.

**SystemML-NN**:
_Current status:_
* Layers:
  * Core:
    * Affine
    * Spatial Convolution
    * Max Pooling
  * Nonlinearities:
    * ReLU
    * Sigmoid
    * Softmax
    * Tanh
  * Loss:
    * Cross-entropy loss
    * L1 loss
    * L2 loss
    * Log ("Logistic") loss
  * Regularization:
    * Dropout
    * L1 reg
    * L2 reg
* Optimizers:
  * Adagrad
  * Adam
  * RMSprop
  * SGD
  * SGD w/ Momentum
  * SGD w/ Nesterov Momentum
* Tests:
  * Gradient Checks

The upstream, experimental version of the codebase lives at [https://github.com/dusenberrymw/systemml-nn](https://github.com/dusenberrymw/systemml-nn), and this addition represents the portion of the codebase that is relatively stable.  The idea is to rapidly work on the upstream repo, and then contribute stable pieces to the main SystemML repo.

Closes #160.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/781d24d8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/781d24d8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/781d24d8

Branch: refs/heads/master
Commit: 781d24d86dea1de880c6b66a75882ecfa5f1086c
Parents: 9dc42b2
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Wed May 18 19:02:16 2016 -0600
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Wed May 18 19:02:16 2016 -0600

----------------------------------------------------------------------
 scripts/staging/SystemML-NN/.gitignore          |    2 +
 scripts/staging/SystemML-NN/README.md           |  182 +++
 .../staging/SystemML-NN/nn/layers/affine.dml    |   86 ++
 scripts/staging/SystemML-NN/nn/layers/conv.dml  |  183 +++
 .../SystemML-NN/nn/layers/conv_builtin.dml      |  147 +++
 .../nn/layers/cross_entropy_loss.dml            |   65 +
 .../staging/SystemML-NN/nn/layers/dropout.dml   |   70 ++
 .../staging/SystemML-NN/nn/layers/l1_loss.dml   |   62 +
 .../staging/SystemML-NN/nn/layers/l1_reg.dml    |   53 +
 .../staging/SystemML-NN/nn/layers/l2_loss.dml   |   62 +
 .../staging/SystemML-NN/nn/layers/l2_reg.dml    |   53 +
 .../staging/SystemML-NN/nn/layers/log_loss.dml  |   65 +
 .../staging/SystemML-NN/nn/layers/max_pool.dml  |  133 ++
 .../SystemML-NN/nn/layers/max_pool_builtin.dml  |   92 ++
 scripts/staging/SystemML-NN/nn/layers/relu.dml  |   55 +
 .../staging/SystemML-NN/nn/layers/sigmoid.dml   |   54 +
 .../staging/SystemML-NN/nn/layers/softmax.dml   |   73 ++
 scripts/staging/SystemML-NN/nn/layers/tanh.dml  |   57 +
 .../staging/SystemML-NN/nn/optim/adagrad.dml    |   72 ++
 scripts/staging/SystemML-NN/nn/optim/adam.dml   |   92 ++
 .../staging/SystemML-NN/nn/optim/rmsprop.dml    |   74 ++
 scripts/staging/SystemML-NN/nn/optim/sgd.dml    |   40 +
 .../SystemML-NN/nn/optim/sgd_momentum.dml       |   66 +
 .../SystemML-NN/nn/optim/sgd_nesterov.dml       |   75 ++
 .../staging/SystemML-NN/nn/test/conv_simple.dml |  211 ++++
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 1139 ++++++++++++++++++
 .../SystemML-NN/nn/test/max_pool_simple.dml     |  130 ++
 scripts/staging/SystemML-NN/nn/test/test.dml    |  192 +++
 scripts/staging/SystemML-NN/nn/test/tests.dml   |   72 ++
 scripts/staging/SystemML-NN/nn/util.dml         |  266 ++++
 30 files changed, 3923 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/.gitignore
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/.gitignore b/scripts/staging/SystemML-NN/.gitignore
new file mode 100644
index 0000000..a1b402b
--- /dev/null
+++ b/scripts/staging/SystemML-NN/.gitignore
@@ -0,0 +1,2 @@
+scratch_space
+*.ipynb

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/README.md
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/README.md b/scripts/staging/SystemML-NN/README.md
new file mode 100644
index 0000000..965da67
--- /dev/null
+++ b/scripts/staging/SystemML-NN/README.md
@@ -0,0 +1,182 @@
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% endcomment %}
+-->
+
+# SystemML-NN
+
+### A deep learning library for [Apache SystemML](https://github.com/apache/incubator-systemml).
+
+## Examples:
+### Neural net for regression with vanilla SGD:
+```python
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/l2_loss.dml") as l2_loss
+source("nn/layers/relu.dml") as relu
+source("nn/optim/sgd.dml") as sgd
+
+# Generate input data
+N = 1024 # num examples
+D = 100 # num features
+t = 1 # num targets
+X = rand(rows=N, cols=D, pdf="normal")
+y = rand(rows=N, cols=t)
+
+# Create 2-layer network:
+## affine1 -> relu1 -> affine2
+M = 64 # number of neurons
+[W1, b1] = affine::init(D, M)
+[W2, b2] = affine::init(M, t)
+
+# Initialize optimizer
+lr = 0.05  # learning rate
+mu = 0.9  # momentum
+decay = 0.99  # learning rate decay constant
+
+# Optimize
+print("Starting optimization")
+batch_size = 32
+epochs = 5
+iters = 1024 / batch_size
+for (e in 1:epochs) {
+  for(i in 1:iters) {
+    # Get next batch
+    X_batch = X[i:i+batch_size-1,]
+    y_batch = y[i:i+batch_size-1,]
+
+    # Compute forward pass
+    out1 = affine::forward(X_batch, W1, b1)
+    outr1 = relu::forward(out1)
+    out2 = affine::forward(outr1, W2, b2)
+
+    # Compute loss
+    loss = l2_loss::forward(out2, y_batch)
+    print("L2 loss: " + loss)
+
+    # Compute backward pass
+    dout2 = l2_loss::backward(out2, y_batch)
+    [doutr1, dW2, db2] = affine::backward(dout2, outr1, W2, b2)
+    dout1 = relu::backward(doutr1, out1)
+    [dX_batch, dW1, db1] = affine::backward(dout1, X_batch, W1, b1)
+
+    # Optimize with vanilla SGD
+    W1 = sgd::update(W1, dW1, lr)
+    b1 = sgd::update(b1, db1, lr)
+    W2 = sgd::update(W2, dW2, lr)
+    b2 = sgd::update(b2, db2, lr)
+  }
+  # Decay learning rate
+  lr = lr * decay
+}
+```
+
+### Neural net for multi-class classification with dropout and SGD w/ Nesterov momentum:
+```python
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+# Generate input data
+N = 1024 # num examples
+D = 100 # num features
+t = 5 # num targets
+X = rand(rows=N, cols=D, pdf="normal")
+classes = round(rand(rows=N, cols=1, min=1, max=t, pdf="uniform"))
+y = matrix(0, rows=N, cols=t)
+parfor (i in 1:N) {
+  y[i, as.scalar(classes[i,1])] = 1  # one-hot encoding
+}
+
+# Create network:
+# affine1 -> relu1 -> dropout1 -> affine2 -> relu2 -> dropout2 -> affine3 -> softmax
+H1 = 64 # number of neurons in 1st hidden layer
+H2 = 64 # number of neurons in 2nd hidden layer
+p = 0.5  # dropout probability
+[W1, b1] = affine::init(D, H1)
+[W2, b2] = affine::init(H1, H2)
+[W3, b3] = affine::init(H2, t)
+
+# Initialize SGD w/ Nesterov momentum optimizer
+lr = 0.05  # learning rate
+mu = 0.5  # momentum
+decay = 0.99  # learning rate decay constant
+vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+
+# Optimize
+print("Starting optimization")
+batch_size = 64
+epochs = 10
+iters = 1024 / batch_size
+for (e in 1:epochs) {
+  for(i in 1:iters) {
+    # Get next batch
+    X_batch = X[i:i+batch_size-1,]
+    y_batch = y[i:i+batch_size-1,]
+
+    # Compute forward pass
+    ## layer 1:
+    out1 = affine::forward(X_batch, W1, b1)
+    outr1 = relu::forward(out1)
+    [outd1, maskd1] = dropout::forward(outr1, p, -1)
+    ## layer 2:
+    out2 = affine::forward(outd1, W2, b2)
+    outr2 = relu::forward(out2)
+    [outd2, maskd2] = dropout::forward(outr2, p, -1)
+    ## layer 3:
+    out3 = affine::forward(outd2, W3, b3)
+    probs = softmax::forward(out3)
+
+    # Compute loss
+    loss = cross_entropy_loss::forward(probs, y_batch)
+    print("Cross entropy loss: " + loss)
+
+    # Compute backward pass
+    ## loss:
+    dprobs = cross_entropy_loss::backward(probs, y_batch)
+    ## layer 3:
+    dout3 = softmax::backward(dprobs, out3)
+    [doutd2, dW3, db3] = affine::backward(dout3, outd2, W3, b3)
+    ## layer 2:
+    doutr2 = dropout::backward(doutd2, outr2, p, maskd2)
+    dout2 = relu::backward(doutr2, out2)
+    [doutd1, dW2, db2] = affine::backward(dout2, outd1, W2, b2)
+    ## layer 1:
+    doutr1 = dropout::backward(doutd1, outr1, p, maskd1)
+    dout1 = relu::backward(doutr1, out1)
+    [dX_batch, dW1, db1] = affine::backward(dout1, X_batch, W1, b1)
+
+    # Optimize with SGD w/ Nesterov momentum
+    [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+    [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+    [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+    [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+    [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+    [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+  }
+  # Anneal momentum towards 0.999
+  mu = mu + (0.999 - mu)/(1+epochs-e)
+  # Decay learning rate
+  lr = lr * decay
+}
+```

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/affine.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/affine.dml b/scripts/staging/SystemML-NN/nn/layers/affine.dml
new file mode 100644
index 0000000..1338de4
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/affine.dml
@@ -0,0 +1,86 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Fully-connected (affine) layer.
+ */
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b)
+    return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a fully-connected (affine) layer with
+   * M neurons.  The input data has N examples, each with D features.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - W: Weights (parameters) matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, M).
+   */
+  out = X %*% W + b
+}
+
+backward = function(matrix[double] dout, matrix[double] X,
+                    matrix[double] W, matrix[double] b)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Computes the backward pass for a fully-connected (affine) layer
+   * with M neurons.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, M).
+   *  - X: Previous input data matrix, of shape (N, D).
+   *  - W: Weights (parameters) matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, D).
+   *  - dW: Gradient wrt W, of shape (D, M).
+   *  - db: Gradient wrt b, of shape (1, M).
+   */
+  dX = dout %*% t(W)
+  dW = t(X) %*% dout
+  db = colSums(dout)
+}
+
+init = function(int D, int M)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   * 
+   * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
+   * which limits the magnification of inputs/gradients during
+   * forward/backward passes by scaling unit-Gaussian weights by a
+   * factor of sqrt(2/n), under the assumption of relu neurons.
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features.
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weight matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   */
+  W = rand(rows=D, cols=M, pdf="normal") * sqrt(2.0/D)
+  b = matrix(0, rows=1, cols=M) 
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/conv.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv.dml b/scripts/staging/SystemML-NN/nn/layers/conv.dml
new file mode 100644
index 0000000..0fbcf99
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/conv.dml
@@ -0,0 +1,183 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Convolutional layer.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+                   int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew, int padh, int padw)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial convolutional layer with
+   * F filters.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * This implementation uses `im2col` internally for each image to
+   * extract local image regions (patches) into columns, and then
+   * performs a matrix multiplication with the filters to compute the
+   * output maps.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
+  
+  # Create output volume
+  out = matrix(0, rows=N, cols=F*Hout*Wout)
+
+  # Convolution - im2col implementation
+  parfor (n in 1:N) {  # all examples
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)  # reshape
+
+    # Pad image
+    Xn_padded = util::pad_image(Xn, Hin, Win, padh, padw)  # shape (C, (Hin+2*padh)*(Win+2*padw))
+
+    # Extract local image patches into columns with im2col, of shape (C*Hf*Wf, Hout*Wout)
+    Xn_padded_cols = util::im2col(Xn_padded, Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew)
+
+    # Convolve patches with filters
+    outn = W %*% Xn_padded_cols + b  # shape (F, Hout*Wout)
+    out[n,] = matrix(outn, rows=1, cols=F*Hout*Wout)  # reshape
+  }
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout,
+                    matrix[double] X, matrix[double] W, matrix[double] b,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Computes the backward pass for a 2D spatial convolutional layer
+   * with F filters.
+   *
+   * This implementation uses `im2col` and `col2im` internally.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Previous input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt W, of shape (F, C*Hf*Wf).
+   *  - db: Gradient wrt b, of shape (F, 1).
+   */
+  N = nrow(X)
+  F = nrow(W)
+  
+  # Create gradient volumes
+  dX = matrix(0, rows=N, cols=C*Hin*Win)
+  dW = matrix(0, rows=F, cols=C*Hf*Wf)
+  db = matrix(0, rows=F, cols=1)
+
+  # Create convenience gradient volumes for dW and db that will allow
+  # for one gradient to be stored per example, allowing for parallel
+  # computation at the expense of memory.  We will reduce at the end.
+  dWN = matrix(0, rows=N, cols=F*C*Hf*Wf)
+  dbN = matrix(0, rows=N, cols=F)
+
+  # Partial derivatives for convolution - im2col implementation
+  parfor (n in 1:N) {  # all examples
+    doutn = matrix(dout[n,], rows=F, cols=Hout*Wout)
+
+    # Compute dW
+    Xn = matrix(X[n,], rows=C, cols=Hin*Win)  # reshape
+    Xn_padded = util::pad_image(Xn, Hin, Win, padh, padw)  # shape (C, (Hin+2*padh)*(Win+2*padw))
+    Xn_padded_cols = util::im2col(Xn_padded, Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew)
+    #dW = dW + doutn %*% t(Xn_padded_cols)
+    dWN[n,] = matrix(doutn %*% t(Xn_padded_cols), rows=1, cols=F*C*Hf*Wf)
+
+    # Compute db
+    #db = db + rowSums(doutn)
+    dbN[n,] = matrix(rowSums(doutn), rows=1, cols=F)
+
+    # Compute dX
+    dXn_padded_cols = t(W) %*% doutn  # shape (C*Hf*Wf, Hout*Wout)
+    dXn_padded =
+      util::col2im(dXn_padded_cols, C, Hin+2*padh, Win+2*padw, Hf, Wf, strideh, stridew, "add")
+    dXn = util::unpad_image(dXn_padded, Hin, Win, padh, padw)
+    dX[n,] = matrix(dXn, rows=1, cols=C*Hin*Win)  # reshape
+  }
+
+  # Reduce convenience gradient volumes with one gradient per example
+  # into single gradients for W and b.
+  dW = matrix(colSums(dWN), rows=F, cols=C*Hf*Wf)
+  db = matrix(colSums(dbN), rows=F, cols=1)
+}
+
+init = function(int F, int C, int Hf, int Wf)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   * 
+   * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
+   * which limits the magnification of inputs/gradients during
+   * forward/backward passes by scaling unit-Gaussian weights by a
+   * factor of sqrt(2/n), under the assumption of relu neurons.
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   */
+  W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1) 
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
new file mode 100644
index 0000000..a73405e
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
@@ -0,0 +1,147 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Convolutional layer.
+ */
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
+                   int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew, int padh, int padw)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial convolutional layer with
+   * F filters.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * This implementation uses `im2col` internally for each image to
+   * extract local image regions (patches) into columns, and then
+   * performs a matrix multiplication with the filters to compute the
+   * output maps.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  F = nrow(W)
+  Hout = as.integer((Hin + 2 * padh - Hf) / strideh + 1)
+  Wout = as.integer((Win + 2 * padw - Wf) / stridew + 1)
+  
+  # Convolution - built-in implementation
+  out = conv2d(X, W, input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf],
+               stride=[strideh,stridew], padding=[padh,padw])
+
+  # Add bias term to each output filter
+  bias = b
+  for (i in 1:Hout*Wout-1)
+    bias = cbind(bias, b)  # creating shape (F, Hout*Wout)
+  out = out + matrix(bias, rows=1, cols=F*Hout*Wout)
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout,
+                    matrix[double] X, matrix[double] W, matrix[double] b,
+                    int C, int Hin, int Win, int Hf, int Wf,
+                    int strideh, int stridew, int padh, int padw)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Computes the backward pass for a 2D spatial convolutional layer
+   * with F filters.
+   *
+   * This implementation uses `im2col` and `col2im` internally.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Previous input data matrix, of shape (N, C*Hin*Win).
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt W, of shape (F, C*Hf*Wf).
+   *  - db: Gradient wrt b, of shape (F, 1).
+   */
+  N = nrow(X)
+  F = nrow(W)
+  
+  # Partial derivatives for convolution - built-in implementation
+  dW = conv2d_backward_filter(X, dout, stride=[strideh,stridew], padding=[padh,padw],
+                              input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf])
+  dX = conv2d_backward_data(W, dout, stride=[strideh, stridew], padding=[padh,padw],
+                            input_shape=[N,C,Hin,Win], filter_shape=[F,C,Hf,Wf])
+
+  # Partial derivatives for bias terms
+  db = matrix(0, rows=F, cols=1)
+  for (n in 1:N) {
+    doutn = matrix(dout[n,], rows=F, cols=Hout*Wout)
+    db = db + rowSums(doutn)
+  }
+}
+
+init = function(int F, int C, int Hf, int Wf)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   * 
+   * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
+   * which limits the magnification of inputs/gradients during
+   * forward/backward passes by scaling unit-Gaussian weights by a
+   * factor of sqrt(2/n), under the assumption of relu neurons.
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (F, C*Hf*Wf).
+   *  - b: Biases vector, of shape (F, 1).
+   */
+  W = rand(rows=F, cols=C*Hf*Wf, pdf="normal") * sqrt(2.0/(C*Hf*Wf))
+  b = matrix(0, rows=F, cols=1) 
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
new file mode 100644
index 0000000..6b9840f
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
@@ -0,0 +1,65 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Cross-entropy loss function.
+ *
+ * L_i = -y_i^T * log(pred_i), where y_i and pred_i are K-dimensional
+ *  vectors of class probs.
+ * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
+ */
+forward = function(matrix[double] pred, matrix[double] y) 
+    return (double loss) {
+  /*
+   * Computes the forward pass for a cross-entropy loss function.  The
+   * inputs consist of N examples, each with K dimensions corresponding
+   * to normalized probabilities of K classes.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, K).
+   *  - y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   */
+  N = nrow(y)
+  losses = rowSums(-y * log(pred))
+  loss = sum(losses) / N
+}
+
+backward = function(matrix[double] pred, matrix[double] y) 
+    return (matrix[double] dpred) {
+  /*
+   * Computes the backward pass of a cross-entropy loss function.  The
+   * inputs consist of N examples, each with K dimensions corresponding
+   * to normalized probabilities of K classes.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, K).
+   *  - y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - dpred: Gradient wrt pred, of shape (N, K).
+   */
+  N = nrow(y)
+  dpred = (1/N) * -y * (1/pred)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/dropout.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/dropout.dml b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
new file mode 100644
index 0000000..e3c34f9
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
@@ -0,0 +1,70 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Dropout layer.
+ */
+forward = function(matrix[double] X, double p, int seed)
+    return (matrix[double] out, matrix[double] mask) {
+  /*
+   * Computes the forward pass for an inverted dropout layer.
+   *
+   * Drops the inputs element-wise with a probability p, and divides
+   * by p to maintain the expected values of those inputs (which are
+   * the outputs of neurons) at test time.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (any, any).
+   *  - p: Probability of keeping a neuron output.
+   *  - seed: [Optional: -1] Random number generator seed.  Setting this
+   *      allows for deterministic evaluation.  Set to -1 for a random
+   *      seed.
+   *
+   * Outputs:
+   *  - out: Ouptuts, of same shape as X.
+   *  - mask: Dropout mask used to compute the output.
+   */
+  if (seed == -1)
+    seed = as.integer(floor(as.scalar(rand(rows=1, cols=1, min=1, max=100000))))
+  mask = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, seed=seed) <= p
+  out = X * mask / p
+}
+
+backward = function(matrix[double] dout, matrix[double] X, double p, matrix[double] mask)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for an inverted dropout layer.
+   *
+   * Applies the mask to the upstream gradient, and divides by p to
+   * maintain the expected values at test time.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of same shape as X.
+   *  - X: Previous input data matrix, of shape (any, any).
+   *  - p: Previous probability of keeping a neuron output.
+   *  - mask: Previous dropout mask used to compute the output.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X.
+   */
+  dX = mask / p * dout
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
new file mode 100644
index 0000000..00db8a7
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
@@ -0,0 +1,62 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * L1 loss function.
+ *
+ * L_i = sum_j(abs((pred_i)_j - (y_i)_j)) for all j.
+ * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
+ */
+forward = function(matrix[double] pred, matrix[double] y) 
+    return (double loss) {
+  /*
+   * Computes the forward pass for an L1 loss function.  The inputs
+   * consist of N examples, each with M dimensions to predict.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, M).
+   *  - y: Target matrix, of shape (N, M).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   */
+  N = nrow(y)
+  losses = rowSums(abs(pred - y))
+  loss = sum(losses) / N
+}
+
+backward = function(matrix[double] pred, matrix[double] y) 
+    return (matrix[double] dpred) {
+  /*
+   * Computes the backward pass for an L1 loss function.  The inputs
+   * consist of N examples, each with M dimensions to predict.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, M).
+   *  - y: Target matrix, of shape (N, M).
+   *
+   * Outputs:
+   *  - dpred: Gradient wrt pred, of shape (N, M).
+   */
+  N = nrow(y)
+  dpred = sign(pred - y) / N
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml b/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
new file mode 100644
index 0000000..28de74c
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/l1_reg.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * L1 regularizataion.
+ */
+forward = function(matrix[double] X, double lambda) return (double reg_loss) {
+  /*
+   * Computes the forward pass for an L1 regularization function.
+   *
+   * Inputs:
+   *  - X: Parameters, of shape (any, any).
+   *  - lambda: Regularization strength.
+   *      A typical value is 0.01.
+   *
+   * Outputs:
+   *  - reg_loss: Scalar L1 regularization loss, of shape (1).
+   */
+  reg_loss = lambda * sum(abs(X))
+}
+
+backward = function(matrix[double] X, double lambda) return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for an L1 regularization function.
+   *
+   * Inputs:
+   *  - X: Parameters, of shape (any, any).
+   *  - lambda: Regularization strength.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X. 
+   */
+  dX = lambda * sign(X)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
new file mode 100644
index 0000000..13b6c2d
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
@@ -0,0 +1,62 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * L2 loss function.
+ *
+ * L_i = (1/2) 2norm(pred_i - y_i)^2
+ * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
+ */
+forward = function(matrix[double] pred, matrix[double] y) 
+    return (double loss) {
+  /*
+   * Computes the forward pass for an L2 loss function.  The inputs
+   * consist of N examples, each with M dimensions to predict.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, M).
+   *  - y: Target matrix, of shape (N, M).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   */
+  N = nrow(y)
+  losses = 0.5 * rowSums((pred - y)^2)
+  loss = sum(losses) / N
+}
+
+backward = function(matrix[double] pred, matrix[double] y) 
+    return (matrix[double] dpred) {
+  /*
+   * Computes the backward pass for an L2 loss function.  The inputs
+   * consist of N examples, each with M dimensions to predict.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, M).
+   *  - y: Target matrix, of shape (N, M).
+   *
+   * Outputs:
+   *  - dpred: Gradient wrt pred, of shape (N, M).
+   */
+  N = nrow(y)
+  dpred = (pred - y) / N
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml b/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
new file mode 100644
index 0000000..22df974
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/l2_reg.dml
@@ -0,0 +1,53 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * L2 regularizataion.
+ */
+forward = function(matrix[double] X, double lambda) return (double reg_loss) {
+  /*
+   * Computes the forward pass for an L2 regularization function.
+   *
+   * Inputs:
+   *  - X: Parameters, of shape (any, any).
+   *  - lambda: Regularization strength.
+   *      A typical value is 0.01.
+   *
+   * Outputs:
+   *  - reg_loss: Scalar l2 regularization loss, of shape (1).
+   */
+  reg_loss = 0.5 * lambda * sum(X^2)
+}
+
+backward = function(matrix[double] X, double lambda) return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for an L2 regularization function.
+   *
+   * Inputs:
+   *  - X: Parameters, of shape (any, any).
+   *  - lambda: Regularization strength.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X. 
+   */
+  dX = lambda * X
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/log_loss.dml b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
new file mode 100644
index 0000000..e3da456
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
@@ -0,0 +1,65 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Log loss function.
+ *
+ * L_i = -y_i*log(pred_i) - (1-y_i)*log(1-pred_i), where y_i is a
+ *  binary target, and pred_i is a probability of y=1. 
+ * L = (1/N) sum(L_i) for i=1 to N, where N is the number of examples.
+ */
+forward = function(matrix[double] pred, matrix[double] y) 
+    return (double loss) {
+  /*
+   * Computes the forward pass for a log loss function.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, 1).  Predictions should
+   *      be probabilities that y=1.
+   *  - y: Target matrix, of shape (N, 1).  Targets should be binary
+   *      in the set {0,1}.
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   */
+  N = nrow(y)
+  losses = -y * log(pred) - (1-y) * log(1-pred)
+  loss = sum(losses) / N
+}
+
+backward = function(matrix[double] pred, matrix[double] y) 
+    return (matrix[double] dpred) {
+  /*
+   * Computes the backward pass for a log loss function.
+   *
+   * Inputs:
+   *  - pred: Prediction matrix, of shape (N, 1).  Predictions should
+   *      be probabilities that y=1.
+   *  - y: Target matrix, of shape (N, 1).  Targets should be binary
+   *      in the set {0,1}.
+   *
+   * Outputs:
+   *  - dpred: Gradient wrt pred, of shape (N, 1).
+   */
+  N = nrow(y)
+  dpred = (1/N) * (pred-y) / (pred * (1-pred))
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/max_pool.dml b/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
new file mode 100644
index 0000000..5dc4638
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/max_pool.dml
@@ -0,0 +1,133 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Max pooling layer.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * This implementation uses `im2col` internally for each image to
+   * extract local image regions (patches) of each channel slice into
+   * columns, and then performs max pooling over the patches to compute
+   * the output maps.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  # Create output volume
+  out = matrix(0, rows=N, cols=C*Hout*Wout)
+
+  # Max pooling - im2col implementation
+  parfor (n in 1:N) {  # all examples
+    img = matrix(X[n,], rows=C, cols=Hin*Win)  # reshape
+    img_maxes = matrix(0, rows=C, cols=Hout*Wout)  # zeros
+
+    parfor (c in 1:C) {  # all channels
+      # Extract local image slice patches into columns with im2col, of shape (Hf*Wf, Hout*Wout)
+      img_slice_cols = util::im2col(img[c,], Hin, Win, Hf, Wf, strideh, stridew)
+
+      # Max pooling on patches
+      img_maxes[c,] = colMaxs(img_slice_cols)
+    }
+
+    out[n,] = matrix(img_maxes, rows=1, cols=C*Hout*Wout)
+  }
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, int C,
+                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   */
+  N = nrow(X)
+  
+  # Create gradient volume
+  dX = matrix(0, rows=N, cols=C*Hin*Win)
+  
+  # Gradient of max pooling
+  parfor (n in 1:N, check=0) {  # all examples
+    img = matrix(X[n,], rows=C, cols=Hin*Win)
+    dimg = matrix(0, rows=C, cols=Hin*Win)
+    parfor (c in 1:C, check=0) {  # all channels
+      img_slice = matrix(img[c,], rows=Hin, cols=Win)
+      dimg_slice = matrix(0, rows=Hin, cols=Win)
+      for (hout in 1:Hout, check=0) {  # all output rows
+        hin = (hout-1) * strideh + 1
+        for (wout in 1:Wout) {  # all output columns
+          win = (wout-1) * stridew + 1
+          img_slice_patch = img_slice[hin:hin+Hf-1, win:win+Wf-1]
+          max_val = max(img_slice_patch)
+          max_val_ind = ppred(img_slice_patch, max_val, "==")  # max value indicator
+          # gradient passes through only for the max value in this patch
+          dimg_slice_patch = max_val_ind * dout[n, (c-1)*Hout*Wout + (hout-1)*Wout + wout]
+          dimg_slice[hin:hin+Hf-1, win:win+Wf-1] =
+            dimg_slice[hin:hin+Hf-1, win:win+Wf-1] + dimg_slice_patch
+        }
+      }
+      dimg[c,] = matrix(dimg_slice, rows=1, cols=Hin*Win)
+    }
+    dX[n,] = matrix(dimg, rows=1, cols=C*Hin*Win)
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml b/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
new file mode 100644
index 0000000..97e991a
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/max_pool_builtin.dml
@@ -0,0 +1,92 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Max pooling layer.
+ */
+forward = function(matrix[double] X, int C, int Hin, int Win, int Hf, int Wf,
+                   int strideh, int stridew)
+    return (matrix[double] out, int Hout, int Wout) {
+  /*
+   * Computes the forward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * This implementation uses `im2col` internally for each image to
+   * extract local image regions (patches) of each channel slice into
+   * columns, and then performs max pooling over the patches to compute
+   * the output maps.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  N = nrow(X)
+  Hout = as.integer((Hin - Hf) / strideh + 1)
+  Wout = as.integer((Win - Wf) / stridew + 1)
+
+  # Max pooling - built-in implementation
+  out = max_pool(X, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf], stride=[strideh,stridew],
+                 padding=[0,0])
+}
+
+backward = function(matrix[double] dout, int Hout, int Wout, matrix[double] X, int C,
+                    int Hin, int Win, int Hf, int Wf, int strideh, int stridew)
+    return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of shape (N, C*Hin*Win).
+   */
+  N = nrow(X)
+  
+  # Gradient of max pooling
+  dX = max_pool_backward(X, dout, input_shape=[N,C,Hin,Win], pool_size=[Hf,Wf],
+                         stride=[strideh,stridew], padding=[0,0])
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/relu.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/relu.dml b/scripts/staging/SystemML-NN/nn/layers/relu.dml
new file mode 100644
index 0000000..a5c5230
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/relu.dml
@@ -0,0 +1,55 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Rectified Linear Unit (ReLU) nonlinearity layer.
+ */
+forward = function(matrix[double] X) return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a ReLU nonlinearity layer.
+   *
+   * Performs an element-wise evaluation of f(input) = max(0, input).
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - out: Ouptuts, of same shape as X.
+   */
+  out = max(0.0, X)
+}
+
+backward = function(matrix[double] dout, matrix[double] X) return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a ReLU nonlinearity layer.
+   *
+   * Essentially performs a pass-through of the upstream gradient for cells > 0.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of same shape as X.
+   *  - X: Previous input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X.
+   */
+   dX = (X > 0) * dout
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/sigmoid.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/sigmoid.dml b/scripts/staging/SystemML-NN/nn/layers/sigmoid.dml
new file mode 100644
index 0000000..a7066f2
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/sigmoid.dml
@@ -0,0 +1,54 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Sigmoid nonlinearity layer.
+ */
+forward = function(matrix[double] X) return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a sigmoid nonlinearity layer.
+   *
+   * sigmoid(x) = 1 / (1 + e^-x)
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - out: Ouptuts, of same shape as X.
+   */
+  out = 1 / (1 + exp(-X))
+}
+
+backward = function(matrix[double] dout, matrix[double] X) return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a sigmoid nonlinearity layer.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of same shape as X.
+   *  - X: Previous input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X.
+   */
+  out = 1 / (1 + exp(-X))
+  dX = out * (1 - out) * dout
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/softmax.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/softmax.dml b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
new file mode 100644
index 0000000..2576162
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
@@ -0,0 +1,73 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Softmax classifier layer.
+ */
+forward = function(matrix[double] scores) return (matrix[double] probs) {
+  /*
+   * Computes the forward pass for a softmax classifier.  The inputs
+   * are interpreted as unnormalized, log-probabilities for each of
+   * N examples, and the softmax function transforms them to normalized
+   * probabilities.
+   *
+   * Inputs:
+   *  - scores: Input data matrix, of shape (N, D).
+   *
+   * Outputs:
+   *  - probs: Outputs, of shape (N, D).
+   */
+  # For numerical stability, we subtract the max score of an example from all scores for that
+  # example.  This is equivalent:
+  # e^scores_i / sum(e^scores_i) == C*e^scores_i / C*sum(e^scores_i)
+  #                              == e^(scores_i+log(C)) / sum(e^(scores_i+log(C))
+  # set log(C) = -max(scores_i):
+  #                              == e^(scores_i-max(scores_i)) / sum(e^(scores_i-max(scores_i))
+  scores = scores - rowMaxs(scores)  # numerical stability
+  unnorm_probs = exp(scores)
+  probs = unnorm_probs / rowSums(unnorm_probs)
+}
+
+backward = function(matrix[double] dprobs, matrix[double] scores)
+    return (matrix[double] dscores) {
+  /*
+   * Computes the backward pass for a softmax classifier.
+   *
+   * dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)
+   * dprobs_ic/dscores_ij = probs_ij * -probs_ic
+   *
+   * dloss/dscores_ij = dloss/dprobs_ij * dprobs_ij/dscores_ij + 
+   *                    sum_c(dloss/dprobs_ic * dprobs_ic/dscores_ij)
+   *
+   * Inputs:
+   *  - dprobs: Derivatives from upstream, of shape (N, D).
+   *  - scores: Previous input data matrix, of shape (N, D).
+   *
+   * Outputs:
+   *  - dscores: Gradient wrt scores, of shape (N, D).
+   */
+  scores = scores - rowMaxs(scores)  # numerical stability
+  unnorm_probs = exp(scores)
+  probs = unnorm_probs / rowSums(unnorm_probs)
+  dscores = dprobs * probs
+  dscores = dscores - probs * rowSums(dscores)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/layers/tanh.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/tanh.dml b/scripts/staging/SystemML-NN/nn/layers/tanh.dml
new file mode 100644
index 0000000..e886081
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/tanh.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Tanh nonlinearity layer.
+ */
+forward = function(matrix[double] X) return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a tanh nonlinearity layer.
+   *
+   * tanh(x) = (e^x - e^-x) / (e^x + e^-x)
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - out: Ouptuts, of same shape as X.
+   */
+  # Simplification of the above formulation:
+  sigma2X = 1 / (1 + exp(-2*X))
+  out = 2 * sigma2X - 1
+}
+
+backward = function(matrix[double] dout, matrix[double] X) return (matrix[double] dX) {
+  /*
+   * Computes the backward pass for a tanh nonlinearity layer.
+   *
+   * Inputs:
+   *  - dout: Derivatives from upstream, of same shape as X.
+   *  - X: Previous input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt X, of same shape as X.
+   */
+  sigma2X = 1 / (1 + exp(-2*X))
+  out = 2 * sigma2X - 1
+  dX = (1 - out^2) * dout
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/adagrad.dml b/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
new file mode 100644
index 0000000..daa5f5e
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
@@ -0,0 +1,72 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Adagrad optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr, double epsilon,
+                  matrix[double] cache)
+    return (matrix[double] X, matrix[double] cache) {
+  /*
+   * Performs an Adagrad update.
+   *
+   * This is an adaptive learning rate optimizer that maintains the
+   * sum of squared gradients to automatically adjust the effective
+   * learning rate.
+   *
+   * Reference:
+   *  - Adaptive Subgradient Methods for Online Learning and Stochastic
+   *    Optimization, Duchi et al.
+   *      - http://jmlr.org/papers/v12/duchi11a.html
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-8, 1e-4].
+   *  - cache: State that maintains per-parameter sum of squared
+   *      gradients, of same shape as X.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - v: Updated velocity of the parameters X, of same shape as
+   *      input v.
+   */
+  cache = cache + dX^2
+  X = X - lr * dX / (sqrt(cache) + epsilon)
+}
+
+init = function(matrix[double] X) return (matrix[double] cache) {
+  /*
+   * Initialize the state for this optimizer.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   * 
+   * Outputs:
+   *  - cache: State that maintains per-parameter sum of squared
+   *      gradients, of same shape as X.
+   */
+  cache = matrix(0, rows=nrow(X), cols=ncol(X))
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/adam.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/adam.dml b/scripts/staging/SystemML-NN/nn/optim/adam.dml
new file mode 100644
index 0000000..05152f4
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/adam.dml
@@ -0,0 +1,92 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Adam optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr, double beta1, double beta2,
+                  double epsilon, int t, matrix[double] m, matrix[double] v)
+    return (matrix[double] X, matrix[double] m, matrix[double] v) {
+  /*
+   * Performs an Adam update.
+   *
+   * Reference:
+   *  - Adam: A Method for Stochastic Optimization, Kingma, Ba.
+   *    - http://arxiv.org/abs/1412.6980
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.  Recommended value is 0.001.
+   *  - beta1: Exponential decay rate for the 1st moment estimates.
+   *      Recommended value is 0.9.
+   *  - beta2: Exponential decay rate for the 2nd moment estimates.
+   *      Recommended value is 0.999.
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Recommended value is 1e-8.
+   *  - t: Timestep, starting at 0.
+   *  - m: State containing the 1st moment (mean) estimate by
+   *      maintaining exponential moving averages of the gradients, of
+   *      same shape as X.
+   *  - v: State containing the 2nd raw moment (uncentered variance)
+   *      estimate by maintaining exponential moving averages of the
+   *      squared gradients, of same shape as X.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - m: Updated state containing the 1st moment (mean) estimate by
+   *      maintaining exponential moving averages of the gradients, of
+   *      same shape as X.
+   *  - v: Updated state containing the 2nd raw moment (uncentered
+   *      variance) estimate by maintaining exponential moving averages
+   *      of the squared gradients, of same shape as X.
+   */
+  t = t + 1
+  m = beta1 * m + (1 - beta1) * dX  # update biased 1st moment estimate
+  v = beta2 * v + (1 - beta2) * dX^2  # update biased 2nd raw moment estimate
+  #m = m / (1 - beta1^t)  # compute bias-corrected 1st moment estimate
+  #v = v / (1 - beta2^t)  # compute bias-corrected 2nd raw moment estimate
+  #X = X - lr * m / (sqrt(v) + epsilon)  # param update
+  # Simplified for computational efficiency:
+  lr = lr * sqrt(1 - beta2^t) / (1 - beta1^t)
+  X = X - lr * m / (sqrt(v) + epsilon)
+}
+
+init = function(matrix[double] X) return (matrix[double] m, matrix[double] v) {
+  /*
+   * Initialize the state for this optimizer.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   * 
+   * Outputs:
+   *  - m: Initial state containing the 1st moment (mean) estimate by
+   *      maintaining exponential moving averages of the gradients, of
+   *      same shape as X.
+   *  - v: Initial state containing the 2nd raw moment (uncentered
+   *      variance) estimate by maintaining exponential moving averages
+   *      of the squared gradients, of same shape as X.
+   */
+  m = matrix(0, rows=nrow(X), cols=ncol(X))
+  v = matrix(0, rows=nrow(X), cols=ncol(X))
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml b/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
new file mode 100644
index 0000000..31b78d5
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
@@ -0,0 +1,74 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * RMSprop optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr, double decay_rate,
+                  double epsilon, matrix[double] cache)
+    return (matrix[double] X, matrix[double] cache) {
+  /*
+   * Performs an RMSprop update.
+   *
+   * This is an adaptive learning rate optimizer that can be viewed
+   * as an adjustment of the Adagrad method to use a moving average
+   * of the sum of squared gradients in order to improve convergence.
+   *
+   * Reference:
+   *  - Neural Networks for Machine Learning, Lecture 6a, Hinton,
+   *    slide 29.
+   *    - http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.
+   *  - decay_rate: Term controlling the rate of the moving average.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-8, 1e-4].
+   *  - cache: State that maintains the moving average of the squared
+   *      gradients, of same shape as X.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - v: Updated velocity of the parameters X, of same shape as
+   *      input v.
+   */
+  cache = decay_rate * cache + (1 - decay_rate) * dX^2
+  X = X - lr * dX / (sqrt(cache) + epsilon)
+}
+
+init = function(matrix[double] X) return (matrix[double] cache) {
+  /*
+   * Initialize the state for this optimizer.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   * 
+   * Outputs:
+   *  - cache: State that maintains the moving average of the squared
+   *      gradients, of same shape as X.
+   */
+  cache = matrix(0, rows=nrow(X), cols=ncol(X))
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/sgd.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/sgd.dml b/scripts/staging/SystemML-NN/nn/optim/sgd.dml
new file mode 100644
index 0000000..554569a
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/sgd.dml
@@ -0,0 +1,40 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Stochastic Gradient Descent (SGD) optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr) return (matrix[double] X) {
+  /*
+   * Performs a vanilla SGD update.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   */
+  X = X - lr * dX
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml b/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
new file mode 100644
index 0000000..22a88f2
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Stochastic Gradient Descent with momentum (SGD-momentum) optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v)
+    return (matrix[double] X, matrix[double] v) {
+  /*
+   * Performs an SGD update with momentum.
+   *
+   * In SGD with momentum, we assume that the parameters have a velocity
+   * that continues with some momentum, and that is influenced by the
+   * gradient.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.
+   *  - mu: Momentum value.
+   *      Typical values are in the range of [0.5, 0.99], usually
+   *      started at the lower end and annealed towards the higher end.
+   *  - v: State maintaining the velocity of the parameters X, of same
+   *      shape as X.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - v: Updated velocity of the parameters X, of same shape as
+   *      input v.
+   */
+  v = mu * v - lr * dX  # update velocity
+  X = X + v  # update position
+}
+
+init = function(matrix[double] X) return (matrix[double] v) {
+  /*
+   * Initialize the state for this optimizer.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   * 
+   * Outputs:
+   *  - v: Initial velocity of the parameters X.
+   */
+  v = matrix(0, rows=nrow(X), cols=ncol(X))
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/781d24d8/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml b/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
new file mode 100644
index 0000000..aac6522
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
@@ -0,0 +1,75 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Stochastic Gradient Descent with Nesterov momentum (SGD-Nesterov) optimizer.
+ */
+update = function(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v)
+    return (matrix[double] X, matrix[double] v) {
+  /*
+   * Performs an SGD update with Nesterov momentum.
+   *
+   * As with regular SGD with momentum, in SGD with Nesterov momentum,
+   * we assume that the parameters have a velocity that continues
+   * with some momentum, and that is influenced by the gradient.
+   * In this view specifically, we perform the position update from the
+   * position that the momentum is about to carry the parameters to,
+   * rather than from the previous position.  Additionally, we always
+   * store the parameters in their position after momentum.
+   *
+   * Reference:
+   *  - Advances in optimizing Recurrent Networks, Bengio et al., section 3.5.
+   *    - http://arxiv.org/abs/1212.0901
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient of X wrt to a loss function being optimized, of
+   *      same shape as X.
+   *  - lr: Learning rate.
+   *  - mu: Momentum value.
+   *      Typical values are in the range of [0.5, 0.99], usually
+   *      started at the lower end and annealed towards the higher end.
+   *  - v: State maintaining the velocity of the parameters X, of same
+   *      shape as X.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - v: Updated velocity of the parameters X, of same shape as
+   *      input v.
+   */
+  v_prev = v
+  v = mu * v - lr * dX  # update velocity
+  X = X - mu * v_prev + (1 + mu) * v  # update position, including momentum
+}
+
+init = function(matrix[double] X) return (matrix[double] v) {
+  /*
+   * Initialize the state for this optimizer.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   * 
+   * Outputs:
+   *  - v: Initial velocity of the parameters X.
+   */
+  v = matrix(0, rows=nrow(X), cols=ncol(X))
+}
+