You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2019/03/25 19:37:07 UTC

[systemml] branch master updated: [SYSTEMML-540] Added looped_minibatch training algorithm in Keras2DML

This is an automated email from the ASF dual-hosted git repository.

niketanpansare pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new b657820  [SYSTEMML-540] Added looped_minibatch training algorithm in Keras2DML
b657820 is described below

commit b657820248fbb42f1c4f27564cdb14865ebeeec1
Author: Niketan Pansare <np...@us.ibm.com>
AuthorDate: Mon Mar 25 12:33:50 2019 -0700

    [SYSTEMML-540] Added looped_minibatch training algorithm in Keras2DML
    
    - This algorithm performs multiple forward-backward passes (=`parallel_batches` parameters) with the given batch size, aggregate gradients and finally updates the model.
    - Updated the documentation.
---
 docs/beginners-guide-caffe2dml.md                  |  2 +-
 docs/beginners-guide-keras2dml.md                  | 35 ++++++++++++-
 src/main/python/systemml/mllearn/estimators.py     | 11 ++--
 .../scala/org/apache/sysml/api/dl/Caffe2DML.scala  | 60 ++++++++++++++--------
 4 files changed, 82 insertions(+), 26 deletions(-)

diff --git a/docs/beginners-guide-caffe2dml.md b/docs/beginners-guide-caffe2dml.md
index 8814283..db74feb 100644
--- a/docs/beginners-guide-caffe2dml.md
+++ b/docs/beginners-guide-caffe2dml.md
@@ -161,7 +161,7 @@ Iter:2000, validation loss:173.66147359346, validation accuracy:97.4897540983606
 
 Unlike Caffe where default train and test algorithm is `minibatch`, you can specify the
 algorithm using the parameters `train_algo` and `test_algo` (valid values are: `minibatch`, `allreduce_parallel_batches`, 
-and `allreduce`). Here are some common settings:
+`looped_minibatch`, and `allreduce`). Here are some common settings:
 
 |                                                                          | PySpark script                                                                                                                           | Changes to Network/Solver                                              |
 |--------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|
