You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2016/05/28 00:55:12 UTC

[2/2] incubator-systemml git commit: Adding some more internal SystemML-NN documentation for clarification.

Adding some more internal SystemML-NN documentation for clarification.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/b14d55be
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/b14d55be
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/b14d55be

Branch: refs/heads/master
Commit: b14d55bed7e1960db69337d7c2fd840d89e630c2
Parents: ba60e73
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Fri May 27 17:55:08 2016 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Fri May 27 17:55:08 2016 -0700

----------------------------------------------------------------------
 .../staging/SystemML-NN/nn/layers/affine.dml    |  3 +++
 scripts/staging/SystemML-NN/nn/layers/conv.dml  |  3 +++
 .../SystemML-NN/nn/layers/conv_builtin.dml      |  3 +++
 .../nn/layers/cross_entropy_loss.dml            |  4 +++
 .../staging/SystemML-NN/nn/layers/dropout.dml   |  3 ++-
 .../staging/SystemML-NN/nn/layers/l1_loss.dml   |  3 +++
 .../staging/SystemML-NN/nn/layers/l2_loss.dml   |  3 +++
 .../staging/SystemML-NN/nn/layers/log_loss.dml  |  3 +++
 .../staging/SystemML-NN/nn/layers/softmax.dml   | 27 +++++++++++++-------
 scripts/staging/SystemML-NN/nn/layers/tanh.dml  |  4 +--
 .../staging/SystemML-NN/nn/optim/adagrad.dml    |  3 +++
 scripts/staging/SystemML-NN/nn/optim/adam.dml   |  3 +++
 .../staging/SystemML-NN/nn/optim/rmsprop.dml    |  3 +++
 .../SystemML-NN/nn/optim/sgd_momentum.dml       |  3 +++
 .../SystemML-NN/nn/optim/sgd_nesterov.dml       |  3 +++
 15 files changed, 59 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/affine.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/affine.dml b/scripts/staging/SystemML-NN/nn/layers/affine.dml
