You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/21 20:26:23 UTC

incubator-systemml git commit: [MINOR] Added documentation for Caffe2DML APIs.

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 259742c40 -> 700b08094


[MINOR] Added documentation for Caffe2DML APIs.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/700b0809
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/700b0809
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/700b0809

Branch: refs/heads/master
Commit: 700b080940bf68a71728be12ea4e24c2450e9d13
Parents: 259742c
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Sun May 21 13:24:15 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Sun May 21 13:25:29 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/dl/Caffe2DML.scala     | 318 +++++++++----------
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  96 +++++-
 .../org/apache/sysml/api/dl/DMLGenerator.scala  | 112 ++++++-
 3 files changed, 338 insertions(+), 188 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index f7f85c3..fe6b159 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -52,6 +52,43 @@ import org.apache.commons.logging.LogFactory
 import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer
 
 
+/***************************************************************************************
+DESIGN OF CAFFE2DML:
+
+1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that needed to be implemented are:
+- `getTrainingScript` for the Estimator class. 
+- `getPredictionScript` for the Model class.
+
+2. To simplify the DML generation in getTrainingScript and getPredictionScript method, we use DMLGenerator interface.
+This interface generates DML string for common operations such as loops (such as if, for, while) as well as built-in functions (read, write), etc.
+Also, this interface helps in "code reading" of this class :)
+
+3. Additionally, we created mapping classes for layer, solver and learning rate that maps the corresponding Caffe abstraction to the SystemML-NN library.
+This greatly simplifies adding new layers into Caffe2DML:
+trait CaffeLayer {
+  // Any layer that wants to reuse SystemML-NN has to override following methods that help in generating the DML for the given layer:
+  def sourceFileName:String;
+  def init(dmlScript:StringBuilder):Unit;
+  def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
+  def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
+  ...
+} 
+trait CaffeSolver {
+  def sourceFileName:String;
+  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
+  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
+}
+
+4. To simplify the traversal of the network, we created a Network interface:
+trait Network {
+  def getLayers(): List[String]
+  def getCaffeLayer(layerName:String):CaffeLayer
+  def getBottomLayers(layerName:String): Set[String]
+  def getTopLayers(layerName:String): Set[String]
+  def getLayerID(layerName:String): Int
+}
+***************************************************************************************/
+
 object Caffe2DML  {
   val LOG = LogFactory.getLog(classOf[Caffe2DML].getName()) 
   // ------------------------------------------------------------------------
@@ -62,7 +99,7 @@ object Caffe2DML  {
   val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
   val XVal = "X_val"; val yVal = "y_val"
   
-  var USE_NESTEROV_UDF = {
+  val USE_NESTEROV_UDF = {
     // Developer environment variable flag 'USE_NESTEROV_UDF' until codegen starts working.
     // Then, we will remove this flag and also the class org.apache.sysml.udf.lib.SGDNesterovUpdate
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
@@ -118,6 +155,93 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
   // Method called by Python mllearn to visualize variable of certain layer
   def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = visualizeLayer(net, layerName, varType, aggFn)
   
+  // ================================================================================================
+  // The below method parses the provided network and solver file and generates DML script.
+	def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+	  val startTrainingTime = System.nanoTime()
+	  
+    reset                                 // Reset the state of DML generator for training script.
+    
+    // Flags passed by user
+	  val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+	  assign(tabDMLScript, "debug", if(DEBUG_TRAINING) "TRUE" else "FALSE")
+	  
+	  appendHeaders(net, solver, true)      // Appends DML corresponding to source and externalFunction statements.
+	  readInputData(net, true)              // Read X_full and y_full
+	  // Initialize the layers and solvers. Reads weights and bias if $weights is set.
+	  initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
+	  
+	  // Split into training and validation set
+	  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
+	  
+	  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
+	  setIterationVariables
+	  val lossLayers = getLossLayers(net)
+	  // ----------------------------------------------------------------------------
+	  // Main logic
+	  forBlock("e", "1", "max_epochs") {
+	    solverParam.getTrainAlgo.toLowerCase match {
+	      case "minibatch" => 
+	        forBlock("i", "1", "num_iters_per_epoch") {
+	          getTrainingBatch(tabDMLScript)
+	          tabDMLScript.append("iter = start_iter + i\n")
+	          forward; backward; update
+	          displayLoss(lossLayers(0), shouldValidate)
+            performSnapshot
+	        }
+	      case "batch" => {
+          tabDMLScript.append("iter = start_iter + i\n")
+          forward; backward; update
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+	      }
+	      case "allreduce" => {
+	        forBlock("i", "1", "num_iters_per_epoch") {
+	          getTrainingBatch(tabDMLScript)
+	          assign(tabDMLScript, "X_group_batch", "Xb")
+	          assign(tabDMLScript, "y_group_batch", "yb")
+	          tabDMLScript.append("iter = start_iter + i\n")
+	          initAggGradients
+	          parForBlock("j", "1", "nrow(y_group_batch)") {
+	            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+	            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+	            forward; backward("_agg")
+              flattenAndStoreAggGradients_j
+	          }
+	          aggregateAggGradients
+            tabDMLScript.append("iter = start_iter + parallel_batches\n")    
+	          update
+            displayLoss(lossLayers(0), shouldValidate)
+            performSnapshot
+	        }
+	      }
+	      case _ => throw new DMLRuntimeException("Unsupported train algo:" + solverParam.getTrainAlgo)
+	    }
+	    // After every epoch, update the learning rate
+	    tabDMLScript.append("# Learning rate\n")
+	    lrPolicy.updateLearningRate(tabDMLScript)
+	    tabDMLScript.append("start_iter = start_iter + num_iters_per_epoch\n")
+	  }
+	  // ----------------------------------------------------------------------------
+	  
+	  // Check if this is necessary
+	  if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
+	  
+	  val trainingScript = tabDMLScript.toString()
+	  // Print script generation time and the DML script on stdout
+	  System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
+	  if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+	  
+	  // Set input/output variables and execute the script
+	  val script = dml(trainingScript).in(inputs)
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
+	  (script, "X_full", "y_full")
+	}
+	// ================================================================================================
+  
   // -------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
@@ -153,30 +277,6 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
     }
   }
   
-  private def printClassificationReport():Unit = {
-    ifBlock("debug"){
-      assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
-      assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
-      forBlock("class_i", "1", "num_rows_error_measures") {
-        assign(tabDMLScript, "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )")
-        assign(tabDMLScript, "tp_plus_fp", "sum( (predicted_yb == class_i) )")
-        assign(tabDMLScript, "tp_plus_fn", "sum( (true_yb == class_i) )")
-        assign(tabDMLScript, "precision", "tp / tp_plus_fp")
-        assign(tabDMLScript, "recall", "tp / tp_plus_fn")
-        assign(tabDMLScript, "f1Score", "2*precision*recall / (precision+recall)")
-        assign(tabDMLScript, "error_measures[class_i,1]", "class_i")
-        assign(tabDMLScript, "error_measures[class_i,2]", "precision")
-        assign(tabDMLScript, "error_measures[class_i,3]", "recall")
-        assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
-        assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
-      }
-      val dmlTab = "\\t"
-      val header = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
-      val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
-      tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
-    }
-  }
-  
   // Append the DML to display training and validation loss
   private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit = {
     if(solverParam.getDisplay > 0) {
@@ -275,54 +375,9 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
           matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
     })
   }
