You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/15 18:03:11 UTC

[4/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

[SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

- Also fixed scala formatting.

Closes #662.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f07b5a2d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f07b5a2d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f07b5a2d

Branch: refs/heads/master
Commit: f07b5a2d92f95f28bcdf141d700fc1be0887d735
Parents: ebb6ea6
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Sep 15 11:00:06 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Sep 15 11:01:49 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  732 ++++++-------
 .../apache/sysml/api/dl/Caffe2DMLLoader.scala   |   20 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 1002 ++++++++++--------
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  216 ++--
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  193 ++--
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  566 +++++-----
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  484 ++++-----
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  264 +++--
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |   68 +-
 .../apache/sysml/api/ml/LinearRegression.scala  |   93 +-
 .../sysml/api/ml/LogisticRegression.scala       |  117 +-
 .../org/apache/sysml/api/ml/NaiveBayes.scala    |   62 +-
 .../apache/sysml/api/ml/PredictionUtils.scala   |   32 +-
 .../scala/org/apache/sysml/api/ml/SVM.scala     |   81 +-
 .../org/apache/sysml/api/ml/ScriptsUtils.scala  |   18 +-
 .../scala/org/apache/sysml/api/ml/Utils.scala   |   49 +-
 16 files changed, 2100 insertions(+), 1897 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index a62fae2..6e3e1dc 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,10 +35,10 @@ import java.util.HashSet
 import org.apache.sysml.api.DMLScript
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -55,7 +55,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 DESIGN OF CAFFE2DML:
 
 1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that were to be implemented are:
-- `getTrainingScript` for the Estimator class. 
+- `getTrainingScript` for the Estimator class.
 - `getPredictionScript` for the Model class.
 
 These methods should be the starting point of any developer to understand the DML generated for training and prediction respectively.
@@ -74,7 +74,7 @@ caffe.proto ---> protoc ---> target/generated-sources/caffe/Caffe.java
 - Just like the classes generated by Dml.g4 are used to parse input DML file,
 the target/generated-sources/caffe/Caffe.java class is used to parse the input caffe network/deploy prototxt and solver files.
 
-- You can think of .caffemodel file as DML file with matrix values encoded in it (please see below example). 
+- You can think of .caffemodel file as DML file with matrix values encoded in it (please see below example).
 So it is possible to read .caffemodel file with the Caffe.java class. This is done in Utils.scala's readCaffeNet method.
 
 X = matrix("1.2 3.5 0.999 7.123", rows=2, cols=2)
@@ -91,7 +91,7 @@ trait CaffeLayer {
   def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
   def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
   ...
-} 
+}
 trait CaffeSolver {
   def sourceFileName:String;
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
@@ -114,67 +114,85 @@ To shield from network files that violates this restriction, Caffe2DML performs
 6. Caffe2DML also expects the layers to be in sorted order.
 
 ***************************************************************************************/
-
-object Caffe2DML  {
-  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName()) 
+object Caffe2DML {
+  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
   def layerDir = "nn/layers/"
   def optimDir = "nn/optim/"
-  
+
   // Naming conventions:
-  val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
+  val X    = "X"; val y        = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
   val XVal = "X_val"; val yVal = "y_val"
-  
+
   val USE_NESTEROV_UDF = {
     // Developer environment variable flag 'USE_NESTEROV_UDF' until codegen starts working.
     // Then, we will remove this flag and also the class org.apache.sysml.udf.lib.SGDNesterovUpdate
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
     envFlagNesterovUDF != null && envFlagNesterovUDF.toBoolean
   }
-  
+
   def main(args: Array[String]): Unit = {
-	// Arguments: [train_script | predict_script] $OUTPUT_DML_FILE $SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
-	if(args.length < 6) throwUsageError
-	val outputDMLFile = args(1)
-	val solverFile = args(2)
-	val inputChannels = args(3)
-	val inputHeight = args(4)
-	val inputWidth = args(5)
-	val caffeObj = new Caffe2DML(new SparkContext(), solverFile, inputChannels, inputHeight, inputWidth)
-	if(args(0).equals("train_script")) {
-		Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, outputDMLFile)
-	}
-	else if(args(0).equals("predict_script")) {
-		Utils.writeToFile(new Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, outputDMLFile)
-	}
-	else {
-		throwUsageError
-	}
-  }
-  def throwUsageError():Unit = {
-	throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH"); 
+    // Arguments: [train_script | predict_script] $OUTPUT_DML_FILE $SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
+    if (args.length < 6) throwUsageError
+    val outputDMLFile = args(1)
+    val solverFile    = args(2)
+    val inputChannels = args(3)
+    val inputHeight   = args(4)
+    val inputWidth    = args(5)
+    val caffeObj      = new Caffe2DML(new SparkContext(), solverFile, inputChannels, inputHeight, inputWidth)
+    if (args(0).equals("train_script")) {
+      Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, outputDMLFile)
+    } else if (args(0).equals("predict_script")) {
+      Utils.writeToFile(new Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, outputDMLFile)
+    } else {
+      throwUsageError
+    }
   }
+  def throwUsageError(): Unit =
+    throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH");
 }
 