index 1338de4..e7e4fd8 100644
--- a/scripts/staging/SystemML-NN/nn/layers/affine.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/affine.dml
@@ -66,6 +66,9 @@ init = function(int D, int M)
     return (matrix[double] W, matrix[double] b) {
   /*
    * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
    * 
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/conv.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv.dml b/scripts/staging/SystemML-NN/nn/layers/conv.dml
index 0fbcf99..1b737f5 100644
--- a/scripts/staging/SystemML-NN/nn/layers/conv.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/conv.dml
@@ -161,6 +161,9 @@ init = function(int F, int C, int Hf, int Wf)
     return (matrix[double] W, matrix[double] b) {
   /*
    * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
    * 
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
index a73405e..7042eb2 100644
--- a/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/conv_builtin.dml
@@ -125,6 +125,9 @@ init = function(int F, int C, int Hf, int Wf)
     return (matrix[double] W, matrix[double] b) {
   /*
    * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
    * 
    * We use the heuristic by He et al. [http://arxiv.org/abs/1502.01852],
    * which limits the magnification of inputs/gradients during

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
index 306ea96..9e3e7cd 100644
--- a/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/cross_entropy_loss.dml
@@ -33,6 +33,10 @@ forward = function(matrix[double] pred, matrix[double] y)
    * inputs consist of N examples, each with K dimensions corresponding
    * to normalized probabilities of K classes.
    *
+   * This can be interpreted as the negative log-likelihood assuming
+   * a Bernoulli distribution generalized to K dimensions, or a
+   * Multinomial with 1 observation.
+   *
    * Inputs:
    *  - pred: Prediction matrix, of shape (N, K).
    *  - y: Target matrix, of shape (N, K).

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/dropout.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/dropout.dml b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
index e3c34f9..6c0b0d0 100644
--- a/scripts/staging/SystemML-NN/nn/layers/dropout.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/dropout.dml
@@ -42,8 +42,9 @@ forward = function(matrix[double] X, double p, int seed)
    *  - out: Ouptuts, of same shape as X.
    *  - mask: Dropout mask used to compute the output.
    */
-  if (seed == -1)
+  if (seed == -1) {
     seed = as.integer(floor(as.scalar(rand(rows=1, cols=1, min=1, max=100000))))
+  }
   mask = rand(rows=nrow(X), cols=ncol(X), min=0, max=1, seed=seed) <= p
   out = X * mask / p
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
index 00db8a7..6c625e8 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l1_loss.dml
@@ -31,6 +31,9 @@ forward = function(matrix[double] pred, matrix[double] y)
    * Computes the forward pass for an L1 loss function.  The inputs
    * consist of N examples, each with M dimensions to predict.
    *
+   * This can be interpreted as the negative log-likelihood assuming
+   * a Laplace distribution.
+   *
    * Inputs:
    *  - pred: Prediction matrix, of shape (N, M).
    *  - y: Target matrix, of shape (N, M).

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
index 13b6c2d..c4a8618 100644
--- a/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/l2_loss.dml
@@ -31,6 +31,9 @@ forward = function(matrix[double] pred, matrix[double] y)
    * Computes the forward pass for an L2 loss function.  The inputs
    * consist of N examples, each with M dimensions to predict.
    *
+   * This can be interpreted as the negative log-likelihood assuming
+   * a Gaussian distribution.
+   *
    * Inputs:
    *  - pred: Prediction matrix, of shape (N, M).
    *  - y: Target matrix, of shape (N, M).

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/log_loss.dml b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
index e3da456..0bcb02e 100644
--- a/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/log_loss.dml
@@ -31,6 +31,9 @@ forward = function(matrix[double] pred, matrix[double] y)
   /*
    * Computes the forward pass for a log loss function.
    *
+   * This can be interpreted as the negative log-likelihood assuming
+   * a Bernoulli distribution.
+   *
    * Inputs:
    *  - pred: Prediction matrix, of shape (N, 1).  Predictions should
    *      be probabilities that y=1.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/softmax.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/softmax.dml b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
index 2576162..111e1b3 100644
--- a/scripts/staging/SystemML-NN/nn/layers/softmax.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/softmax.dml
@@ -29,6 +29,11 @@ forward = function(matrix[double] scores) return (matrix[double] probs) {
    * N examples, and the softmax function transforms them to normalized
    * probabilities.
    *
+   * This can be interpreted as a generalization of the sigmoid
+   * function to multiple classes.
+   *
+   * probs_ij = e^scores_ij / sum(e^scores)
+   *
    * Inputs:
    *  - scores: Input data matrix, of shape (N, D).
    *
@@ -36,14 +41,14 @@ forward = function(matrix[double] scores) return (matrix[double] probs) {
    *  - probs: Outputs, of shape (N, D).
    */
   # For numerical stability, we subtract the max score of an example from all scores for that
-  # example.  This is equivalent:
+  # example.  This is equivalent to the original formulation:
   # e^scores_i / sum(e^scores_i) == C*e^scores_i / C*sum(e^scores_i)
   #                              == e^(scores_i+log(C)) / sum(e^(scores_i+log(C))
   # set log(C) = -max(scores_i):
   #                              == e^(scores_i-max(scores_i)) / sum(e^(scores_i-max(scores_i))
   scores = scores - rowMaxs(scores)  # numerical stability
-  unnorm_probs = exp(scores)
-  probs = unnorm_probs / rowSums(unnorm_probs)
+  unnorm_probs = exp(scores)  # unnormalized probabilities
+  probs = unnorm_probs / rowSums(unnorm_probs)  # normalized probabilities
 }
 
 backward = function(matrix[double] dprobs, matrix[double] scores)
@@ -51,11 +56,13 @@ backward = function(matrix[double] dprobs, matrix[double] scores)
   /*
    * Computes the backward pass for a softmax classifier.
    *
+   * Note that dscores_ij has multiple sources:
+   *
    * dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)
-   * dprobs_ic/dscores_ij = probs_ij * -probs_ic
+   * dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j
    *
    * dloss/dscores_ij = dloss/dprobs_ij * dprobs_ij/dscores_ij + 
-   *                    sum_c(dloss/dprobs_ic * dprobs_ic/dscores_ij)
+   *                    sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij)
    *
    * Inputs:
    *  - dprobs: Derivatives from upstream, of shape (N, D).
@@ -65,9 +72,11 @@ backward = function(matrix[double] dprobs, matrix[double] scores)
    *  - dscores: Gradient wrt scores, of shape (N, D).
    */
   scores = scores - rowMaxs(scores)  # numerical stability
