You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/04/11 00:23:05 UTC

[2/2] incubator-systemml git commit: [SYSTEMML-1468] Add new 1D/2D "Scale & Shift" layers

[SYSTEMML-1468] Add new 1D/2D "Scale & Shift" layers

A "Scale & Shift" layer introduces learnable parameters
(`gamma`, `beta`) to scale and shift the input on either
a per-feature basis (1D) or a per-channel basis (2D).

  `y = x*gamma + beta`

Closes #453.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/65172565
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/65172565
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/65172565

Branch: refs/heads/master
Commit: 6517256511b5953b4efea97600164261243a8402
Parents: f5ef628
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Mon Apr 10 17:20:55 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Mon Apr 10 17:20:55 2017 -0700

----------------------------------------------------------------------
 .../SystemML-NN/nn/layers/scale_shift1d.dml     |  95 +++++
 .../SystemML-NN/nn/layers/scale_shift2d.dml     | 107 ++++++
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 379 +++++++++++++------
 .../staging/SystemML-NN/nn/test/run_tests.dml   |  21 +-
 4 files changed, 486 insertions(+), 116 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/65172565/scripts/staging/SystemML-NN/nn/layers/scale_shift1d.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/scale_shift1d.dml b/scripts/staging/SystemML-NN/nn/layers/scale_shift1d.dml
new file mode 100644
index 0000000..7e162a3
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/scale_shift1d.dml
@@ -0,0 +1,95 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 1D Scale & Shift layer.
+ */
+
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta)
+    return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a 1D scale & shift layer. The input
+   * data has N examples, each with D features.
+   *
+   * A 1D scale & shift layer introduces learnable parameters
+   * (gamma, beta) to scale and shift the input on a per-feature basis.
+   *
+   *   `y = x*gamma + beta`
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, D).
+   */
+  # Scale and shift
+  out = X*gamma + beta  # shape (N, D)
+}
+
+backward = function(matrix[double] dout, matrix[double] out,
+                    matrix[double] X, matrix[double] gamma, matrix[double] beta)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a 1D scale & shift layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, D).
+   *  - out: Outputs from the forward pass, of shape (N, D).
+   *  - X: Inputs, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, D).
+   *  - dgamma: Gradient wrt `W`, of shape (1, D).
+   *  - dbeta: Gradient wrt `b`, of shape (1, D).
+   *
+   */
+  # Compute gradients during training
+  dgamma = colSums(dout*X)  # shape (1, D)
+  dbeta = colSums(dout)  # shape (1, D)
+  dX = dout * gamma  # shape (N, D)
+}
+
+init = function(int D)
+    return (matrix[double] gamma, matrix[double] beta) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * By default, we initialize to an identity function, with a scale
+   * filler of `1`, and a shift filler of `0`.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features (number of features).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   */
+   gamma = matrix(1, rows=1, cols=D)
+   beta = matrix(0, rows=1, cols=D)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/65172565/scripts/staging/SystemML-NN/nn/layers/scale_shift2d.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/scale_shift2d.dml b/scripts/staging/SystemML-NN/nn/layers/scale_shift2d.dml
new file mode 100644
index 0000000..79c884a
--- /dev/null
+++ b/scripts/staging/SystemML-NN/nn/layers/scale_shift2d.dml
@@ -0,0 +1,107 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 2D Scale & Shift layer.
+ */
+source("nn/util.dml") as util
+
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                   int C, int Hin, int Win)
+    return (matrix[double] out) {
+  /*
+   * Computes the forward pass for a 2D scale & shift layer.  The input
+   * data has N examples, each represented as a 3D volume unrolled into
+   * a single vector.
+   *
+   * A 2D scale & shift layer introduces learnable parameters
+   * (gamma, beta) to scale and shift the input on a per-channel basis.
+   *
+   *   `y = x*gamma + beta`
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hin*Win).
+   */
+  # Scale and shift
+  scaled = bias_multiply(X, gamma)  # shape (N, C*Hin*Win)
+  out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
+}
+
+backward = function(matrix[double] dout, matrix[double] out,
+                    matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                    int C, int Hin, int Win)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a 2D scale & shift layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
+   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
+   *  - X: Input data matrix to the forward pass, of
+   *      shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dgamma: Gradient wrt `W`, of shape (C, 1).
+   *  - dbeta: Gradient wrt `b`, of shape (C, 1).
+   *
+   */
+  # Compute gradients during training
+  dgamma = util::channel_sums(dout*X, C, Hin, Win)  # shape (C, 1)
+  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+  dX = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+}
+
+init = function(int C)
+    return (matrix[double] gamma, matrix[double] beta) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * By default, we initialize to an identity function, with a scale
+   * filler of `1`, and a shift filler of `0`.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - C: Number of input channels (dimensionality of input depth).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   */
+   gamma = matrix(1, rows=C, cols=1)
+   beta = matrix(0, rows=C, cols=1)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/65172565/scripts/staging/SystemML-NN/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/grad_check.dml b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
index f21811c..516fe2a 100644
--- a/scripts/staging/SystemML-NN/nn/test/grad_check.dml
+++ b/scripts/staging/SystemML-NN/nn/test/grad_check.dml
@@ -39,6 +39,8 @@ source("nn/layers/max_pool2d.dml") as max_pool2d
 source("nn/layers/max_pool2d_builtin.dml") as max_pool2d_builtin
 source("nn/layers/relu.dml") as relu
 source("nn/layers/rnn.dml") as rnn
