You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/05/30 22:55:04 UTC

systemml git commit: [SYSTEMML-540] Remove unnecessary variables from batch_norm2d layer

Repository: systemml
Updated Branches:
  refs/heads/master 72fd8fda3 -> 7350a0c6d


[SYSTEMML-540] Remove unnecessary variables from batch_norm2d layer

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7350a0c6
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7350a0c6
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7350a0c6

Branch: refs/heads/master
Commit: 7350a0c6d38b3c018e10d18863295c1a89abc2cd
Parents: 72fd8fd
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Wed May 30 15:37:40 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Wed May 30 15:37:40 2018 -0700

----------------------------------------------------------------------
 scripts/nn/layers/batch_norm2d.dml              | 93 ++++++--------------
 scripts/nn/test/grad_check.dml                  | 30 +++----
 scripts/nn/test/test.dml                        |  3 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 41 +--------
 4 files changed, 44 insertions(+), 123 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7350a0c6/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml
index 49c6746..8a8555f 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -29,7 +29,7 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
                    matrix[double] ema_mean, matrix[double] ema_var,
                    double mu, double epsilon)
     return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
-            matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
+            matrix[double] cache_mean, matrix[double] cache_inv_var) {
   /*
    * Computes the forward pass for a 2D (spatial) batch normalization
    * layer.  The input data has N examples, each represented as a 3D
@@ -80,11 +80,8 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
    *      of shape (C, 1).
    *  - cache_mean: Cache of the batch mean, of shape (C, 1).
    *      Note: This is used for performance during training.
-   *  - cache_var: Cache of the batch variance, of shape (C, 1).
+   *  - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
    *      Note: This is used for performance during training.
-   *  - cache_norm: Cache of the normalized inputs, of
-   *      shape (C, N*Hin*Win). Note: This is used for performance
-   *      during training.
    */
   N = nrow(X)
 
@@ -109,28 +106,24 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
     ema_var_upd = ema_var
   }
 
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_inv_var = 1/sqrt(var+epsilon)
+  
   # Normalize, shift, and scale
   # norm = (X-mean)*(var+epsilon)^(-1/2)
   #      = (X-mean) / sqrt(var+epsilon)
   centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-  norm = bias_multiply(centered, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
   # out = norm*gamma + beta
   scaled = bias_multiply(norm, gamma)  # shape (N, C*Hin*Win)
   out = bias_add(scaled, beta)  # shape (N, C*Hin*Win)
-
-  # Save variable for backward pass
-  cache_mean = mean
-  cache_var = var
-  cache_norm = norm
 }
 
-backward = function(matrix[double] dout, matrix[double] out,
-                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
-                    matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm,
-                    matrix[double] X, matrix[double] gamma, matrix[double] beta,
-                    int C, int Hin, int Win, string mode,
-                    matrix[double] ema_mean, matrix[double] ema_var,
-                    double mu, double epsilon)
+backward = function(matrix[double] dout, 
+                    matrix[double] cache_mean, matrix[double] cache_inv_var,
+                    matrix[double] X, matrix[double] gamma, 
+                    int C, int Hin, int Win, double epsilon)
       return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
   /*
    * Computes the backward pass for a 2D (spatial) batch normalization
@@ -138,38 +131,18 @@ backward = function(matrix[double] dout, matrix[double] out,
    *
    * Inputs:
    *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
-   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
-   *  - ema_mean_upd: Updated exponential moving average of the mean
-   *      from the forward pass, of shape (C, 1).
-   *  - ema_var_upd: Updated exponential moving average of the variance
-   *      from the forward pass, of shape (C, 1).
    *  - cache_mean: Cache of the batch mean from the forward pass, of
    *      shape (C, 1).  Note: This is used for performance during
    *      training.
-   *  - cache_var: Cache of the batch variance from the forward pass,
+   *  - cache_inv_var: Cache of the inverse variance from the forward pass,
    *      of shape (C, 1).  Note: This is used for performance during
    *      training.
-   *  - cache_norm: Cache of the normalized inputs from the forward
-   *      pass, of shape (C, N*Hin*Win).  Note: This is used for
-   *      performance during training.
    *  - X: Input data matrix to the forward pass, of
    *      shape (N, C*Hin*Win).
    *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
    *  - C: Number of input channels (dimensionality of input depth).
    *  - Hin: Input height.
    *  - Win: Input width.
-   *  - mode: 'train' or 'test' to indicate if the model is currently
-   *      being trained or tested.  During training, the current batch
-   *      mean and variance will be used to normalize the inputs, while
-   *      during testing, the exponential average of the mean and
-   *      variance over all previous batches will be used.
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   *  - mu: Momentum value for moving averages.
-   *      Typical values are in the range of [0.9, 0.999].
    *  - epsilon: Smoothing term to avoid divide by zero errors.
    *      Typical values are in the range of [1e-5, 1e-3].
    *
@@ -181,33 +154,22 @@ backward = function(matrix[double] dout, matrix[double] out,
    */
   N = nrow(X)
   mean = cache_mean