-class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
-    val solver:CaffeSolver, val net:CaffeNetwork, 
-    val lrPolicy:LearningRatePolicy, val numChannels:String, val height:String, val width:String) extends Estimator[Caffe2DMLModel] 
-  with BaseSystemMLClassifier with DMLGenerator {
+class Caffe2DML(val sc: SparkContext,
+                val solverParam: Caffe.SolverParameter,
+                val solver: CaffeSolver,
+                val net: CaffeNetwork,
+                val lrPolicy: LearningRatePolicy,
+                val numChannels: String,
+                val height: String,
+                val width: String)
+    extends Estimator[Caffe2DMLModel]
+    with BaseSystemMLClassifier
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, networkPath:String, numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), 
-        new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, networkPath: String, numChannels: String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width), 
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, numChannels: String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solverPath:String, numChannels:String, height:String, width:String) {
+  def this(sc: SparkContext, solverPath: String, numChannels: String, height: String, width: String) {
     this(sc, Utils.readCaffeSolver(solverPath), numChannels, height, width)
   }
-  val uid:String = "caffe_classifier_" + (new Random).nextLong
+  val uid: String = "caffe_classifier_" + (new Random).nextLong
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Estimator[Caffe2DMLModel] = {
     val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, numChannels, height, width)
     copyValues(that, extra)
@@ -188,221 +206,223 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
     mloutput = baseFit(df, sc)
     new Caffe2DMLModel(this)
   }
-	// --------------------------------------------------------------
+  // --------------------------------------------------------------
   // Returns true if last 2 of 4 dimensions are 1.
   // The first dimension refers to number of input datapoints.
   // The second dimension refers to number of classes.