diff --git a/docs/beginners-guide-keras2dml.md b/docs/beginners-guide-keras2dml.md
index 4517be5..2259397 100644
--- a/docs/beginners-guide-keras2dml.md
+++ b/docs/beginners-guide-keras2dml.md
@@ -208,4 +208,37 @@ For example: for the expression `Keras2DML(..., display=100, test_iter=10, test_
 To verify that Keras2DML produce same results as other Keras' backend, we have [Python unit tests](https://github.com/apache/systemml/blob/master/src/main/python/tests/test_nn_numpy.py)
 that compare the results of Keras2DML with that of TensorFlow. We assume that Keras team ensure that all their backends are consistent with their TensorFlow backend.
 
-
+#### How can I train very deep models on GPU?
+
+Unlike Keras where default train and test algorithm is `minibatch`, you can specify the
+algorithm using the parameters `train_algo` and `test_algo` (valid values are: `minibatch`, `allreduce_parallel_batches`, 
+`looped_minibatch`, and `allreduce`). Here are some common settings:
+
+|                                                                          | PySpark script                                                                                                                           | Changes to Network/Solver                                              |
+|--------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|
+| Single-node CPU execution (similar to Caffe with solver_mode: CPU)       | `lenet.set(train_algo="minibatch", test_algo="minibatch")`                                                                               | Ensure that `batch_size` is set to appropriate value (for example: 64) |
+| Single-node single-GPU execution                                         | `lenet.set(train_algo="minibatch", test_algo="minibatch").setGPU(True).setForceGPU(True)`                                                | Ensure that `batch_size` is set to appropriate value (for example: 64) |
+| Single-node multi-GPU execution (similar to Caffe with solver_mode: GPU) | `lenet.set(train_algo="allreduce_parallel_batches", test_algo="minibatch", parallel_batches=num_gpu).setGPU(True).setForceGPU(True)`     | Ensure that `batch_size` is set to appropriate value (for example: 64) |
+| Distributed prediction                                                   | `lenet.set(test_algo="allreduce")`                                                                                                       |                                                                        |
+| Distributed synchronous training                                         | `lenet.set(train_algo="allreduce_parallel_batches", parallel_batches=num_cluster_cores)`                                                 | Ensure that `batch_size` is set to appropriate value (for example: 64) |
+
+Here are high-level guidelines to train very deep models on GPU with Keras2DML (and Caffe2DML):
+
+1. If there exists at least one layer/operator that does not fit on the device, please allow SystemML's optimizer to perform operator placement based on the memory estimates `sysml_model.setGPU(True)`.
+2. If each individual layer/operator fits on the device but not the entire network with a batch size of 1, then 
+- Rely on SystemML's GPU Memory Manager to perform automatic eviction (recommended): `sysml_model.setGPU(True) # Optional: .setForceGPU(True)`
+- Or enable Nvidia's Unified Memory:  `sysml_model.setConfigProperty('sysml.gpu.memory.allocator', 'unified_memory')`
+3. If the entire neural network does not fit in the GPU memory with the user-specified `batch_size`, but fits in the GPU memory with `local_batch_size` such that `1 << local_batch_size < batch_size`, then
+- Use either of the above two options.
+- Or enable `train_algo` that performs multiple forward-backward pass with batch size `local_batch_size`, aggregate gradients and finally updates the model: 
+```python
+sysml_model = Keras2DML(spark, keras_model, batch_size=local_batch_size)
+sysml_model.set(train_algo="looped_minibatch", parallel_batches=int(batch_size/local_batch_size))
+sysml_model.setGPU(True).setForceGPU(True)
+```
+- Or add `int(batch_size/local_batch_size)` GPUs and perform single-node multi-GPU training with batch size `local_batch_size`:
+```python
+sysml_model = Keras2DML(spark, keras_model, batch_size=local_batch_size)
+sysml_model.set(train_algo="allreduce_parallel_batches", parallel_batches=int(batch_size/local_batch_size))
+sysml_model.setGPU(True).setForceGPU(True)
+```
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index 456280b..0b47d8c 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -923,22 +923,23 @@ class Caffe2DML(BaseSystemMLClassifier):
 
     def set(self, debug=None, train_algo=None, test_algo=None, parallel_batches=None,
             output_activations=None, perform_one_hot_encoding=None, parfor_parameters=None, inline_nn_library=None, use_builtin_lstm_fn=None,
-            perform_fused_backward_update=None):
+            perform_fused_backward_update=None, weight_parallel_batches=None):
         """
         Set input to Caffe2DML
 
         Parameters
         ----------
         debug: to add debugging DML code such as classification report, print DML script, etc (default: False)
-        train_algo: can be minibatch, batch, allreduce_parallel_batches or allreduce (default: minibatch)
-        test_algo: can be minibatch, batch, allreduce_parallel_batches or allreduce (default: minibatch)
-        parallel_batches: number of parallel batches
+        train_algo: can be minibatch, batch, allreduce_parallel_batches, looped_minibatch or allreduce (default: minibatch)
+        test_algo: can be minibatch, batch, allreduce_parallel_batches, looped_minibatch or allreduce (default: minibatch)
+        parallel_batches: number of parallel batches (required for allreduce_parallel_batches or looped_minibatch)
         output_activations: (developer flag) directory to output activations of each layer as csv while prediction. To be used only in batch mode (default: None)
         perform_one_hot_encoding: should perform one-hot encoding in DML using table function (default: True)
         parfor_parameters: dictionary for parfor parameters when using allreduce-style algorithms (default: "")
         inline_nn_library: whether to inline the NN library when generating DML using Caffe2DML (default: False)
         use_builtin_lstm_fn: whether to use builtin lstm function for LSTM layer (default: True)
         perform_fused_backward_update: whether to perform update immediately after backward pass at the script level. Supported for minibatch and batch algorithms. (default: True)
+        weight_parallel_batches: whether to multiply 1/parallel_batches to gradients before performing SGD update (default: True)
         """
         if debug is not None:
             self.estimator.setInput("$debug", str(debug).upper())
@@ -954,6 +955,8 @@ class Caffe2DML(BaseSystemMLClassifier):
             self.estimator.setInput("$use_builtin_lstm_fn", str(use_builtin_lstm_fn).upper())
         if perform_fused_backward_update is not None:
             self.estimator.setInput("$perform_fused_backward_update", str(perform_fused_backward_update).upper())