-  var = cache_var
-  norm = cache_norm
   centered = bias_add(X, -mean)  # shape (N, C*Hin*Win)
-
-  if (mode == 'train') {
-    # Compute gradients during training
-    dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
-    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
-    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
-    dvar = util::channel_sums((-1/2) * bias_multiply(centered, (var+epsilon)^(-3/2)) * dnorm,
-                              C, Hin, Win)  # shape (C, 1)
-    dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -1/sqrt(var+epsilon)), C, Hin, Win)
-    dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
-    dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within an expression yet
-    dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
-    dX_norm_branch = bias_multiply(dnorm, 1/sqrt(var+epsilon))
-    dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
-    dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
-    dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
-  }
-  else {
-    # Compute gradients during testing
-    dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
-    dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
-    dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
-    dX = bias_multiply(dnorm, 1/sqrt(var+epsilon))  # shape (N, C*Hin*Win)
-  }
+  norm = bias_multiply(centered, cache_inv_var)  # shape (N, C*Hin*Win)
+  # Compute gradients during training
+  dgamma = util::channel_sums(dout*norm, C, Hin, Win)  # shape (C, 1)
+  dbeta = util::channel_sums(dout, C, Hin, Win)  # shape (C, 1)
+  dnorm = bias_multiply(dout, gamma)  # shape (N, C*Hin*Win)
+  dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+                          C, Hin, Win)  # shape (C, 1)
+  dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+  dmean_var_branch =  util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
+  dmean_var_branch = dmean_var_branch * dvar  # we can't use a function within an expression yet
+  dmean = dmean_norm_branch + dmean_var_branch  # shape (C, 1)
+  dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+  dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+  dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
+  dX = dX_norm_branch + dX_mean_branch + dX_var_branch  # shape (N, C*Hin*Win)
 }
 
 init = function(int C)