-  def isClassification():Boolean = {
+  def isClassification(): Boolean = {
     val outShape = getOutputShapeOfLastLayer
     return outShape._2 == 1 && outShape._3 == 1
   }
-  def getOutputShapeOfLastLayer():(Int, Int, Int) = {
+  def getOutputShapeOfLastLayer(): (Int, Int, Int) = {
     val out = net.getCaffeLayer(net.getLayers().last).outputShape
-    (out._1.toInt, out._2.toInt, out._3.toInt) 
+    (out._1.toInt, out._2.toInt, out._3.toInt)
   }
-  
+
   // Used for simplifying transfer learning
-  private val layersToIgnore:HashSet[String] = new HashSet[String]() 
-  def setWeightsToIgnore(layerName:String):Unit = layersToIgnore.add(layerName)
-  def setWeightsToIgnore(layerNames:ArrayList[String]):Unit = layersToIgnore.addAll(layerNames)
-  	  
+  private val layersToIgnore: HashSet[String]                 = new HashSet[String]()
+  def setWeightsToIgnore(layerName: String): Unit             = layersToIgnore.add(layerName)
+  def setWeightsToIgnore(layerNames: ArrayList[String]): Unit = layersToIgnore.addAll(layerNames)
+
   // Input parameters to prediction and scoring script
-  val inputs:java.util.HashMap[String, String] = new java.util.HashMap[String, String]()
-  def setInput(key: String, value:String):Unit = inputs.put(key, value)
+  val inputs: java.util.HashMap[String, String]  = new java.util.HashMap[String, String]()
+  def setInput(key: String, value: String): Unit = inputs.put(key, value)
   customAssert(solverParam.getTestIterCount <= 1, "Multiple test_iter variables are not supported")
   customAssert(solverParam.getMaxIter > 0, "Please set max_iter to a positive value")
   customAssert(net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[IsLossLayer]).length == 1, "Expected exactly one loss layer")
-    
+
   // TODO: throw error or warning if user tries to set solver_mode == GPU instead of using setGPU method
-  
+
   // Method called by Python mllearn to visualize variable of certain layer
-  def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = visualizeLayer(net, layerName, varType, aggFn)
-  
-  def getTrainAlgo():String = if(inputs.containsKey("$train_algo")) inputs.get("$train_algo") else "minibatch"
-  def getTestAlgo():String = if(inputs.containsKey("$test_algo")) inputs.get("$test_algo") else "minibatch"
+  def visualizeLayer(layerName: String, varType: String, aggFn: String): Unit = visualizeLayer(net, layerName, varType, aggFn)
 
-  def summary(sparkSession:org.apache.spark.sql.SparkSession):Unit = {
+  def getTrainAlgo(): String = if (inputs.containsKey("$train_algo")) inputs.get("$train_algo") else "minibatch"
+  def getTestAlgo(): String  = if (inputs.containsKey("$test_algo")) inputs.get("$test_algo") else "minibatch"
+
+  def summary(sparkSession: org.apache.spark.sql.SparkSession): Unit = {
     val header = Seq("Name", "Type", "Output", "Weight", "Bias", "Top", "Bottom")
-    val entries = net.getLayers.map(l => (l, net.getCaffeLayer(l))).map(l => {
-      val layer = l._2
-      (l._1, layer.param.getType, 
-          "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + layer.outputShape._3 + ")",
-          if(layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + layer.weightShape()(1) + "]" else "",
-          if(layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + layer.biasShape()(1) + "]" else "",
-          layer.param.getTopList.mkString(","),
-          layer.param.getBottomList.mkString(",")
-      )
-    })
+    val entries = net.getLayers
+      .map(l => (l, net.getCaffeLayer(l)))
+      .map(l => {
+        val layer = l._2
+        (l._1,
+         layer.param.getType,
+         "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + layer.outputShape._3 + ")",
+         if (layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + layer.weightShape()(1) + "]" else "",
+         if (layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + layer.biasShape()(1) + "]" else "",
+         layer.param.getTopList.mkString(","),
+         layer.param.getBottomList.mkString(","))
+      })
     import sparkSession.implicits._
-    sc.parallelize(entries).toDF(header : _*).show(net.getLayers.size)
+    sc.parallelize(entries).toDF(header: _*).show(net.getLayers.size)
   }
-  
+
   // ================================================================================================
   // The below method parses the provided network and solver file and generates DML script.
-	def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
-	  val startTrainingTime = System.nanoTime()
-	  
-    reset                                 // Reset the state of DML generator for training script.
-    
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
+    val startTrainingTime = System.nanoTime()
+
+    reset // Reset the state of DML generator for training script.
+
     // Flags passed by user
-	  val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
-	  assign(tabDMLScript, "debug", if(DEBUG_TRAINING) "TRUE" else "FALSE")
-	  
-	  appendHeaders(net, solver, true)      // Appends DML corresponding to source and externalFunction statements.
-	  readInputData(net, true)              // Read X_full and y_full
-	  // Initialize the layers and solvers. Reads weights and bias if $weights is set.
-	  initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
-	  
-	  // Split into training and validation set
-	  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
-	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
-	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
-	  
-	  // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
-	  ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
-	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-	  assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
-	  assign(tabDMLScript, "e", "0")
-	  
-	  val lossLayers = getLossLayers(net)
-	  // ----------------------------------------------------------------------------
-	  // Main logic
-	  forBlock("iter", "1", "max_iter") {
-		performTrainingIter(lossLayers, shouldValidate)
-		if(getTrainAlgo.toLowerCase.equals("batch")) {
-			assign(tabDMLScript, "e", "iter")
-			tabDMLScript.append("# Learning rate\n")
-			lrPolicy.updateLearningRate(tabDMLScript)
-		}
-		else {
-			ifBlock("iter %% num_iters_per_epoch == 0") {
-				// After every epoch, update the learning rate
-				assign(tabDMLScript, "e", "e + 1")
-				tabDMLScript.append("# Learning rate\n")
-				lrPolicy.updateLearningRate(tabDMLScript)
-			}
-		}
-	  }
-	  // ----------------------------------------------------------------------------
-	  
-	  // Check if this is necessary
-	  if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
-	  
-	  val trainingScript = tabDMLScript.toString()
-	  // Print script generation time and the DML script on stdout
-	  System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
-	  if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
-	  
-	  // Set input/output variables and execute the script
-	  val script = dml(trainingScript).in(inputs)
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
-	  (script, "X_full", "y_full")
-	}
-	// ================================================================================================
-  
-  private def performTrainingIter(lossLayers:List[IsLossLayer], shouldValidate:Boolean):Unit = {
-	getTrainAlgo.toLowerCase match {
-      case "minibatch" => 
-          getTrainingBatch(tabDMLScript)
-          // -------------------------------------------------------
-          // Perform forward, backward and update on minibatch
-          forward; backward; update
-          // -------------------------------------------------------
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
+    val DEBUG_TRAINING = if (inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, true) // Appends DML corresponding to source and externalFunction statements.
+    readInputData(net, true)         // Read X_full and y_full
+    // Initialize the layers and solvers. Reads weights and bias if $weights is set.
+    initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
+
+    // Split into training and validation set
+    // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+    val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+    trainTestSplit(if (shouldValidate) solverParam.getTestIter(0) else 0)
+
+    // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
+    ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
+    assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+    assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
+    assign(tabDMLScript, "e", "0")
+
+    val lossLayers = getLossLayers(net)
+    // ----------------------------------------------------------------------------
+    // Main logic
+    forBlock("iter", "1", "max_iter") {
+      performTrainingIter(lossLayers, shouldValidate)
+      if (getTrainAlgo.toLowerCase.equals("batch")) {
+        assign(tabDMLScript, "e", "iter")
+        tabDMLScript.append("# Learning rate\n")
+        lrPolicy.updateLearningRate(tabDMLScript)
+      } else {
+        ifBlock("iter %% num_iters_per_epoch == 0") {
+          // After every epoch, update the learning rate
+          assign(tabDMLScript, "e", "e + 1")
+          tabDMLScript.append("# Learning rate\n")
+          lrPolicy.updateLearningRate(tabDMLScript)
+        }
+      }
+    }
+    // ----------------------------------------------------------------------------
+
+    // Check if this is necessary
+    if (doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
+
+    val trainingScript = tabDMLScript.toString()
+    // Print script generation time and the DML script on stdout
+    System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime) * 1e-9) + " seconds.")
+    if (DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+
+    // Set input/output variables and execute the script
+    val script = dml(trainingScript).in(inputs)
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
+    (script, "X_full", "y_full")
+  }
+  // ================================================================================================
+
+  private def performTrainingIter(lossLayers: List[IsLossLayer], shouldValidate: Boolean): Unit =
+    getTrainAlgo.toLowerCase match {
+      case "minibatch" =>
+        getTrainingBatch(tabDMLScript)
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       case "batch" => {
-	      // -------------------------------------------------------
-	      // Perform forward, backward and update on entire dataset
-	      forward; backward; update
-	      // -------------------------------------------------------
-	      displayLoss(lossLayers(0), shouldValidate)
-	      performSnapshot
+        // -------------------------------------------------------
+        // Perform forward, backward and update on entire dataset
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case "allreduce_parallel_batches" => {
-    	  // This setting uses the batch size provided by the user
-	      if(!inputs.containsKey("$parallel_batches")) {
-	        throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
-	      }
-	      // The user specifies the number of parallel_batches
-	      // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-	      assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
-	      assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
-	      assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
-	      // Grab groups of mini-batches
-	      forBlock("g", "1", "groups") {
-	        // Get next group of mini-batches
-	        assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
-	        assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
-	        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
-	        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
-	        initializeGradients("parallel_batches")
-	        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-	        parForBlock("j", "1", "parallel_batches") {
-	          // Get a mini-batch in this group
-	          assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
-	          assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
-	          assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
-	          assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
-	          forward; backward
-	          flattenGradients
-	        }
-	        aggregateAggGradients    
-	        update
-	        // -------------------------------------------------------
-	        assign(tabDMLScript, "Xb", "X_group_batch")
-	        assign(tabDMLScript, "yb", "y_group_batch")
-	        displayLoss(lossLayers(0), shouldValidate)
-	        performSnapshot
-	      }
-      }
-      case "allreduce" => {
-    	  // This is distributed synchronous gradient descent
-    	  // -------------------------------------------------------
-    	  // Perform forward, backward and update on minibatch in parallel
-    	  assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
-    	  assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
-    	  assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
-    	  assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
-    	  assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
-          val localBatchSize = "local_batch_size"
-          initializeGradients(localBatchSize)
-          parForBlock("j", "1", localBatchSize) {
-            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+        // This setting uses the batch size provided by the user
+        if (!inputs.containsKey("$parallel_batches")) {
+          throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
+        }
+        // The user specifies the number of parallel_batches
+        // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
+        assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
+        assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
+        // Grab groups of mini-batches
+        forBlock("g", "1", "groups") {
+          // Get next group of mini-batches
+          assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
+          assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
+          assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
+          assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
+          initializeGradients("parallel_batches")
+          assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+          parForBlock("j", "1", "parallel_batches") {
+            // Get a mini-batch in this group
+            assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
+            assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
+            assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+            assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
             forward; backward
-          flattenGradients
+            flattenGradients
           }
-          aggregateAggGradients    
+          aggregateAggGradients
           update
           // -------------------------------------------------------
           assign(tabDMLScript, "Xb", "X_group_batch")
           assign(tabDMLScript, "yb", "y_group_batch")
           displayLoss(lossLayers(0), shouldValidate)
           performSnapshot
+        }
+      }
+      case "allreduce" => {
+        // This is distributed synchronous gradient descent
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch in parallel
+        assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
+        assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
+        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
+        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+        tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
+        val localBatchSize = "local_batch_size"
+        initializeGradients(localBatchSize)
+        parForBlock("j", "1", localBatchSize) {
+          assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+          assign(tabDMLScript, "yb", "y_group_batch[j,]")
+          forward; backward
+          flattenGradients
+        }
+        aggregateAggGradients
+        update
+        // -------------------------------------------------------
+        assign(tabDMLScript, "Xb", "X_group_batch")
+        assign(tabDMLScript, "yb", "y_group_batch")
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case _ => throw new DMLRuntimeException("Unsupported train algo:" + getTrainAlgo)
     }
-  }
   // -------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
-  private def trainTestSplit(numValidationBatches:Int):Unit = {
-    if(numValidationBatches > 0) {
-      if(solverParam.getDisplay <= 0) 
+  private def trainTestSplit(numValidationBatches: Int): Unit =
+    if (numValidationBatches > 0) {
+      if (solverParam.getDisplay <= 0)
         throw new DMLRuntimeException("Since test_iter and test_interval is greater than zero, you should set display to be greater than zero")
       tabDMLScript.append(Caffe2DML.numValidationImages).append(" = " + numValidationBatches + " * " + Caffe2DML.batchSize + "\n")
       tabDMLScript.append("# Sanity check to ensure that validation set is not too large\n")
       val maxValidationSize = "ceil(0.3 * " + Caffe2DML.numImages + ")"
-      ifBlock(Caffe2DML.numValidationImages  + " > " + maxValidationSize) {
+      ifBlock(Caffe2DML.numValidationImages + " > " + maxValidationSize) {
         assign(tabDMLScript, "max_test_iter", "floor(" + maxValidationSize + " / " + Caffe2DML.batchSize + ")")
-        tabDMLScript.append("stop(" +
-            dmlConcat(asDMLString("Too large validation size. Please reduce test_iter to "), "max_test_iter") 
-            + ")\n")
+        tabDMLScript.append(
+          "stop(" +
+          dmlConcat(asDMLString("Too large validation size. Please reduce test_iter to "), "max_test_iter")
+          + ")\n"
+        )
       }
       val one = "1"
-      val rl = int_add(Caffe2DML.numValidationImages, one)
+      val rl  = int_add(Caffe2DML.numValidationImages, one)
       rightIndexing(tabDMLScript.append(Caffe2DML.X).append(" = "), "X_full", rl, Caffe2DML.numImages, null, null)
       tabDMLScript.append("; ")
       rightIndexing(tabDMLScript.append(Caffe2DML.y).append(" = "), "y_full", rl, Caffe2DML.numImages, null, null)
@@ -412,41 +432,39 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
       rightIndexing(tabDMLScript.append(Caffe2DML.yVal).append(" = "), "y_full", one, Caffe2DML.numValidationImages, null, null)
       tabDMLScript.append("; ")
       tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y)\n")
-    }
-    else {
+    } else {
       assign(tabDMLScript, Caffe2DML.X, "X_full")
-	    assign(tabDMLScript, Caffe2DML.y, "y_full")
-	    tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y + ")\n")
+      assign(tabDMLScript, Caffe2DML.y, "y_full")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y + ")\n")
     }
-  }
-  
+
   // Append the DML to display training and validation loss
-  private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit = {
-    if(solverParam.getDisplay > 0) {
+  private def displayLoss(lossLayer: IsLossLayer, shouldValidate: Boolean): Unit = {
+    if (solverParam.getDisplay > 0) {
       // Append the DML to compute training loss
-      if(!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
+      if (!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
         // Compute training loss for allreduce
         tabDMLScript.append("# Compute training loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
           lossLayer.computeLoss(dmlScript, numTabs)
           assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, "training_accuracy", "accuracy")
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", training loss:"), "training_loss", asDMLString(", training accuracy:"), "training_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", training loss:"), "training_loss", asDMLString(", training accuracy:"), "training_accuracy"))
+          )
           appendTrainingVisualizationBody(dmlScript, numTabs)
           printClassificationReport
         }
-      }
-      else {
+      } else {
         Caffe2DML.LOG.info("Training loss is not printed for train_algo=" + getTrainAlgo)
       }
-      if(shouldValidate) {
-        if(  getTrainAlgo.toLowerCase.startsWith("allreduce") &&
+      if (shouldValidate) {
+        if (getTrainAlgo.toLowerCase.startsWith("allreduce") &&
             getTestAlgo.toLowerCase.startsWith("allreduce")) {
           Caffe2DML.LOG.warn("The setting: train_algo=" + getTrainAlgo + " and test_algo=" + getTestAlgo + " is not recommended. Consider changing test_algo=minibatch")
         }
         // Append the DML to compute validation loss
-        val numValidationBatches = if(solverParam.getTestIterCount > 0) solverParam.getTestIter(0) else 0
+        val numValidationBatches = if (solverParam.getTestIterCount > 0) solverParam.getTestIter(0) else 0
         tabDMLScript.append("# Compute validation loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getTestInterval + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
@@ -455,11 +473,11 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "0")
               assign(tabDMLScript, "validation_accuracy", "0")
               forBlock("iVal", "1", "num_iters_per_epoch") {
-    	          getValidationBatch(tabDMLScript)
-    	          forward;  lossLayer.computeLoss(dmlScript, numTabs)
+                getValidationBatch(tabDMLScript)
+                forward; lossLayer.computeLoss(dmlScript, numTabs)
                 tabDMLScript.append("validation_loss = validation_loss + loss\n")
                 tabDMLScript.append("validation_accuracy = validation_accuracy + accuracy\n")
-    	        }
+              }
               tabDMLScript.append("validation_accuracy = validation_accuracy / num_iters_per_epoch\n")
             }
             case "batch" => {
@@ -467,16 +485,16 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
               lossLayer.computeLoss(dmlScript, numTabs)
               assign(tabDMLScript, "validation_loss", "loss"); assign(tabDMLScript, "validation_accuracy", "accuracy")
-              
+
             }
             case "allreduce_parallel_batches" => {
               // This setting uses the batch size provided by the user
-              if(!inputs.containsKey("$parallel_batches")) {
+              if (!inputs.containsKey("$parallel_batches")) {
                 throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
               }
               // The user specifies the number of parallel_batches
               // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-              assign(tabDMLScript, "parallel_batches_val", "$parallel_batches") 
+              assign(tabDMLScript, "parallel_batches_val", "$parallel_batches")
               assign(tabDMLScript, "group_batch_size_val", "parallel_batches_val*" + Caffe2DML.batchSize)
               assign(tabDMLScript, "groups_val", "as.integer(ceil(" + Caffe2DML.numValidationImages + "/group_batch_size_val))")
               assign(tabDMLScript, "validation_accuracy", "0")
@@ -511,8 +529,8 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "group_validation_loss", matrix("0", Caffe2DML.numValidationImages, "1"))
               assign(tabDMLScript, "group_validation_accuracy", matrix("0", Caffe2DML.numValidationImages, "1"))
               parForBlock("iVal", "1", Caffe2DML.numValidationImages) {
-                assign(tabDMLScript, "Xb",  Caffe2DML.XVal + "[iVal,]")
-                assign(tabDMLScript, "yb",  Caffe2DML.yVal + "[iVal,]")
+                assign(tabDMLScript, "Xb", Caffe2DML.XVal + "[iVal,]")
+                assign(tabDMLScript, "yb", Caffe2DML.yVal + "[iVal,]")
                 net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
                 lossLayer.computeLoss(dmlScript, numTabs)
                 assign(tabDMLScript, "group_validation_loss[iVal,1]", "loss")
@@ -521,124 +539,132 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "sum(group_validation_loss)")
               assign(tabDMLScript, "validation_accuracy", "mean(group_validation_accuracy)")
             }
-            
+
             case _ => throw new DMLRuntimeException("Unsupported test algo:" + getTestAlgo)
           }
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), "validation_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), "validation_accuracy"))
+          )
           appendValidationVisualizationBody(dmlScript, numTabs)
         }
       }
     }
   }