+source("nn/layers/scale_shift1d.dml") as scale_shift1d
+source("nn/layers/scale_shift2d.dml") as scale_shift2d
 source("nn/layers/sigmoid.dml") as sigmoid
 source("nn/layers/softmax.dml") as softmax
 source("nn/layers/tanh.dml") as tanh
@@ -229,6 +231,113 @@ batch_norm1d = function() {
   }
 }
 
+batch_norm2d = function() {
+  /*
+   * Gradient check for the 2D (spatial) batch normalization layer.
+   */
+  print("Grad checking the 2D (spatial) batch normalization layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  mu = 0.9  # momentum
+  eps = 1e-5  # epsilon
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*Hin*Win)
+  gamma = rand(rows=C, cols=1)
+  beta = rand(rows=C, cols=1)
+  ema_mean = rand(rows=C, cols=1)
+  ema_var = rand(rows=C, cols=1)
+  #[dummy, dummy, ema_mean, ema_var] = batch_norm2d::init(C)
+
+  # Check training & testing modes
+  for (i in 1:2) {
+    if (i == 1)
+      mode = 'train'
+    else
+      mode = 'test'
+    print(" - Grad checking the '"+mode+"' mode.")
+
+    # Compute analytical gradients of loss wrt parameters
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+    dout = l2_loss::backward(out, y)
+    [dX, dgamma, dbeta] = batch_norm2d::backward(dout, out, ema_mean_upd, ema_var_upd,
+                                                 cache_mean, cache_var, cache_norm,
+                                                 X, gamma, beta, C, Hin, Win, mode,
+                                                 ema_mean, ema_var, mu, eps)
+
+    # Grad check
+    h = 1e-5
+    print("   - Grad checking X.")
+    for (i in 1:nrow(X)) {
+      for (j in 1:ncol(X)) {
+        # Compute numerical derivative
+        old = as.scalar(X[i,j])
+        X[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        X[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        X[i,j] = old  # reset
+        dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+        # Check error
+        rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking gamma.")
+    for (i in 1:nrow(gamma)) {
+      for (j in 1:ncol(gamma)) {
+        # Compute numerical derivative
+        old = as.scalar(gamma[i,j])
+        gamma[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        gamma[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        gamma[i,j] = old  # reset
+        dgamma_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+        # Check error
+        rel_error = test_util::check_rel_grad_error(as.scalar(dgamma[i,j]), dgamma_num,
+                                                    lossph, lossmh)
+      }
+    }
+
+    print("   - Grad checking beta.")
+    for (i in 1:nrow(beta)) {
+      for (j in 1:ncol(beta)) {
+        # Compute numerical derivative
+        old = as.scalar(beta[i,j])
+        beta[i,j] = old - h
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossmh = l2_loss::forward(outmh, y)
+        beta[i,j] = old + h
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
+        lossph = l2_loss::forward(outph, y)
+        beta[i,j] = old  # reset
+        dbeta_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+        # Check error
+        rel_error = test_util::check_rel_grad_error(as.scalar(dbeta[i,j]), dbeta_num,
+                                                    lossph, lossmh)
+      }
+    }
+  }
+}
+
 conv2d = function() {
   /*
    * Gradient check for the 2D convolutional layer using `im2col`.
@@ -1199,6 +1308,168 @@ rnn = function() {
   }
 }
 
+scale_shift1d = function() {
+  /*
+   * Gradient check for the 1D scale & shift layer.
+   */
+  print("Grad checking the 1D scale & shift layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  D = 100 # num features
+  X = rand(rows=N, cols=D)
+  y = rand(rows=N, cols=D)
+  [gamma, beta] = scale_shift1d::init(D)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = scale_shift1d::forward(X, gamma, beta)
+  dout = l2_loss::backward(out, y)
+  [dX, dgamma, dbeta] = scale_shift1d::backward(dout, out, X, gamma, beta)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = scale_shift1d::forward(X, gamma, beta)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = scale_shift1d::forward(X, gamma, beta)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking gamma.")
+  for (i in 1:nrow(gamma)) {
+    for (j in 1:ncol(gamma)) {
+      # Compute numerical derivative
+      old = as.scalar(gamma[i,j])
+      gamma[i,j] = old - h
+      outmh = scale_shift1d::forward(X, gamma, beta)
+      lossmh = l2_loss::forward(outmh, y)
+      gamma[i,j] = old + h
+      outph = scale_shift1d::forward(X, gamma, beta)
+      lossph = l2_loss::forward(outph, y)
+      gamma[i,j] = old  # reset
+      dgamma_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dgamma[i,j]), dgamma_num,
+                                                  lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking beta.")
+  for (i in 1:nrow(beta)) {
+    for (j in 1:ncol(beta)) {
+      # Compute numerical derivative
+      old = as.scalar(beta[i,j])
+      beta[i,j] = old - h
+      outmh = scale_shift1d::forward(X, gamma, beta)
+      lossmh = l2_loss::forward(outmh, y)
+      beta[i,j] = old + h
+      outph = scale_shift1d::forward(X, gamma, beta)
+      lossph = l2_loss::forward(outph, y)
+      beta[i,j] = old  # reset
+      dbeta_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dbeta[i,j]), dbeta_num,
+                                                  lossph, lossmh)
+    }
+  }
+}
+
+scale_shift2d = function() {
+  /*
+   * Gradient check for the 2D scale & shift layer.
+   */
+  print("Grad checking the 2D scale & shift layer with L2 loss.")
+
+  # Generate data
+  N = 3 # num examples
+  C = 2  # num channels
+  Hin = 5  # input height
+  Win = 5  # input width
+  X = rand(rows=N, cols=C*Hin*Win)
+  y = rand(rows=N, cols=C*Hin*Win)
+  [gamma, beta] = scale_shift2d::init(C)
+
+  # Compute analytical gradients of loss wrt parameters
+  out = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+  dout = l2_loss::backward(out, y)
+  [dX, dgamma, dbeta] = scale_shift2d::backward(dout, out, X, gamma, beta, C, Hin, Win)
+
+  # Grad check
+  h = 1e-5
+  print(" - Grad checking X.")
+  for (i in 1:nrow(X)) {
+    for (j in 1:ncol(X)) {
+      # Compute numerical derivative
+      old = as.scalar(X[i,j])
+      X[i,j] = old - h
+      outmh = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossmh = l2_loss::forward(outmh, y)
+      X[i,j] = old + h
+      outph = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossph = l2_loss::forward(outph, y)
+      X[i,j] = old  # reset
+      dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking gamma.")
+  for (i in 1:nrow(gamma)) {
+    for (j in 1:ncol(gamma)) {
+      # Compute numerical derivative
+      old = as.scalar(gamma[i,j])
+      gamma[i,j] = old - h
+      outmh = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossmh = l2_loss::forward(outmh, y)
+      gamma[i,j] = old + h
+      outph = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossph = l2_loss::forward(outph, y)
+      gamma[i,j] = old  # reset
+      dgamma_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dgamma[i,j]), dgamma_num,
+                                                  lossph, lossmh)
+    }
+  }
+
+  print(" - Grad checking beta.")
+  for (i in 1:nrow(beta)) {
+    for (j in 1:ncol(beta)) {
+      # Compute numerical derivative
+      old = as.scalar(beta[i,j])
+      beta[i,j] = old - h
+      outmh = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossmh = l2_loss::forward(outmh, y)
+      beta[i,j] = old + h
+      outph = scale_shift2d::forward(X, gamma, beta, C, Hin, Win)
+      lossph = l2_loss::forward(outph, y)
+      beta[i,j] = old  # reset
+      dbeta_num = (lossph-lossmh) / (2*h)  # numerical derivative
+
+      # Check error
+      rel_error = test_util::check_rel_grad_error(as.scalar(dbeta[i,j]), dbeta_num,
+                                                  lossph, lossmh)
+    }
+  }
+}
+
 sigmoid = function() {
   /*
    * Gradient check for the sigmoid nonlinearity layer.
@@ -1276,114 +1547,6 @@ softmax = function() {
   }
 }
 
-batch_norm2d = function() {
-  /*
-   * Gradient check for the 2D (spatial) batch normalization layer.
-   */
-  print("Grad checking the 2D (spatial) batch normalization layer with L2 loss.")
-
-  # Generate data
-  N = 3 # num examples
-  N = 2  # num examples
-  C = 2  # num channels
-  Hin = 5  # input height
-  Win = 5  # input width
-  mu = 0.9  # momentum
-  eps = 1e-5  # epsilon
-  X = rand(rows=N, cols=C*Hin*Win)
-  y = rand(rows=N, cols=C*Hin*Win)
-  gamma = rand(rows=C, cols=1)
-  beta = rand(rows=C, cols=1)
-  ema_mean = rand(rows=C, cols=1)
-  ema_var = rand(rows=C, cols=1)
-  #[dummy, dummy, ema_mean, ema_var] = batch_norm2d::init(C)
-
-  # Check training & testing modes
-  for (i in 1:2) {
-    if (i == 1)
-      mode = 'train'
-    else
-      mode = 'test'
-    print(" - Grad checking the '"+mode+"' mode.")
-
-    # Compute analytical gradients of loss wrt parameters
-    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-        batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-    dout = l2_loss::backward(out, y)
-    [dX, dgamma, dbeta] = batch_norm2d::backward(dout, out, ema_mean_upd, ema_var_upd,
-                                                 cache_mean, cache_var, cache_norm,
-                                                 X, gamma, beta, C, Hin, Win, mode,
-                                                 ema_mean, ema_var, mu, eps)
-
-    # Grad check
-    h = 1e-5
-    print("   - Grad checking X.")
-    for (i in 1:nrow(X)) {
-      for (j in 1:ncol(X)) {
-        # Compute numerical derivative
-        old = as.scalar(X[i,j])
-        X[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossmh = l2_loss::forward(outmh, y)
-        X[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossph = l2_loss::forward(outph, y)
-        X[i,j] = old  # reset
-        dX_num = (lossph-lossmh) / (2*h)  # numerical derivative
-
-        # Check error
-        rel_error = test_util::check_rel_grad_error(as.scalar(dX[i,j]), dX_num, lossph, lossmh)
-      }
-    }
-
-    print("   - Grad checking gamma.")
-    for (i in 1:nrow(gamma)) {
-      for (j in 1:ncol(gamma)) {
-        # Compute numerical derivative
-        old = as.scalar(gamma[i,j])
-        gamma[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossmh = l2_loss::forward(outmh, y)
-        gamma[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossph = l2_loss::forward(outph, y)
-        gamma[i,j] = old  # reset
-        dgamma_num = (lossph-lossmh) / (2*h)  # numerical derivative
-
-        # Check error
-        rel_error = test_util::check_rel_grad_error(as.scalar(dgamma[i,j]), dgamma_num,
-                                                    lossph, lossmh)
-      }
-    }
-
-    print("   - Grad checking beta.")
-    for (i in 1:nrow(beta)) {
-      for (j in 1:ncol(beta)) {
-        # Compute numerical derivative
-        old = as.scalar(beta[i,j])
-        beta[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossmh = l2_loss::forward(outmh, y)
-        beta[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
-            batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
-        lossph = l2_loss::forward(outph, y)
-        beta[i,j] = old  # reset
-        dbeta_num = (lossph-lossmh) / (2*h)  # numerical derivative
-
-        # Check error
-        rel_error = test_util::check_rel_grad_error(as.scalar(dbeta[i,j]), dbeta_num,
-                                                    lossph, lossmh)
-      }
-    }
-  }
-}
-
 tanh = function() {
   /*
    * Gradient check for the hyperbolic tangent (tanh) nonlinearity

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/65172565/scripts/staging/SystemML-NN/nn/test/run_tests.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/test/run_tests.dml b/scripts/staging/SystemML-NN/nn/test/run_tests.dml
index 0279363..ee4da68 100644
--- a/scripts/staging/SystemML-NN/nn/test/run_tests.dml
+++ b/scripts/staging/SystemML-NN/nn/test/run_tests.dml
@@ -29,34 +29,39 @@ print("")
 print("Starting grad checks.")
 print("---")
 
-# Loss functions
+# Loss & loss-related functions
 tmp = grad_check::cross_entropy_loss()
 tmp = grad_check::l1_loss()
+tmp = grad_check::l1_reg()
 tmp = grad_check::l2_loss()
+tmp = grad_check::l2_reg()
 tmp = grad_check::log_loss()
+print("")
 
-# Other layers
+# Core layers
 tmp = grad_check::affine()
 tmp = grad_check::batch_norm1d()
 tmp = grad_check::batch_norm2d()
-tmp = grad_check::conv2d_simple()
 tmp = grad_check::conv2d()
 tmp = grad_check::conv2d_builtin()
+tmp = grad_check::conv2d_simple()
 tmp = grad_check::dropout()
-tmp = grad_check::l1_reg()
-tmp = grad_check::l2_reg()
 tmp = grad_check::lstm()
-tmp = grad_check::max_pool2d_simple()
 tmp = grad_check::max_pool2d()
 tmp = grad_check::max_pool2d_builtin()
+tmp = grad_check::max_pool2d_simple()
 tmp = grad_check::relu()
 tmp = grad_check::rnn()
+tmp = grad_check::scale_shift1d()
+tmp = grad_check::scale_shift2d()
 tmp = grad_check::sigmoid()
 tmp = grad_check::softmax()
 tmp = grad_check::tanh()
+print("")
 
 # Example model
 tmp = grad_check::two_layer_affine_l2_net()
+print("")
 
 print("---")
 print("Grad checks complete -- look for any ERRORs or WARNINGs.")
@@ -71,11 +76,11 @@ print("---")
 
 tmp = test::batch_norm1d()
 tmp = test::batch_norm2d()
-tmp = test::im2col()
-tmp = test::padding()
 tmp = test::conv2d()
 tmp = test::cross_entropy_loss()
+tmp = test::im2col()
 tmp = test::max_pool2d()
+tmp = test::padding()
 tmp = test::tanh()
 
 print("---")