@@ -235,4 +197,3 @@ init = function(int C)
    ema_mean = matrix(0, rows=C, cols=1)
    ema_var = matrix(1, rows=C, cols=1)
 }
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/7350a0c6/scripts/nn/test/grad_check.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/grad_check.dml b/scripts/nn/test/grad_check.dml
index 8fbfa76..be34408 100644
--- a/scripts/nn/test/grad_check.dml
+++ b/scripts/nn/test/grad_check.dml
@@ -363,21 +363,16 @@ batch_norm2d = function() {
   #[dummy, dummy, ema_mean, ema_var] = batch_norm2d::init(C)
 
   # Check training & testing modes
-  for (i in 1:2) {
-    if (i == 1)
-      mode = 'train'
-    else
-      mode = 'test'
+  # for (i in 1:1) {
+    mode = 'train'
     print(" - Grad checking the '"+mode+"' mode.")
 
     # Compute analytical gradients of loss wrt parameters
-    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+    [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
         batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
     dout = l2_loss::backward(out, y)
-    [dX, dgamma, dbeta] = batch_norm2d::backward(dout, out, ema_mean_upd, ema_var_upd,
-                                                 cache_mean, cache_var, cache_norm,
-                                                 X, gamma, beta, C, Hin, Win, mode,
-                                                 ema_mean, ema_var, mu, eps)
+    [dX, dgamma, dbeta] = batch_norm2d::backward(dout, cache_mean, cache_var, 
+                                                 X, gamma, C, Hin, Win, eps)
 
     # Grad check
     h = 1e-5
@@ -387,11 +382,11 @@ batch_norm2d = function() {
         # Compute numerical derivative
         old = as.scalar(X[i,j])
         X[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossmh = l2_loss::forward(outmh, y)
         X[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossph = l2_loss::forward(outph, y)
         X[i,j] = old  # reset
@@ -408,11 +403,11 @@ batch_norm2d = function() {
         # Compute numerical derivative
         old = as.scalar(gamma[i,j])
         gamma[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossmh = l2_loss::forward(outmh, y)
         gamma[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossph = l2_loss::forward(outph, y)
         gamma[i,j] = old  # reset
@@ -430,11 +425,11 @@ batch_norm2d = function() {
         # Compute numerical derivative
         old = as.scalar(beta[i,j])
         beta[i,j] = old - h
-        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outmh, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossmh = l2_loss::forward(outmh, y)
         beta[i,j] = old + h
-        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+        [outph, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
             batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
         lossph = l2_loss::forward(outph, y)
         beta[i,j] = old  # reset
@@ -445,7 +440,7 @@ batch_norm2d = function() {
                                                     lossph, lossmh)
       }
     }
-  }
+  # }
 }
 
 conv2d = function() {
@@ -2497,4 +2492,3 @@ elu = function() {
      }
    }
 }
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/7350a0c6/scripts/nn/test/test.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/test/test.dml b/scripts/nn/test/test.dml
index e3e136f..59bec5c 100644
--- a/scripts/nn/test/test.dml
+++ b/scripts/nn/test/test.dml
@@ -125,7 +125,7 @@ batch_norm2d = function() {
   [gamma, beta, ema_mean, ema_var] = batch_norm2d::init(C)
 
   # Forward
-  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var, cache_norm] =
+  [out, ema_mean_upd, ema_var_upd, cache_mean, cache_var] =
       batch_norm2d::forward(X, gamma, beta, C, Hin, Win, mode, ema_mean, ema_var, mu, eps)
 
   # Equivalency check
@@ -1125,4 +1125,3 @@ elu = function() {
     }
   }
 }
-

http://git-wip-us.apache.org/repos/asf/systemml/blob/7350a0c6/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index 9aad7b3..3e7aff3 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -279,15 +279,12 @@ class BatchNorm(val param: LayerParameter, val id: Int, val net: CaffeNetwork) e
    *      Note: This is used for performance during training.
    *  - cache_var: Cache of the batch variance, of shape (C, 1).
    *      Note: This is used for performance during training.
-   *  - cache_norm: Cache of the normalized inputs, of
-   *      shape (C, N*Hin*Win). Note: This is used for performance
-   *      during training.
    */
   def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
     val mode = if (isPrediction) "\"test\"" else "\"train\""
     invokeForward(
       dmlScript,
-      List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)),
+      List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var)),
       X,
       gamma,
       beta,
@@ -307,38 +304,18 @@ class BatchNorm(val param: LayerParameter, val id: Int, val net: CaffeNetwork) e
    *
    * Inputs:
    *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
-   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
-   *  - ema_mean_upd: Updated exponential moving average of the mean
-   *      from the forward pass, of shape (C, 1).
-   *  - ema_var_upd: Updated exponential moving average of the variance
-   *      from the forward pass, of shape (C, 1).
    *  - cache_mean: Cache of the batch mean from the forward pass, of
    *      shape (C, 1).  Note: This is used for performance during
    *      training.
-   *  - cache_var: Cache of the batch variance from the forward pass,
+   *  - cache_inv_var: Cache of the inverse variance from the forward pass,
    *      of shape (C, 1).  Note: This is used for performance during
    *      training.
-   *  - cache_norm: Cache of the normalized inputs from the forward
-   *      pass, of shape (C, N*Hin*Win).  Note: This is used for
-   *      performance during training.
    *  - X: Input data matrix to the forward pass, of
    *      shape (N, C*Hin*Win).
    *  - gamma: Scale parameters, of shape (C, 1).
-   *  - beta: Shift parameters, of shape (C, 1).
    *  - C: Number of input channels (dimensionality of input depth).
    *  - Hin: Input height.
    *  - Win: Input width.
-   *  - mode: 'train' or 'test' to indicate if the model is currently
-   *      being trained or tested.  During training, the current batch
-   *      mean and variance will be used to normalize the inputs, while
-   *      during testing, the exponential average of the mean and
-   *      variance over all previous batches will be used.
-   *  - ema_mean: Exponential moving average of the mean, of
-   *      shape (C, 1).
-   *  - ema_var: Exponential moving average of the variance, of
-   *      shape (C, 1).
-   *  - mu: Momentum value for moving averages.
-   *      Typical values are in the range of [0.9, 0.999].
    *  - epsilon: Smoothing term to avoid divide by zero errors.
    *      Typical values are in the range of [1e-5, 1e-3].
    *
@@ -354,22 +331,13 @@ class BatchNorm(val param: LayerParameter, val id: Int, val net: CaffeNetwork) e
       outSuffix,
       List[String]("dOut" + id, dgamma, dbeta),
       dout,
-      out,
-      ema_mean,
-      ema_var,
       cache_mean,
       cache_var,
-      cache_norm,
       X,
       gamma,
-      beta,
       numChannels,
       Hin,
       Win,
-      "\"train\"",
-      ema_mean,
-      ema_var,
-      ma_fraction,
       eps
     )
 
@@ -377,8 +345,7 @@ class BatchNorm(val param: LayerParameter, val id: Int, val net: CaffeNetwork) e
   override def weightShape(): Array[Int]      = Array(numChannels.toInt, 1)
   override def biasShape(): Array[Int]        = Array(numChannels.toInt, 1)
   def cache_mean(): String                    = "cache_mean" + id
-  def cache_var(): String                     = "cache_mean" + id
-  def cache_norm(): String                    = "cache_norm" + id
+  def cache_var(): String                     = "cache_var" + id
   var scaleLayer: Scale                       = null
   def gamma(): String                         = { checkNextLayer(); scaleLayer.weight }
   def ma_fraction(): String                   = if (param.getBatchNormParam.hasMovingAverageFraction()) param.getBatchNormParam.getMovingAverageFraction.toString else "0.999"
@@ -1636,4 +1603,4 @@ class DeConvolution(val param: LayerParameter, val id: Int, val net: CaffeNetwor
     if (convParam.hasPadW) convParam.getPadW.toString
     else if (convParam.getPadCount > 0) convParam.getPad(0).toString
     else "0"
-}
+}
\ No newline at end of file