-  
-  private def performSnapshot():Unit = {
-    if(solverParam.getSnapshot > 0) {
+  private def appendSnapshotWrite(varName: String, fileName: String): Unit =
+    tabDMLScript.append(write(varName, "snapshot_dir + \"" + fileName + "\"", "binary"))
+  private def performSnapshot(): Unit =
+    if (solverParam.getSnapshot > 0) {
       ifBlock("iter %% " + solverParam.getSnapshot + " == 0") {
         tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix + "\" + \"/iter_\" + iter + \"/\"\n")
-        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(
-        	"write(" + l.weight + ", snapshot_dir + \"" + l.param.getName + "_weight.mtx\", format=\"binary\")\n"))
-  		net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(
-  			"write(" + l.bias + ", snapshot_dir + \"" + l.param.getName + "_bias.mtx\", format=\"binary\")\n"))
+        val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+        allLayers.filter(_.weight != null).map(l => appendSnapshotWrite(l.weight, l.param.getName + "_weight.mtx"))
+        allLayers.filter(_.bias != null).map(l => appendSnapshotWrite(l.bias, l.param.getName + "_bias.mtx"))
       }
-  	}
-  }
-  
-  private def forward():Unit = {
+    }
+
+  private def forward(): Unit = {
     tabDMLScript.append("# Perform forward pass\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
+    net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
   }
-  private def backward():Unit = {
+  private def backward(): Unit = {
     tabDMLScript.append("# Perform backward pass\n")
     net.getLayers.reverse.map(layer => net.getCaffeLayer(layer).backward(tabDMLScript, ""))
   }
-  private def update():Unit = {
+  private def update(): Unit = {
     tabDMLScript.append("# Update the parameters\n")
     net.getLayers.map(layer => solver.update(tabDMLScript, net.getCaffeLayer(layer)))
   }
-  private def initializeGradients(parallel_batches:String):Unit = {
+  private def initializeGradients(parallel_batches: String): Unit = {
     tabDMLScript.append("# Data structure to store gradients computed in parallel\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
-      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias)))) 
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
+        if (l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias))))
+      })
   }