-  // -------------------------------------------------------------------------------------------
-  
-  private def multiply(v1:String, v2:String):String = v1 + "*" + v2
-  private def colSums(m:String):String = "colSums(" + m + ")"
-  
-	// Script generator
-	def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
-	  val startTrainingTime = System.nanoTime()
-	  val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
-    reset()
-	  
-	  // Add source for layers as well as solver as well as visualization header
-	  source(net, solver, Array[String]("l2_reg"))
-	  appendVisualizationHeaders(dmlScript, numTabs)
-	  
-	  if(Caffe2DML.USE_NESTEROV_UDF) {
-	    tabDMLScript(dmlScript, numTabs).append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n")
-	  }
-	  
-	  // Read and convert to one-hote encoding
-	  assign(tabDMLScript, "X_full", "read(\" \", format=\"csv\")")
-	  assign(tabDMLScript, "y_full", "read(\" \", format=\"csv\")")
-	  tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
-	  tabDMLScript.append("weights = ifdef($weights, \" \")\n")
-	  tabDMLScript.append("debug = ifdef($debug, FALSE)\n")
-	  tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
-	  tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
-	  
-	  // Initialize the layers and solvers
-	  tabDMLScript.append("# Initialize the layers and solvers\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
-	  if(inputs.containsKey("$weights")) {
-		  // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
-		  tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
-	  }
-	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
-	  
-	  // Split into training and validation set
-	  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
-	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
-	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
-	  
-	  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
-	  val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
-	  if(lossLayers.length != 1) throw new DMLRuntimeException("Expected exactly one loss layer")
-	  solverParam.getTrainAlgo.toLowerCase match {
+  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
+  def setIterationVariables():Unit = {
+    solverParam.getTrainAlgo.toLowerCase match {
 	    case "batch" => 
 	      assign(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString)
 	    case _ => {
@@ -332,68 +387,8 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
 	  }
 	  assign(tabDMLScript, "start_iter", "0")
 	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-	  
-	  // ----------------------------------------------------------------------------
-	  // Main logic
-	  forBlock("e", "1", "max_epochs") {
-	    solverParam.getTrainAlgo.toLowerCase match {
-	      case "minibatch" => 
-	        forBlock("i", "1", "num_iters_per_epoch") {
-	          getTrainingBatch(tabDMLScript)
-	          tabDMLScript.append("iter = start_iter + i\n")
-	          forward; backward; update
-	          displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-	        }
-	      case "batch" => {
-          tabDMLScript.append("iter = start_iter + i\n")
-          forward; backward; update
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
-	      }
-	      case "allreduce" => {
-	        forBlock("i", "1", "num_iters_per_epoch") {
-	          getTrainingBatch(tabDMLScript)
-	          assign(tabDMLScript, "X_group_batch", "Xb")
-	          assign(tabDMLScript, "y_group_batch", "yb")
-	          tabDMLScript.append("iter = start_iter + i\n")
-	          initAggGradients
-	          parForBlock("j", "1", "nrow(y_group_batch)") {
-	            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-	            assign(tabDMLScript, "yb", "y_group_batch[j,]")
-	            forward; backward("_agg")
-              flattenAndStoreAggGradients_j
-	          }
-	          aggregateAggGradients
-            tabDMLScript.append("iter = start_iter + parallel_batches\n")    
-	          update
-            displayLoss(lossLayers(0), shouldValidate)
-            performSnapshot
-	        }
-	      }
-	      case _ => throw new DMLRuntimeException("Unsupported train algo:" + solverParam.getTrainAlgo)
-	    }
-	    // After every epoch, update the learning rate
-	    tabDMLScript.append("# Learning rate\n")
-	    lrPolicy.updateLearningRate(tabDMLScript)
-	    tabDMLScript.append("start_iter = start_iter + num_iters_per_epoch\n")
-	  }
-	  // ----------------------------------------------------------------------------
-	  
-	  // Check if this is necessary
-	  if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
-	  
-	  val trainingScript = tabDMLScript.toString()
-	  // Print script generation time and the DML script on stdout
-	  System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
-	  if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
-	  
-	  // Set input/output variables and execute the script
-	  val script = dml(trainingScript).in(inputs)
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
-	  (script, "X_full", "y_full")
-	}
+  }
+  // -------------------------------------------------------------------------------------------
 }
 
 class Caffe2DMLModel(val mloutput: MLResults,  
@@ -431,40 +426,33 @@ class Caffe2DMLModel(val mloutput: MLResults,
 	  ml.execute(script)
 	}
     
+  // ================================================================================================
+  // The below method parses the provided network and solver file and generates DML script.
   def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
-    reset()
     val startPredictionTime = System.nanoTime()
-	  val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    
+	  reset                                  // Reset the state of DML generator for training script.
 	  
-	  // Append source statements for each layer
-	  source(net, solver, null)
-    tabDMLScript.append("weights = ifdef($weights, \" \")\n")
-	  // Initialize the layers and solvers
-	  tabDMLScript.append("# Initialize the layers and solvers\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
-	  if(mloutput == null && estimator.inputs.containsKey("$weights")) {
-		  // fit was not called
-		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
-		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
-	  }
-	  else if(mloutput == null) {
-		  throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
+	  val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+	  assign(tabDMLScript, "debug", if(DEBUG_PREDICTION) "TRUE" else "FALSE")
+    
+    appendHeaders(net, solver, false)      // Appends DML corresponding to source and externalFunction statements.
+    readInputData(net, false)              // Read X_full and y_full
+    assign(tabDMLScript, "X", "X_full")
+    
+    // Initialize the layers and solvers. Reads weights and bias if readWeights is true.
+    val readWeights = {
+	    if(mloutput == null && estimator.inputs.containsKey("$weights")) true
+	    else if(mloutput == null) throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
+	    else false
 	  }
-	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
-	  
-//	  if(estimator.inputs.containsKey("$debug") && estimator.inputs.get("$debug").equals("TRUE")) {
-//		  System.out.println("The output shape of layers:")
-//		  net.getLayers.map(layer =>  System.out.println(net.getCaffeLayer(layer).param.getName + " " + net.getCaffeLayer(layer).outputShape))
-//	  }
+    initWeights(net, solver, readWeights)
 	  
 	  // Donot update mean and variance in batchnorm
-	  net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = false)
-	  tabDMLScript.append("X_full = read(\" \", format=\"csv\")\n")
-	  assign(tabDMLScript, "X", "X_full")
-	  tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+	  updateMeanVarianceForBatchNorm(net, false)
+	  
+	  val lossLayers = getLossLayers(net)
 	  
-	  val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
-	  customAssert(lossLayers.length == 1, "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
 	  assign(tabDMLScript, "Prob", matrix("0", Caffe2DML.numImages, numClasses))
 	  estimator.solverParam.getTestAlgo.toLowerCase match {
       case "minibatch" => {
@@ -495,8 +483,8 @@ class Caffe2DMLModel(val mloutput: MLResults,
 		System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." )
 		if(DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
 		
-		// Reset
-		net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = true)
+		// Reset state of BatchNorm layer
+		updateMeanVarianceForBatchNorm(net, true)
 		
 	  val script = dml(predictionScript).out("Prob").in(estimator.inputs)
 	  if(mloutput != null) {
@@ -504,9 +492,9 @@ class Caffe2DMLModel(val mloutput: MLResults,
   	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, mloutput.getBinaryBlockMatrix(l.weight)))
   	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, mloutput.getBinaryBlockMatrix(l.bias)))
 	  }
-	  
 	  (script, "X_full")
   }
+  // ================================================================================================
   
   // Prediction
   def transform(X: MatrixBlock): MatrixBlock = {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 0620e44..0e39192 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -88,11 +88,30 @@ class LearningRatePolicy(lr_policy:String="exp", base_lr:Double=0.01) {
   }
 }
 
-/**
- * lambda: regularization parameter
- * momentum: Momentum value. Typical values are in the range of [0.5, 0.99], usually started at the lower end and annealed towards the higher end.
- */
 class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+  /*
+   * Performs an SGD update with momentum.
+   *
+   * In SGD with momentum, we assume that the parameters have a velocity
+   * that continues with some momentum, and that is influenced by the
+   * gradient.
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient wrt `X` of a loss function being optimized, of
+   *      same shape as `X`.
+   *  - lr: Learning rate.
+   *  - mu: Momentum value.
+   *      Typical values are in the range of [0.5, 0.99], usually
+   *      started at the lower end and annealed towards the higher end.
+   *  - v: State maintaining the velocity of the parameters `X`, of same
+   *      shape as `X`.
+   *
+   * Outputs:
+   *  - X: Updated parameters `X`, of same shape as input `X`.
+   *  - v: Updated velocity of the parameters `X`, of same shape as
+   *      input `X`.
+   */
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
     l2reg_update(lambda, dmlScript, layer)
     if(momentum == 0) {
@@ -117,13 +136,34 @@ class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
   def sourceFileName:String = if(momentum == 0) "sgd" else "sgd_momentum" 
 }
 
-/**
- * lambda: regularization parameter
- * epsilon: Smoothing term to avoid divide by zero errors. Typical values are in the range of [1e-8, 1e-4].
- * 
- * See Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, Duchi et al.
- */
 class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
+  /*
+   * Performs an Adagrad update.
+   *
+   * This is an adaptive learning rate optimizer that maintains the
+   * sum of squared gradients to automatically adjust the effective
+   * learning rate.
+   *
+   * Reference:
+   *  - Adaptive Subgradient Methods for Online Learning and Stochastic
+   *    Optimization, Duchi et al.
+   *      - http://jmlr.org/papers/v12/duchi11a.html
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient wrt `X` of a loss function being optimized, of
+   *      same shape as `X`.
+   *  - lr: Learning rate.
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-8, 1e-4].
+   *  - cache: State that maintains per-parameter sum of squared
+   *      gradients, of same shape as `X`.
+   *
+   * Outputs:
+   *  - X: Updated parameters `X`, of same shape as input `X`.
+   *  - cache: State that maintains per-parameter sum of squared
+   *      gradients, of same shape as `X`.
+   */
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
     l2reg_update(lambda, dmlScript, layer)
     if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_cache") + "] " + 
@@ -138,11 +178,39 @@ class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
   def sourceFileName:String = "adagrad"
 }
 
-/**
- * lambda: regularization parameter
- * momentum: Momentum value. Typical values are in the range of [0.5, 0.99], usually started at the lower end and annealed towards the higher end.
- */
 class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+  /*
+   * Performs an SGD update with Nesterov momentum.
+   *
+   * As with regular SGD with momentum, in SGD with Nesterov momentum,
+   * we assume that the parameters have a velocity that continues
+   * with some momentum, and that is influenced by the gradient.
+   * In this view specifically, we perform the position update from the
+   * position that the momentum is about to carry the parameters to,
+   * rather than from the previous position.  Additionally, we always
+   * store the parameters in their position after momentum.
+   *
+   * Reference:
+   *  - Advances in optimizing Recurrent Networks, Bengio et al.,
+   *    section 3.5.
+   *    - http://arxiv.org/abs/1212.0901
+   *
+   * Inputs:
+   *  - X: Parameters to update, of shape (any, any).
+   *  - dX: Gradient wrt `X` of a loss function being optimized, of
+   *      same shape as `X`.
+   *  - lr: Learning rate.
+   *  - mu: Momentum value.
+   *      Typical values are in the range of [0.5, 0.99], usually
+   *      started at the lower end and annealed towards the higher end.
+   *  - v: State maintaining the velocity of the parameters `X`, of same
+   *      shape as `X`.
+   *
+   * Outputs:
+   *  - X: Updated parameters X, of same shape as input X.
+   *  - v: Updated velocity of the parameters X, of same shape as
+   *      input v.
+   */
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
     val fn = if(Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
     val lastParameter = if(Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 668d996..456b032 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -54,7 +54,7 @@ trait BaseDMLGenerator {
   def isNumber(x: String):Boolean = x forall Character.isDigit
   def transpose(x:String):String = "t(" + x + ")"
   def write(varName:String, fileName:String, format:String):String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
-  def read(varName:String, fileName:String, sep:String="/"):String =  varName + " = read(weights + \"" + sep + fileName + "\")\n"
+  def readWeight(varName:String, fileName:String, sep:String="/"):String =  varName + " = read(weights + \"" + sep + fileName + "\")\n"
   def asDMLString(str:String):String = "\"" + str + "\""
   def assign(dmlScript:StringBuilder, lhsVar:String, rhsVar:String):Unit = {
     dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("\n")
@@ -132,6 +132,11 @@ trait BaseDMLGenerator {
   def nrow(m:String):String = "nrow(" + m + ")"
   def ncol(m:String):String = "ncol(" + m + ")"
   def customAssert(cond:Boolean, msg:String) = if(!cond) throw new DMLRuntimeException(msg)
+  def multiply(v1:String, v2:String):String = v1 + "*" + v2
+  def colSums(m:String):String = "colSums(" + m + ")"
+  def ifdef(cmdLineVar:String, defaultVal:String):String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
+  def ifdef(cmdLineVar:String):String = ifdef(cmdLineVar, "\" \"")
+  def read(filePathVar:String, format:String):String = "read(" + filePathVar + ", format=\""+ format + "\")"
 }
 
 trait TabbedDMLGenerator extends BaseDMLGenerator {
@@ -229,14 +234,6 @@ trait VisualizeDMLGenerator extends TabbedDMLGenerator {
         + ");\n")
     dmlScript.append("viz_counter = viz_counter + viz_counter1\n")
   }
-  def appendVisualizationHeaders(dmlScript:StringBuilder, numTabs:Int): Unit = {
-    if(doVisualize) {
-	    tabDMLScript(dmlScript, numTabs).append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
-	        "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
-	    tabDMLScript(dmlScript, numTabs).append("viz_counter = 0\n")
-	    System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
-	  }
-  }
   def visualizeLayer(net:CaffeNetwork, layerName:String, varType:String, aggFn:String): Unit = {
 	  // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
 	  // 'sum', 'mean', 'var' or 'sd'
@@ -316,4 +313,101 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
 	  tabDMLScript.append("}\n")
 	}
 	
+	def printClassificationReport():Unit = {
+    ifBlock("debug"){
+      assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
+      assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
+      forBlock("class_i", "1", "num_rows_error_measures") {
+        assign(tabDMLScript, "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )")
+        assign(tabDMLScript, "tp_plus_fp", "sum( (predicted_yb == class_i) )")
+        assign(tabDMLScript, "tp_plus_fn", "sum( (true_yb == class_i) )")
+        assign(tabDMLScript, "precision", "tp / tp_plus_fp")
+        assign(tabDMLScript, "recall", "tp / tp_plus_fn")
+        assign(tabDMLScript, "f1Score", "2*precision*recall / (precision+recall)")
+        assign(tabDMLScript, "error_measures[class_i,1]", "class_i")
+        assign(tabDMLScript, "error_measures[class_i,2]", "precision")
+        assign(tabDMLScript, "error_measures[class_i,3]", "recall")
+        assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
+        assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
+      }
+      val dmlTab = "\\t"
+      val header = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+      val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
+      tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
+    }
+  }
+	
+	// Appends DML corresponding to source and externalFunction statements. 
+  def appendHeaders(net:CaffeNetwork, solver:CaffeSolver, isTraining:Boolean):Unit = {
+    // Append source statements for layers as well as solver
+	  source(net, solver, if(isTraining) Array[String]("l2_reg") else null)
+	  
+	  if(isTraining) {
+  	  // Append external built-in function headers:
+  	  // 1. visualize external built-in function header
+      if(doVisualize) {
+  	    tabDMLScript.append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
+  	        "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
+  	    tabDMLScript.append("viz_counter = 0\n")
+  	    System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
+  	  }
+  	  // 2. update_nesterov external built-in function header
+  	  if(Caffe2DML.USE_NESTEROV_UDF) {
+  	    tabDMLScript.append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n")
+  	  }
+	  }
+  }
+  
+  def readMatrix(varName:String, cmdLineVar:String):Unit = {
+    val pathVar = varName + "_path"
+    assign(tabDMLScript, pathVar, ifdef(cmdLineVar))
+    // Uncomment the following lines if we want to the user to pass the format
+    // val formatVar = varName + "_fmt"
+    // assign(tabDMLScript, formatVar, ifdef(cmdLineVar + "_fmt", "\"csv\""))
+    // assign(tabDMLScript, varName, "read(" + pathVar + ", format=" + formatVar + ")")
+    assign(tabDMLScript, varName, "read(" + pathVar + ")")
+  }
+  
+  def readInputData(net:CaffeNetwork, isTraining:Boolean):Unit = {
+    // Read and convert to one-hot encoding
+    readMatrix("X_full", "$X")
+	  if(isTraining) {
+	    readMatrix("y_full", "$y")
+  	  tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+  	  tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+	    tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+	  }
+	  else {
+	    tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+	  }
+  }
+  
+  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean): Unit = {
+    initWeights(net, solver, readWeights, new HashSet[String]())
+  }
+  
+  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean, layersToIgnore:HashSet[String]): Unit = {
+    tabDMLScript.append("weights = ifdef($weights, \" \")\n")
+	  // Initialize the layers and solvers
+	  tabDMLScript.append("# Initialize the layers and solvers\n")
+	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+	  if(readWeights) {
+		  // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+		  tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
+		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
+	  }
+	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+  }
+  
+  def getLossLayers(net:CaffeNetwork):List[IsLossLayer] = {
+    val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
+	  if(lossLayers.length != 1) 
+	    throw new DMLRuntimeException("Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
+	  lossLayers
+  }
+  
+  def updateMeanVarianceForBatchNorm(net:CaffeNetwork, value:Boolean):Unit = {
+    net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
+  }
 }
\ No newline at end of file