-  unnorm_probs = exp(scores)
-  probs = unnorm_probs / rowSums(unnorm_probs)
-  dscores = dprobs * probs
-  dscores = dscores - probs * rowSums(dscores)
+  unnorm_probs = exp(scores)  # unnormalized probabilities
+  probs = unnorm_probs / rowSums(unnorm_probs)  # normalized probabilities
+  # After some cancellation:
+  # dscores = dprobs*probs - probs*rowSums(dprobs*probs)
+  dtemp = dprobs * probs
+  dscores = dtemp - probs * rowSums(dtemp)
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/layers/tanh.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/layers/tanh.dml b/scripts/staging/SystemML-NN/nn/layers/tanh.dml
index e886081..0fadf77 100644
--- a/scripts/staging/SystemML-NN/nn/layers/tanh.dml
+++ b/scripts/staging/SystemML-NN/nn/layers/tanh.dml
@@ -26,7 +26,7 @@ forward = function(matrix[double] X) return (matrix[double] out) {
   /*
    * Computes the forward pass for a tanh nonlinearity layer.
    *
-   * tanh(x) = (e^x - e^-x) / (e^x + e^-x)
+   * tanh(x) = (e^x - e^-x) / (e^x + e^-x) = sigmoid(-2x)
    *
    * Inputs:
    *  - X: Input data matrix, of shape (any, any).
@@ -34,7 +34,7 @@ forward = function(matrix[double] X) return (matrix[double] out) {
    * Outputs:
    *  - out: Ouptuts, of same shape as X.
    */
-  # Simplification of the above formulation:
+  # Simplification of the above formulation to use the sigmoid function:
   sigma2X = 1 / (1 + exp(-2*X))
   out = 2 * sigma2X - 1
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/adagrad.dml b/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
index daa5f5e..688109b 100644
--- a/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
+++ b/scripts/staging/SystemML-NN/nn/optim/adagrad.dml
@@ -60,6 +60,9 @@ init = function(matrix[double] X) return (matrix[double] cache) {
   /*
    * Initialize the state for this optimizer.
    *
+   * Note: This is just a convenience function, and state
+   * may be initialized manually if needed.
+   *
    * Inputs:
    *  - X: Parameters to update, of shape (any, any).
    * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/optim/adam.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/adam.dml b/scripts/staging/SystemML-NN/nn/optim/adam.dml
index 05152f4..a25f74d 100644
--- a/scripts/staging/SystemML-NN/nn/optim/adam.dml
+++ b/scripts/staging/SystemML-NN/nn/optim/adam.dml
@@ -75,6 +75,9 @@ init = function(matrix[double] X) return (matrix[double] m, matrix[double] v) {
   /*
    * Initialize the state for this optimizer.
    *
+   * Note: This is just a convenience function, and state
+   * may be initialized manually if needed.
+   *
    * Inputs:
    *  - X: Parameters to update, of shape (any, any).
    * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml b/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
index 31b78d5..e256000 100644
--- a/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
+++ b/scripts/staging/SystemML-NN/nn/optim/rmsprop.dml
@@ -62,6 +62,9 @@ init = function(matrix[double] X) return (matrix[double] cache) {
   /*
    * Initialize the state for this optimizer.
    *
+   * Note: This is just a convenience function, and state
+   * may be initialized manually if needed.
+   *
    * Inputs:
    *  - X: Parameters to update, of shape (any, any).
    * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml b/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
index 22a88f2..c2a441b 100644
--- a/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
+++ b/scripts/staging/SystemML-NN/nn/optim/sgd_momentum.dml
@@ -55,6 +55,9 @@ init = function(matrix[double] X) return (matrix[double] v) {
   /*
    * Initialize the state for this optimizer.
    *
+   * Note: This is just a convenience function, and state
+   * may be initialized manually if needed.
+   *
    * Inputs:
    *  - X: Parameters to update, of shape (any, any).
    * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/b14d55be/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
----------------------------------------------------------------------
diff --git a/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml b/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
index aac6522..56c6ab0 100644
--- a/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
+++ b/scripts/staging/SystemML-NN/nn/optim/sgd_nesterov.dml
@@ -64,6 +64,9 @@ init = function(matrix[double] X) return (matrix[double] v) {
   /*
    * Initialize the state for this optimizer.
    *
+   * Note: This is just a convenience function, and state
+   * may be initialized manually if needed.
+   *
    * Inputs:
    *  - X: Parameters to update, of shape (any, any).
    *