-  private def flattenGradients():Unit = {
+  private def flattenGradients(): Unit = {
     tabDMLScript.append("# Flatten and store gradients for this parallel execution\n")
     // Note: We multiply by a weighting to allow for proper gradient averaging during the
     // aggregation even with uneven batch sizes.
     assign(tabDMLScript, "weighting", "nrow(Xb)/X_group_batch_size")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
-          matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * weighting") 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", 
-          matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias)))  + " * weighting")
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * weighting")
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))) + " * weighting")
+      })
   }
-  private def aggregateAggGradients():Unit = {
+  private def aggregateAggGradients(): Unit = {
     tabDMLScript.append("# Aggregate the gradients\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, 
-          matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight))) 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, 
-          matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight)))
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
+      })
   }
   // -------------------------------------------------------------------------------------------
 }
 
-class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:CaffeSolver,
-    val net:CaffeNetwork, val lrPolicy:LearningRatePolicy,
-    val estimator:Caffe2DML) 
-  extends Model[Caffe2DMLModel] with HasMaxOuterIter with BaseSystemMLClassifierModel with DMLGenerator {
+class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: CaffeSolver, val net: CaffeNetwork, val lrPolicy: LearningRatePolicy, val estimator: Caffe2DML)
+    extends Model[Caffe2DMLModel]
+    with HasMaxOuterIter
+    with BaseSystemMLClassifierModel
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  val uid:String = "caffe_model_" + (new Random).nextLong 
-  def this(estimator:Caffe2DML) =  {
-    this(Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
-        estimator.net,
-        // new CaffeNetwork(estimator.solverParam.getNet, caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, estimator.width), 
-        estimator.lrPolicy, estimator) 
+  val uid: String = "caffe_model_" + (new Random).nextLong
+  def this(estimator: Caffe2DML) = {
+    this(
+      Utils.numClasses(estimator.net),
+      estimator.sc,
+      estimator.solver,
+      estimator.net,
+      // new CaffeNetwork(estimator.solverParam.getNet, caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, estimator.width),
+      estimator.lrPolicy,
+      estimator
+    )
   }
-      
+
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Caffe2DMLModel = {
     val that = new Caffe2DMLModel(numClasses, sc, solver, net, lrPolicy, estimator)
     copyValues(that, extra)
   }
   // --------------------------------------------------------------
-  
-  def modelVariables():List[String] = {
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(_.weight) ++
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(_.bias)
+
+  def modelVariables(): List[String] = {
+    val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+    allLayers.filter(_.weight != null).map(_.weight) ++ allLayers.filter(_.bias != null).map(_.bias)
   }
-    
+
   // ================================================================================================
   // The below method parses the provided network and solver file and generates DML script.
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) = {
     val startPredictionTime = System.nanoTime()
-    
-	  reset                                  // Reset the state of DML generator for training script.
-	  
-	  val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
-	  assign(tabDMLScript, "debug", if(DEBUG_PREDICTION) "TRUE" else "FALSE")
-    
-    appendHeaders(net, solver, false)      // Appends DML corresponding to source and externalFunction statements.
-    readInputData(net, false)              // Read X_full and y_full
+
+    reset // Reset the state of DML generator for training script.
+
+    val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, false) // Appends DML corresponding to source and externalFunction statements.
+    readInputData(net, false)         // Read X_full and y_full
     assign(tabDMLScript, "X", "X_full")
-    
+
     // Initialize the layers and solvers. Reads weights and bias if readWeights is true.
-    if(!estimator.inputs.containsKey("$weights") && estimator.mloutput == null) 
+    if (!estimator.inputs.containsKey("$weights") && estimator.mloutput == null)
       throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
     val readWeights = estimator.inputs.containsKey("$weights") || estimator.mloutput != null
     initWeights(net, solver, readWeights)
-	  
-	  // Donot update mean and variance in batchnorm
-	  updateMeanVarianceForBatchNorm(net, false)
-	  
-	  val lossLayers = getLossLayers(net)
-	  val lastLayerShape = estimator.getOutputShapeOfLastLayer
-	  assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, (lastLayerShape._1*lastLayerShape._2*lastLayerShape._3).toString))
-	  estimator.getTestAlgo.toLowerCase match {
+
+    // Donot update mean and variance in batchnorm
+    updateMeanVarianceForBatchNorm(net, false)
+
+    val lossLayers     = getLossLayers(net)
+    val lastLayerShape = estimator.getOutputShapeOfLastLayer
+    assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, (lastLayerShape._1 * lastLayerShape._2 * lastLayerShape._3).toString))
+    estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
         forBlock("iter", "1", "num_iters") {
@@ -654,12 +680,12 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
       }
       case "allreduce_parallel_batches" => {
         // This setting uses the batch size provided by the user
-        if(!estimator.inputs.containsKey("$parallel_batches")) {
+        if (!estimator.inputs.containsKey("$parallel_batches")) {
           throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
         }
         // The user specifies the number of parallel_batches
         // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-        assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
         assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
         assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
         // Grab groups of mini-batches
@@ -688,70 +714,66 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
       }
       case _ => throw new DMLRuntimeException("Unsupported test algo:" + estimator.getTestAlgo)
     }