+        if weight_parallel_batches is not None:
+            self.estimator.setInput("$weight_parallel_batches", str(weight_parallel_batches).upper())
         if output_activations is not None:
             self.estimator.setInput(
                 "$output_activations",
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index c5a20db..9950d69 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -118,7 +118,7 @@ To shield from network files that violates this restriction, Caffe2DML performs
 object Caffe2DML {
   val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
-  val USE_PLUS_EQ = true
+  var USE_PLUS_EQ = true
   def nnDir = "nn/"
   def layerDir = nnDir + "layers/"
   def optimDir = nnDir + "optim/"
@@ -157,6 +157,7 @@ object Caffe2DML {
   val rand = new Random
   // Supported Algorithms:
   val MINIBATCH_ALGORITHM = "minibatch"
+  val LOOPED_MINIBATCH_ALGORITHM = "looped_minibatch"
   val BATCH_ALGORITHM = "batch"
   val ALLREDUCE_ALGORITHM = "allreduce"
   val ALLREDUCE_PARALLEL_BATCHES_ALGORITHM = "allreduce_parallel_batches"
@@ -321,6 +322,7 @@ class Caffe2DML(val sc: SparkContext,
         case "$inline_nn_library" => false
         case "$use_builtin_lstm_fn" => true
         case "$perform_fused_backward_update" => true
+        case "$weight_parallel_batches" => true
         case _ => throw new DMLRuntimeException("Unsupported input:" + key)
       }
     } 
@@ -329,7 +331,7 @@ class Caffe2DML(val sc: SparkContext,
   // The below method parses the provided network and solver file and generates DML script.
   def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
     val startTrainingTime = System.nanoTime()
-
+    
     reset // Reset the state of DML generator for training script.
 
     // Flags passed by user
@@ -357,7 +359,9 @@ class Caffe2DML(val sc: SparkContext,
       tabDMLScript.append(print(dmlConcat(asDMLString("Iterations (for training loss/accuracy) refers to the number of batches processed where batch size="), Caffe2DML.batchSize)))
     }
     if(getTrainAlgo.toLowerCase.equals(Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM) ||
-        getTestAlgo.toLowerCase.equals(Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM)) {
+        getTestAlgo.toLowerCase.equals(Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM) || 
+        getTrainAlgo.toLowerCase.equals(Caffe2DML.LOOPED_MINIBATCH_ALGORITHM) ||
+        getTestAlgo.toLowerCase.equals(Caffe2DML.LOOPED_MINIBATCH_ALGORITHM)) {
       assign(tabDMLScript, "parallel_batches", "$parallel_batches")
     }
     // ----------------------------------------------------------------------------
@@ -426,7 +430,7 @@ class Caffe2DML(val sc: SparkContext,
           lrPolicy.updateLearningRate(tabDMLScript)
         }
       }
-      case Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM => {
+      case Caffe2DML.LOOPED_MINIBATCH_ALGORITHM | Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM => {
         assign(tabDMLScript, "e", "0")
         assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
         forBlock("iter", "1", "max_iter", "parallel_batches") {  
@@ -436,7 +440,16 @@ class Caffe2DML(val sc: SparkContext,
             assign(tabDMLScript, "allreduce_start_index", "1")
           }
           initializeGradients("parallel_batches")
-          parForBlock("j", "1", "parallel_batches", "1", getParforParameters()) {
+          val old_USE_PLUS_EQ = Caffe2DML.USE_PLUS_EQ
+          val iterBlock = if(getTrainAlgo.toLowerCase.equals(Caffe2DML.ALLREDUCE_PARALLEL_BATCHES_ALGORITHM)) {
+            parForBlock("j", "1", "parallel_batches", "1", getParforParameters()) _ 
+          }
+          else {
+            Caffe2DML.USE_PLUS_EQ = true
+            forBlock("j", "1", "parallel_batches", "1") _
+          }
+          
+          iterBlock {
             // Get a mini-batch in this group
             assign(tabDMLScript, "beg", "allreduce_start_index + (j-1)*" + Caffe2DML.batchSize)
             assign(tabDMLScript, "end", "allreduce_start_index + j*" + Caffe2DML.batchSize + " - 1")
@@ -463,6 +476,7 @@ class Caffe2DML(val sc: SparkContext,
             }
           }
           performSnapshot
+          Caffe2DML.USE_PLUS_EQ = old_USE_PLUS_EQ
         }
       }
       case Caffe2DML.ALLREDUCE_ALGORITHM => {
@@ -570,7 +584,7 @@ class Caffe2DML(val sc: SparkContext,
     tabDMLScript.append("# Compute validation loss & accuracy\n")
     assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
     getTestAlgo.toLowerCase match {
-      case Caffe2DML.MINIBATCH_ALGORITHM => {
+      case Caffe2DML.MINIBATCH_ALGORITHM | Caffe2DML.LOOPED_MINIBATCH_ALGORITHM => {
         assign(tabDMLScript, "validation_loss", "0")
         assign(tabDMLScript, "validation_accuracy", "0")
         forBlock("iVal", "1", "num_batches_per_epoch") {
@@ -695,29 +709,35 @@ class Caffe2DML(val sc: SparkContext,
     }
   }
   private def flattenGradients(): Unit = {
-    if(Caffe2DML.USE_PLUS_EQ) {
-      // Note: We multiply by a weighting to allow for proper gradient averaging during the
-      // aggregation even with uneven batch sizes.
+    if(!Caffe2DML.USE_PLUS_EQ) {
+      tabDMLScript.append("# Flatten and store gradients for this parallel execution\n")
+    }
+    val isLoopedMinibatch = getTrainAlgo.toLowerCase.equals(Caffe2DML.LOOPED_MINIBATCH_ALGORITHM)
+    val suffixDML = if(getInputBooleanValue("$weight_parallel_batches")) " * weighting" else ""
+    // Note: We multiply by a weighting to allow for proper gradient averaging during the
+    // aggregation even with uneven batch sizes.
+    if(getInputBooleanValue("$weight_parallel_batches")) {
       assign(tabDMLScript, "weighting", "1/parallel_batches") // "nrow(Xb)/X_group_batch_size")
+    }
+    if(Caffe2DML.USE_PLUS_EQ) {
       net.getLayers
         .map(layer => net.getCaffeLayer(layer))
         .map(l => {
-          if (l.shouldUpdateWeight) assignPlusEq(tabDMLScript, l.dWeight + "_agg", l.dWeight + "*weighting")
-          if (l.shouldUpdateExtraWeight) assignPlusEq(tabDMLScript, l.dExtraWeight + "_agg", l.dExtraWeight + "*weighting")
-          if (l.shouldUpdateWeight) assignPlusEq(tabDMLScript, l.dBias + "_agg", l.dBias + "*weighting")
+          if (l.shouldUpdateWeight) assignPlusEq(tabDMLScript, l.dWeight + "_agg", l.dWeight + suffixDML)
+          if (l.shouldUpdateExtraWeight) assignPlusEq(tabDMLScript, l.dExtraWeight + "_agg", l.dExtraWeight + suffixDML)
+          if (l.shouldUpdateWeight) assignPlusEq(tabDMLScript, l.dBias + "_agg", l.dBias + suffixDML)
         })
     }
     else {
-      tabDMLScript.append("# Flatten and store gradients for this parallel execution\n")
-      // Note: We multiply by a weighting to allow for proper gradient averaging during the
-      // aggregation even with uneven batch sizes.
-      assign(tabDMLScript, "weighting", "1/parallel_batches") // "nrow(Xb)/X_group_batch_size")
+      if(isLoopedMinibatch) {
+        throw new DMLRuntimeException("Flattening and storing gradients is not supported for looped_minibatch algorithm")
+      }
       net.getLayers
         .map(layer => net.getCaffeLayer(layer))
         .map(l => {
-          if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * weighting")
-          if (l.shouldUpdateExtraWeight) assign(tabDMLScript, l.dExtraWeight + "_agg[j,]", matrix(l.dExtraWeight, "1", multiply(nrow(l.extraWeight), ncol(l.extraWeight))) + " * weighting")
-          if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))) + " * weighting")
+          if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + suffixDML)
+          if (l.shouldUpdateExtraWeight) assign(tabDMLScript, l.dExtraWeight + "_agg[j,]", matrix(l.dExtraWeight, "1", multiply(nrow(l.extraWeight), ncol(l.extraWeight))) + suffixDML)
+          if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))) + suffixDML)
         })
     }
   }
@@ -807,7 +827,7 @@ class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: C
     val lastLayerShape = estimator.getOutputShapeOfLastLayer
     assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, (lastLayerShape._1 * lastLayerShape._2 * lastLayerShape._3).toString))
     estimator.getTestAlgo.toLowerCase match {
-      case Caffe2DML.MINIBATCH_ALGORITHM => {
+      case Caffe2DML.MINIBATCH_ALGORITHM | Caffe2DML.LOOPED_MINIBATCH_ALGORITHM => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
         forBlock("iter", "1", "num_iters") {
           getTestBatch(tabDMLScript)