-    
-    if(estimator.inputs.containsKey("$output_activations")) {
-      if(estimator.getTestAlgo.toLowerCase.equals("batch")) {
-        net.getLayers.map(layer => 
-          tabDMLScript.append(write(net.getCaffeLayer(layer).out, 
-              estimator.inputs.get("$output_activations") + "/" + net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n")
-        )  
-      }
-      else {
+
+    if (estimator.inputs.containsKey("$output_activations")) {
+      if (estimator.getTestAlgo.toLowerCase.equals("batch")) {
+        net.getLayers.map(
+          layer =>
+            tabDMLScript.append(
+              write(net.getCaffeLayer(layer).out, estimator.inputs.get("$output_activations") + "/" + net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n"
+          )
+        )
+      } else {
         throw new DMLRuntimeException("Incorrect usage of output_activations. It should be only used in batch mode.")
       }
     }
-		
-		val predictionScript = dmlScript.toString()
-		System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." )
-		if(DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
-		
-		// Reset state of BatchNorm layer
-		updateMeanVarianceForBatchNorm(net, true)
-		
-	  val script = dml(predictionScript).out("Prob").in(estimator.inputs)
-	  if(estimator.mloutput != null) {
-	    // fit was called
-  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
-  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
-	  }
-	  (script, "X_full")
+
+    val predictionScript = dmlScript.toString()
+    System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime) * 1e-9) + "secs.")
+    if (DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
+
+    // Reset state of BatchNorm layer
+    updateMeanVarianceForBatchNorm(net, true)
+
+    val script = dml(predictionScript).out("Prob").in(estimator.inputs)
+    if (estimator.mloutput != null) {
+      // fit was called
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
+    }
+    (script, "X_full")
   }
   // ================================================================================================
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+
   // Prediction
-  def transform(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(X, sc, "Prob", outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  }
-  def transform_probability(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform_probability(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction of probability assuming classification")
       baseTransformProbability(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction of probability assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransformProbability(X, sc, "Prob", outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  } 
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = {
-    if(estimator.isClassification) {
+
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(df, sc, "Prob", true)
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(df, sc, "Prob", true, outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
index 30d86fd..19aff63 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,16 +22,16 @@ package org.apache.sysml.api.dl
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.net.MalformedURLException;
-import java.net.URL; 
+import java.net.URL;
 import java.net.URLClassLoader;
 import java.io.File;
 
 class Caffe2DMLLoader {
-  def loadCaffe2DML(filePath:String):Unit = {
-    val url = new File(filePath).toURI().toURL();
-		val classLoader = ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
-		val method = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]);
-		method.setAccessible(true);
-	  method.invoke(classLoader, url);
+  def loadCaffe2DML(filePath: String): Unit = {
+    val url         = new File(filePath).toURI().toURL();
+    val classLoader = ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
+    val method      = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]);
+    method.setAccessible(true);
+    method.invoke(classLoader, url);
   }
-}
\ No newline at end of file
+}