You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/15 18:03:08 UTC

[1/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

Repository: systemml
Updated Branches:
  refs/heads/master ebb6ea612 -> f07b5a2d9


http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 2684261..5d43730 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -43,264 +43,266 @@ import org.apache.spark.api.java.JavaSparkContext
 object Utils {
   // ---------------------------------------------------------------------------------------------
   // Helper methods for DML generation
-  
+
   // Returns number of classes if inferred from the network
-  def numClasses(net:CaffeNetwork):String = {
-  	try {
-  		return "" + net.getCaffeLayer(net.getLayers().last).outputShape._1.toLong
-  	} catch {
-  		case _:Throwable => {
-  			Caffe2DML.LOG.warn("Cannot infer the number of classes from network definition. User needs to pass it via set(num_classes=...) method.")
-  			return "$num_classes" // Expect users to provide it
-  		}
-  	}
-  }
-  def prettyPrintDMLScript(script:String) {
-	  val bufReader = new BufferedReader(new StringReader(script))
-	  var line = bufReader.readLine();
-	  var lineNum = 1
-	  while( line != null ) {
-		  System.out.println( "%03d".format(lineNum) + "|" + line)
-		  lineNum = lineNum + 1
-		  line = bufReader.readLine()
-	  }
+  def numClasses(net: CaffeNetwork): String =
+    try {
+      return "" + net.getCaffeLayer(net.getLayers().last).outputShape._1.toLong
+    } catch {
+      case _: Throwable => {
+        Caffe2DML.LOG.warn("Cannot infer the number of classes from network definition. User needs to pass it via set(num_classes=...) method.")
+        return "$num_classes" // Expect users to provide it
+      }
+    }
+  def prettyPrintDMLScript(script: String) {
+    val bufReader = new BufferedReader(new StringReader(script))
+    var line      = bufReader.readLine();
+    var lineNum   = 1
+    while (line != null) {
+      System.out.println("%03d".format(lineNum) + "|" + line)
+      lineNum = lineNum + 1
+      line = bufReader.readLine()
+    }
   }
-  
+
   // ---------------------------------------------------------------------------------------------
-  def parseSolver(solverFilePath:String): CaffeSolver = parseSolver(readCaffeSolver(solverFilePath))
-	def parseSolver(solver:SolverParameter): CaffeSolver = {
-	  val momentum = if(solver.hasMomentum) solver.getMomentum else 0.0
-	  val lambda = if(solver.hasWeightDecay) solver.getWeightDecay else 0.0
-	  val delta = if(solver.hasDelta) solver.getDelta else 0.0
-	  
-	  solver.getType.toLowerCase match {
-	    case "sgd" => new SGD(lambda, momentum)
-	    case "adagrad" => new AdaGrad(lambda, delta)
-	    case "nesterov" => new Nesterov(lambda, momentum)
-	    case _ => throw new DMLRuntimeException("The solver type is not supported: " + solver.getType + ". Try: SGD, AdaGrad or Nesterov.")
-	  }
-    
+  def parseSolver(solverFilePath: String): CaffeSolver = parseSolver(readCaffeSolver(solverFilePath))
+  def parseSolver(solver: SolverParameter): CaffeSolver = {
+    val momentum = if (solver.hasMomentum) solver.getMomentum else 0.0
+    val lambda   = if (solver.hasWeightDecay) solver.getWeightDecay else 0.0
+    val delta    = if (solver.hasDelta) solver.getDelta else 0.0
+
+    solver.getType.toLowerCase match {
+      case "sgd"      => new SGD(lambda, momentum)
+      case "adagrad"  => new AdaGrad(lambda, delta)
+      case "nesterov" => new Nesterov(lambda, momentum)
+      case _          => throw new DMLRuntimeException("The solver type is not supported: " + solver.getType + ". Try: SGD, AdaGrad or Nesterov.")
+    }
+
   }
-  
-	// --------------------------------------------------------------
-	// Caffe utility functions
-	def readCaffeNet(netFilePath:String):NetParameter = {
-	  // Load network
-		val reader:InputStreamReader = getInputStreamReader(netFilePath); 
-  	val builder:NetParameter.Builder =  NetParameter.newBuilder();
-  	TextFormat.merge(reader, builder);
-  	return builder.build();
-	}
-	
-	class CopyFloatToDoubleArray(data:java.util.List[java.lang.Float], rows:Int, cols:Int, transpose:Boolean, arr:Array[Double]) extends Thread {
-	  override def run(): Unit = {
-	    if(transpose) {
+
+  // --------------------------------------------------------------
+  // Caffe utility functions
+  def readCaffeNet(netFilePath: String): NetParameter = {
+    // Load network
+    val reader: InputStreamReader     = getInputStreamReader(netFilePath);
+    val builder: NetParameter.Builder = NetParameter.newBuilder();
+    TextFormat.merge(reader, builder);
+    return builder.build();
+  }
+
+  class CopyFloatToDoubleArray(data: java.util.List[java.lang.Float], rows: Int, cols: Int, transpose: Boolean, arr: Array[Double]) extends Thread {
+    override def run(): Unit =
+      if (transpose) {
         var iter = 0
-        for(i <- 0 until cols) {
-          for(j <- 0 until rows) {
-            arr(j*cols + i) = data.get(iter).doubleValue()
+        for (i <- 0 until cols) {
+          for (j <- 0 until rows) {
+            arr(j * cols + i) = data.get(iter).doubleValue()
             iter += 1
           }
         }
-      }
-      else {
-        for(i <- 0 until data.size()) {
+      } else {
+        for (i <- 0 until data.size()) {
           arr(i) = data.get(i).doubleValue()
         }
       }
-	  }
-	}
-	
-	class CopyCaffeDeconvFloatToSystemMLDeconvDoubleArray(data:java.util.List[java.lang.Float], F:Int, C:Int, H:Int, W:Int, arr:Array[Double]) 
-	    extends CopyFloatToDoubleArray(data, C, F*H*W, false, arr) {
-	  override def run(): Unit = {
-	    var i = 0
-	    for(f <- 0 until F) {
-	      for(c <- 0 until C) {
-	        for(hw <- 0 until H*W) {
-	          arr(c*F*H*W + f*H*W + hw) = data.get(i).doubleValue()
-	          i = i+1
-	        }
-	      }
-	    }
-	  }
-	}
-	
-	def allocateDeconvolutionWeight(data:java.util.List[java.lang.Float], F:Int, C:Int, H:Int, W:Int):(MatrixBlock,CopyFloatToDoubleArray) = {
-	  val mb =  new MatrixBlock(C, F*H*W, false)
+  }
+
+  class CopyCaffeDeconvFloatToSystemMLDeconvDoubleArray(data: java.util.List[java.lang.Float], F: Int, C: Int, H: Int, W: Int, arr: Array[Double])
+      extends CopyFloatToDoubleArray(data, C, F * H * W, false, arr) {
+    override def run(): Unit = {
+      var i = 0
+      for (f <- 0 until F) {
+        for (c <- 0 until C) {
+          for (hw <- 0 until H * W) {
+            arr(c * F * H * W + f * H * W + hw) = data.get(i).doubleValue()
+            i = i + 1
+          }
+        }
+      }
+    }
+  }
+
+  def allocateDeconvolutionWeight(data: java.util.List[java.lang.Float], F: Int, C: Int, H: Int, W: Int): (MatrixBlock, CopyFloatToDoubleArray) = {
+    val mb = new MatrixBlock(C, F * H * W, false)
     mb.allocateDenseBlock()
-    val arr = mb.getDenseBlock
+    val arr    = mb.getDenseBlock
     val thread = new CopyCaffeDeconvFloatToSystemMLDeconvDoubleArray(data, F, C, H, W, arr)
-	  thread.start
-	  return (mb, thread)
-	}
-	
-	def allocateMatrixBlock(data:java.util.List[java.lang.Float], rows:Int, cols:Int, transpose:Boolean):(MatrixBlock,CopyFloatToDoubleArray) = {
-	  val mb =  new MatrixBlock(rows, cols, false)
+    thread.start
+    return (mb, thread)
+  }
+
+  def allocateMatrixBlock(data: java.util.List[java.lang.Float], rows: Int, cols: Int, transpose: Boolean): (MatrixBlock, CopyFloatToDoubleArray) = {
+    val mb = new MatrixBlock(rows, cols, false)
     mb.allocateDenseBlock()
-    val arr = mb.getDenseBlock
+    val arr    = mb.getDenseBlock
     val thread = new CopyFloatToDoubleArray(data, rows, cols, transpose, arr)
-	  thread.start
-	  return (mb, thread)
-	}
-	def validateShape(shape:Array[Int], data:java.util.List[java.lang.Float], layerName:String): Unit = {
-	  if(shape == null) 
+    thread.start
+    return (mb, thread)
+  }
+  def validateShape(shape: Array[Int], data: java.util.List[java.lang.Float], layerName: String): Unit =
+    if (shape == null)
       throw new DMLRuntimeException("Unexpected weight for layer: " + layerName)
-    else if(shape.length != 2) 
+    else if (shape.length != 2)
       throw new DMLRuntimeException("Expected shape to be of length 2:" + layerName)
-    else if(shape(0)*shape(1) != data.size())
-      throw new DMLRuntimeException("Incorrect size of blob from caffemodel for the layer " + layerName + ". Expected of size " + shape(0)*shape(1) + ", but found " + data.size())
-	}
-	
-	def saveCaffeModelFile(sc:JavaSparkContext, deployFilePath:String, 
-	    caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
-	  saveCaffeModelFile(sc.sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
-	}
-	
-	def saveCaffeModelFile(sc:SparkContext, deployFilePath:String, caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
-	  val inputVariables = new java.util.HashMap[String, MatrixBlock]()
-	  readCaffeNet(new CaffeNetwork(deployFilePath), deployFilePath, caffeModelFilePath, inputVariables)
-	  val ml = new MLContext(sc)
-	  val dmlScript = new StringBuilder
-	  if(inputVariables.keys.size == 0)
-	    throw new DMLRuntimeException("No weights found in the file " + caffeModelFilePath)
-	  for(input <- inputVariables.keys) {
-	    dmlScript.append("write(" + input + ", \"" + outputDirectory + "/" + input + ".mtx\", format=\"" + format + "\");\n")
-	  }
-	  if(Caffe2DML.LOG.isDebugEnabled())
-	    Caffe2DML.LOG.debug("Executing the script:" + dmlScript.toString)
-	  val script = org.apache.sysml.api.mlcontext.ScriptFactory.dml(dmlScript.toString()).in(inputVariables)
-	  ml.execute(script)
-	}
-	
-	def readCaffeNet(net:CaffeNetwork, netFilePath:String, weightsFilePath:String, inputVariables:java.util.HashMap[String, MatrixBlock]):NetParameter = {
-	  // Load network
-		val reader:InputStreamReader = getInputStreamReader(netFilePath); 
-  	val builder:NetParameter.Builder =  NetParameter.newBuilder();
-  	TextFormat.merge(reader, builder);
-  	// Load weights
-	  val inputStream = CodedInputStream.newInstance(new FileInputStream(weightsFilePath))
-	  inputStream.setSizeLimit(Integer.MAX_VALUE)
-	  builder.mergeFrom(inputStream)
-	  val net1 = builder.build();
-	  
-	  val asyncThreads = new java.util.ArrayList[CopyFloatToDoubleArray]()
-	  val v1Layers = net1.getLayersList.map(layer => layer.getName -> layer).toMap
-	  for(layer <- net1.getLayerList) {
-	    val blobs = if(layer.getBlobsCount != 0) layer.getBlobsList else if(v1Layers.contains(layer.getName)) v1Layers.get(layer.getName).get.getBlobsList else null
-	      
-	    if(blobs == null || blobs.size == 0) {
-	      // No weight or bias
-	      Caffe2DML.LOG.debug("The layer:" + layer.getName + " has no blobs")
-	    }
-	    else if(blobs.size == 2) {
-	      // Both weight and bias
-	      val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
-	      val transpose = caffe2DMLLayer.isInstanceOf[InnerProduct]
-	      
-	      // weight
-	      val data = blobs(0).getDataList
-	      val shape = caffe2DMLLayer.weightShape()
-	      if(shape == null)
-  	        throw new DMLRuntimeException("Didnot expect weights for the layer " + layer.getName)
-	      validateShape(shape, data, layer.getName)
-	      
-	      val ret1 = if(caffe2DMLLayer.isInstanceOf[DeConvolution]) {
-	        // Swap dimensions: Caffe's format (F, C*Hf*Wf) to NN layer's format (C, F*Hf*Wf).
-	        val deconvLayer = caffe2DMLLayer.asInstanceOf[DeConvolution]
-	        val C = shape(0)
-	        val F = deconvLayer.numKernels.toInt
-	        val Hf = deconvLayer.kernel_h.toInt
-	        val Wf = deconvLayer.kernel_w.toInt
-	        allocateDeconvolutionWeight(data, F, C, Hf, Wf)
-	      }
-	      else {
-  	      allocateMatrixBlock(data, shape(0), shape(1), transpose)
-	      }
-	      asyncThreads.add(ret1._2)
-	      inputVariables.put(caffe2DMLLayer.weight, ret1._1)
-	      
-	      // bias
-	      val biasData = blobs(1).getDataList
-	      val biasShape = caffe2DMLLayer.biasShape()
-	      if(biasShape == null)
-	        throw new DMLRuntimeException("Didnot expect bias for the layer " + layer.getName)
-	      validateShape(biasShape, biasData, layer.getName)
-	      val ret2 = allocateMatrixBlock(biasData, biasShape(0), biasShape(1), transpose)
-	      asyncThreads.add(ret2._2)
-	      inputVariables.put(caffe2DMLLayer.bias, ret2._1)
-	      Caffe2DML.LOG.debug("Read weights/bias for layer:" + layer.getName)
-	    }
-	    else if(blobs.size == 1) {
-	      // Special case: convolution/deconvolution without bias
-	      // TODO: Extend nn layers to handle this situation + Generalize this to other layers, for example: InnerProduct
-	      val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
-	      val convParam = if((caffe2DMLLayer.isInstanceOf[Convolution] || caffe2DMLLayer.isInstanceOf[DeConvolution]) && caffe2DMLLayer.param.hasConvolutionParam())  caffe2DMLLayer.param.getConvolutionParam else null  
-	      if(convParam == null)
-	        throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
-	      else if(convParam.hasBiasTerm && convParam.getBiasTerm)
-	        throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " and with bias term is not supported for the layer " + layer.getName)
-	     
-	      val data = blobs(0).getDataList
-	      val shape = caffe2DMLLayer.weightShape()
-	      validateShape(shape, data, layer.getName)
-	      val ret1 = allocateMatrixBlock(data, shape(0), shape(1), false)
-	      asyncThreads.add(ret1._2)
-	      inputVariables.put(caffe2DMLLayer.weight, ret1._1)
-	      inputVariables.put(caffe2DMLLayer.bias, new MatrixBlock(convParam.getNumOutput, 1, false))
-	      Caffe2DML.LOG.debug("Read only weight for layer:" + layer.getName)
-	    }
-	    else {
-	      throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
-	    }
-	  }
-	  
-	  // Wait for the copy to be finished
-	  for(t <- asyncThreads) {
-	    t.join()
-	  }
-	  
-	  for(mb <- inputVariables.values()) {
-	    mb.recomputeNonZeros();
-	  }
-	  
-	  // Return the NetParameter without
-	  return readCaffeNet(netFilePath)
-	}
-	
-	def readCaffeSolver(solverFilePath:String):SolverParameter = {
-		val reader = getInputStreamReader(solverFilePath);
-		val builder =  SolverParameter.newBuilder();
-		TextFormat.merge(reader, builder);
-		return builder.build();
-	}
-	
-	// --------------------------------------------------------------
-	// File IO utility functions
-	def writeToFile(content:String, filePath:String): Unit = {
-		val pw = new java.io.PrintWriter(new File(filePath))
-		pw.write(content)
-		pw.close
-	}
-	def getInputStreamReader(filePath:String ):InputStreamReader = {
-		//read solver script from file
-		if(filePath == null)
-			throw new LanguageException("file path was not specified!");
-		if(filePath.startsWith("hdfs:")  || filePath.startsWith("gpfs:")) { 
-			val fs = FileSystem.get(ConfigurationManager.getCachedJobConf());
-			return new InputStreamReader(fs.open(new Path(filePath)));
-		}
-		else { 
-			return new InputStreamReader(new FileInputStream(new File(filePath)), "ASCII");
-		}
-	}
-	// --------------------------------------------------------------
+    else if (shape(0) * shape(1) != data.size())
+      throw new DMLRuntimeException(
+        "Incorrect size of blob from caffemodel for the layer " + layerName + ". Expected of size " + shape(0) * shape(1) + ", but found " + data.size()
+      )
+
+  def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit =
+    saveCaffeModelFile(sc.sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
+
+  def saveCaffeModelFile(sc: SparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit = {
+    val inputVariables = new java.util.HashMap[String, MatrixBlock]()
+    readCaffeNet(new CaffeNetwork(deployFilePath), deployFilePath, caffeModelFilePath, inputVariables)
+    val ml        = new MLContext(sc)
+    val dmlScript = new StringBuilder
+    if (inputVariables.keys.size == 0)
+      throw new DMLRuntimeException("No weights found in the file " + caffeModelFilePath)
+    for (input <- inputVariables.keys) {
+      dmlScript.append("write(" + input + ", \"" + outputDirectory + "/" + input + ".mtx\", format=\"" + format + "\");\n")
+    }
+    if (Caffe2DML.LOG.isDebugEnabled())
+      Caffe2DML.LOG.debug("Executing the script:" + dmlScript.toString)
+    val script = org.apache.sysml.api.mlcontext.ScriptFactory.dml(dmlScript.toString()).in(inputVariables)
+    ml.execute(script)
+  }
+
+  def readCaffeNet(net: CaffeNetwork, netFilePath: String, weightsFilePath: String, inputVariables: java.util.HashMap[String, MatrixBlock]): NetParameter = {
+    // Load network
+    val reader: InputStreamReader     = getInputStreamReader(netFilePath);
+    val builder: NetParameter.Builder = NetParameter.newBuilder();
+    TextFormat.merge(reader, builder);
+    // Load weights
+    val inputStream = CodedInputStream.newInstance(new FileInputStream(weightsFilePath))
+    inputStream.setSizeLimit(Integer.MAX_VALUE)
+    builder.mergeFrom(inputStream)
+    val net1 = builder.build();
+
+    val asyncThreads = new java.util.ArrayList[CopyFloatToDoubleArray]()
+    val v1Layers     = net1.getLayersList.map(layer => layer.getName -> layer).toMap
+
+    for (i <- 0 until net1.getLayerList.length) {
+      val layer = net1.getLayerList.get(i)
+      val blobs = getBlobs(layer, v1Layers)
+
+      if (blobs == null || blobs.size == 0) {
+        // No weight or bias
+        Caffe2DML.LOG.debug("The layer:" + layer.getName + " has no blobs")
+      } else if (blobs.size == 2 || (blobs.size == 3 && net.getCaffeLayer(layer.getName).isInstanceOf[BatchNorm])) {
+        // Both weight and bias
+        val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
+        val transpose      = caffe2DMLLayer.isInstanceOf[InnerProduct]
+
+        // weight
+        val shape = caffe2DMLLayer.weightShape()
+        if (shape == null)
+          throw new DMLRuntimeException("Didnot expect weights for the layer " + layer.getName)
+        if (caffe2DMLLayer.isInstanceOf[DeConvolution]) {
+          val data = blobs(0).getDataList
+          validateShape(shape, data, layer.getName)
+          // Swap dimensions: Caffe's format (F, C*Hf*Wf) to NN layer's format (C, F*Hf*Wf).
+          val deconvLayer = caffe2DMLLayer.asInstanceOf[DeConvolution]
+          val C           = shape(0)
+          val F           = deconvLayer.numKernels.toInt
+          val Hf          = deconvLayer.kernel_h.toInt
+          val Wf          = deconvLayer.kernel_w.toInt
+          val ret1        = allocateDeconvolutionWeight(data, F, C, Hf, Wf)
+          asyncThreads.add(ret1._2)
+          inputVariables.put(caffe2DMLLayer.weight, ret1._1)
+        } else {
+          inputVariables.put(caffe2DMLLayer.weight, getMBFromBlob(blobs(0), shape, layer.getName, transpose, asyncThreads))
+        }
+
+        // bias
+        val biasShape = caffe2DMLLayer.biasShape()
+        if (biasShape == null)
+          throw new DMLRuntimeException("Didnot expect bias for the layer " + layer.getName)
+        inputVariables.put(caffe2DMLLayer.bias, getMBFromBlob(blobs(1), biasShape, layer.getName, transpose, asyncThreads))
+        Caffe2DML.LOG.debug("Read weights/bias for layer:" + layer.getName)
+      } else if (blobs.size == 1) {
+        // Special case: convolution/deconvolution without bias
+        // TODO: Extend nn layers to handle this situation + Generalize this to other layers, for example: InnerProduct
+        val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
+        val convParam =
+          if ((caffe2DMLLayer.isInstanceOf[Convolution] || caffe2DMLLayer.isInstanceOf[DeConvolution]) && caffe2DMLLayer.param.hasConvolutionParam())
+            caffe2DMLLayer.param.getConvolutionParam
+          else null
+        if (convParam == null)
+          throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
+        else if (convParam.hasBiasTerm && convParam.getBiasTerm)
+          throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " and with bias term is not supported for the layer " + layer.getName)
+
+        inputVariables.put(caffe2DMLLayer.weight, getMBFromBlob(blobs(0), caffe2DMLLayer.weightShape(), layer.getName, false, asyncThreads))
+        inputVariables.put(caffe2DMLLayer.bias, new MatrixBlock(convParam.getNumOutput, 1, false))
+        Caffe2DML.LOG.debug("Read only weight for layer:" + layer.getName)
+      } else {
+        throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
+      }
+    }
+
+    // Wait for the copy to be finished
+    for (t <- asyncThreads) {
+      t.join()
+    }
+
+    for (mb <- inputVariables.values()) {
+      mb.recomputeNonZeros();
+    }
+
+    // Return the NetParameter without
+    return readCaffeNet(netFilePath)
+  }
+
+  def getBlobs(layer: LayerParameter, v1Layers: scala.collection.immutable.Map[String, caffe.Caffe.V1LayerParameter]): java.util.List[caffe.Caffe.BlobProto] =
+    if (layer.getBlobsCount != 0)
+      layer.getBlobsList
+    else if (v1Layers.contains(layer.getName)) v1Layers.get(layer.getName).get.getBlobsList
+    else null
+
+  def getMBFromBlob(blob: caffe.Caffe.BlobProto,
+                    shape: Array[Int],
+                    layerName: String,
+                    transpose: Boolean,
+                    asyncThreads: java.util.ArrayList[CopyFloatToDoubleArray]): MatrixBlock = {
+    val data = blob.getDataList
+    validateShape(shape, data, layerName)
+    val ret1 = allocateMatrixBlock(data, shape(0), shape(1), transpose)
+    asyncThreads.add(ret1._2)
+    return ret1._1
+  }
+
+  def readCaffeSolver(solverFilePath: String): SolverParameter = {
+    val reader  = getInputStreamReader(solverFilePath);
+    val builder = SolverParameter.newBuilder();
+    TextFormat.merge(reader, builder);
+    return builder.build();
+  }
+
+  // --------------------------------------------------------------
+  // File IO utility functions
+  def writeToFile(content: String, filePath: String): Unit = {
+    val pw = new java.io.PrintWriter(new File(filePath))
+    pw.write(content)
+    pw.close
+  }
+  def getInputStreamReader(filePath: String): InputStreamReader = {
+    //read solver script from file
+    if (filePath == null)
+      throw new LanguageException("file path was not specified!");
+    if (filePath.startsWith("hdfs:") || filePath.startsWith("gpfs:")) {
+      val fs = FileSystem.get(ConfigurationManager.getCachedJobConf());
+      return new InputStreamReader(fs.open(new Path(filePath)));
+    } else {
+      return new InputStreamReader(new FileInputStream(new File(filePath)), "ASCII");
+    }
+  }
+  // --------------------------------------------------------------
 }
 
 class Utils {
-  def saveCaffeModelFile(sc:JavaSparkContext, deployFilePath:String, 
-	    caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
+  def saveCaffeModelFile(sc: JavaSparkContext, deployFilePath: String, caffeModelFilePath: String, outputDirectory: String, format: String): Unit =
     Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
-  }
-  
+
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index f42acb5..ec086eb 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,13 +23,13 @@ import org.apache.spark.api.java.JavaSparkContext
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
-import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt, RDDConverterUtils }
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtils, RDDConverterUtilsExt }
 import org.apache.sysml.api.mlcontext._
 import org.apache.sysml.api.mlcontext.ScriptFactory._
 import org.apache.spark.sql._
@@ -38,12 +38,11 @@ import java.util.HashMap
 import scala.collection.JavaConversions._
 import java.util.Random
 
-
 /****************************************************
 DESIGN DOCUMENT for MLLEARN API:
-The mllearn API supports LogisticRegression, LinearRegression, SVM, NaiveBayes 
+The mllearn API supports LogisticRegression, LinearRegression, SVM, NaiveBayes
 and Caffe2DML. Every algorithm in this API has a python wrapper (implemented in the mllearn python package)
-and a Scala class where the actual logic is implementation. 
+and a Scala class where the actual logic is implementation.
 Both wrapper and scala class follow the below hierarchy to reuse code and simplify the implementation.
 
 
@@ -72,7 +71,6 @@ get the DML script. To enable this, each wrapper class has to implement followin
 2. getPredictionScript(isSingleNode:Boolean): (Script object of mlcontext, variable name of X in the script:String)
 
 ****************************************************/
-
 trait HasLaplace extends Params {
   final val laplace: Param[Double] = new Param[Double](this, "laplace", "Laplace smoothing specified by the user to avoid creation of 0 probabilities.")
   setDefault(laplace, 1.0)
@@ -105,27 +103,27 @@ trait HasRegParam extends Params {
 }
 
 trait BaseSystemMLEstimatorOrModel {
-  var enableGPU:Boolean = false
-  var forceGPU:Boolean = false
-  var explain:Boolean = false
-  var explainLevel:String = "runtime"
-  var statistics:Boolean = false
-  var statisticsMaxHeavyHitters:Int = 10
-  val config:HashMap[String, String] = new HashMap[String, String]()
-  def setGPU(enableGPU1:Boolean):BaseSystemMLEstimatorOrModel = { enableGPU = enableGPU1; this}
-  def setForceGPU(enableGPU1:Boolean):BaseSystemMLEstimatorOrModel = { forceGPU = enableGPU1; this}
-  def setExplain(explain1:Boolean):BaseSystemMLEstimatorOrModel = { explain = explain1; this}
-  def setExplainLevel(explainLevel1:String):BaseSystemMLEstimatorOrModel = { explainLevel = explainLevel1; this  }
-  def setStatistics(statistics1:Boolean):BaseSystemMLEstimatorOrModel = { statistics = statistics1; this}
-  def setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1:Int):BaseSystemMLEstimatorOrModel = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this}
-  def setConfigProperty(key:String, value:String):BaseSystemMLEstimatorOrModel = { config.put(key, value); this}
-  def updateML(ml:MLContext):Unit = {
+  var enableGPU: Boolean                                                                          = false
+  var forceGPU: Boolean                                                                           = false
+  var explain: Boolean                                                                            = false
+  var explainLevel: String                                                                        = "runtime"
+  var statistics: Boolean                                                                         = false
+  var statisticsMaxHeavyHitters: Int                                                              = 10
+  val config: HashMap[String, String]                                                             = new HashMap[String, String]()
+  def setGPU(enableGPU1: Boolean): BaseSystemMLEstimatorOrModel                                   = { enableGPU = enableGPU1; this }
+  def setForceGPU(enableGPU1: Boolean): BaseSystemMLEstimatorOrModel                              = { forceGPU = enableGPU1; this }
+  def setExplain(explain1: Boolean): BaseSystemMLEstimatorOrModel                                 = { explain = explain1; this }
+  def setExplainLevel(explainLevel1: String): BaseSystemMLEstimatorOrModel                        = { explainLevel = explainLevel1; this }
+  def setStatistics(statistics1: Boolean): BaseSystemMLEstimatorOrModel                           = { statistics = statistics1; this }
+  def setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1: Int): BaseSystemMLEstimatorOrModel = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this }
+  def setConfigProperty(key: String, value: String): BaseSystemMLEstimatorOrModel                 = { config.put(key, value); this }
+  def updateML(ml: MLContext): Unit = {
     ml.setGPU(enableGPU); ml.setForceGPU(forceGPU);
     ml.setExplain(explain); ml.setExplainLevel(explainLevel);
-    ml.setStatistics(statistics); ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters); 
+    ml.setStatistics(statistics); ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters);
     config.map(x => ml.setConfigProperty(x._1, x._2))
   }
-  def copyProperties(other:BaseSystemMLEstimatorOrModel):BaseSystemMLEstimatorOrModel = {
+  def copyProperties(other: BaseSystemMLEstimatorOrModel): BaseSystemMLEstimatorOrModel = {
     other.setGPU(enableGPU); other.setForceGPU(forceGPU);
     other.setExplain(explain); other.setExplainLevel(explainLevel);
     other.setStatistics(statistics); other.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters);
@@ -136,172 +134,168 @@ trait BaseSystemMLEstimatorOrModel {
 
 trait BaseSystemMLEstimator extends BaseSystemMLEstimatorOrModel {
   def transformSchema(schema: StructType): StructType = schema
-  var mloutput:MLResults = null
+  var mloutput: MLResults                             = null
   // Returns the script and variables for X and y
-  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)
-  
-  def toDouble(i:Int): java.lang.Double = {
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String)
+
+  def toDouble(i: Int): java.lang.Double =
     double2Double(i.toDouble)
-  }
-  
-  def toDouble(d:Double): java.lang.Double = {
+
+  def toDouble(d: Double): java.lang.Double =
     double2Double(d)
-  }
-  
+
 }
 
 trait BaseSystemMLEstimatorModel extends BaseSystemMLEstimatorOrModel {
-  def toDouble(i:Int): java.lang.Double = {
+  def toDouble(i: Int): java.lang.Double =
     double2Double(i.toDouble)
-  }
-  def toDouble(d:Double): java.lang.Double = {
+  def toDouble(d: Double): java.lang.Double =
     double2Double(d)
-  }
-  
+
   def transform_probability(X: MatrixBlock): MatrixBlock;
-  
+
   def transformSchema(schema: StructType): StructType = schema
-  
+
   // Returns the script and variable for X
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)
-  def baseEstimator():BaseSystemMLEstimator
-  def modelVariables():List[String]
+  def getPredictionScript(isSingleNode: Boolean): (Script, String)
+  def baseEstimator(): BaseSystemMLEstimator
+  def modelVariables(): List[String]
   // self.model.load(self.sc._jsc, weights, format, sep)
-  def load(sc:JavaSparkContext, outputDir:String, sep:String, eager:Boolean=false):Unit = {
-  	val dmlScript = new StringBuilder
-  	dmlScript.append("print(\"Loading the model from " + outputDir + "...\")\n")
-  	val tmpSum = "tmp_sum_var" + Math.abs((new Random()).nextInt())
-  	if(eager)
-  	  dmlScript.append(tmpSum + " = 0\n")
-		for(varName <- modelVariables) {
-			dmlScript.append(varName + " = read(\"" + outputDir + sep + varName + ".mtx\")\n")
-			if(eager)
-			  dmlScript.append(tmpSum + " = " + tmpSum + " + 0.001*mean(" + varName + ")\n")
-		}
-  	if(eager) {
-  	  dmlScript.append("if(" + tmpSum + " > 0) { print(\"Loaded the model\"); } else {  print(\"Loaded the model.\"); }")
-  	}
-  	val script = dml(dmlScript.toString)
-		for(varName <- modelVariables) {
-			script.out(varName)
-		}
-	  val ml = new MLContext(sc)
-	  baseEstimator.mloutput = ml.execute(script)
+  def load(sc: JavaSparkContext, outputDir: String, sep: String, eager: Boolean = false): Unit = {
+    val dmlScript = new StringBuilder
+    dmlScript.append("print(\"Loading the model from " + outputDir + "...\")\n")
+    val tmpSum = "tmp_sum_var" + Math.abs((new Random()).nextInt())
+    if (eager)
+      dmlScript.append(tmpSum + " = 0\n")
+    for (varName <- modelVariables) {
+      dmlScript.append(varName + " = read(\"" + outputDir + sep + varName + ".mtx\")\n")
+      if (eager)
+        dmlScript.append(tmpSum + " = " + tmpSum + " + 0.001*mean(" + varName + ")\n")
+    }
+    if (eager) {
+      dmlScript.append("if(" + tmpSum + " > 0) { print(\"Loaded the model\"); } else {  print(\"Loaded the model.\"); }")
+    }
+    val script = dml(dmlScript.toString)
+    for (varName <- modelVariables) {
+      script.out(varName)
+    }
+    val ml = new MLContext(sc)
+    baseEstimator.mloutput = ml.execute(script)
+  }
+  def save(sc: JavaSparkContext, outputDir: String, format: String = "binary", sep: String = "/"): Unit = {
+    if (baseEstimator.mloutput == null) throw new DMLRuntimeException("Cannot save as you need to train the model first using fit")
+    val dmlScript = new StringBuilder
+    dmlScript.append("print(\"Saving the model to " + outputDir + "...\")\n")
+    for (varName <- modelVariables) {
+      dmlScript.append("write(" + varName + ", \"" + outputDir + sep + varName + ".mtx\", format=\"" + format + "\")\n")
+    }
+    val script = dml(dmlScript.toString)
+    for (varName <- modelVariables) {
+      script.in(varName, baseEstimator.mloutput.getMatrix(varName))
+    }
+    val ml = new MLContext(sc)
+    ml.execute(script)
   }
-  def save(sc:JavaSparkContext, outputDir:String, format:String="binary", sep:String="/"):Unit = {
-	  if(baseEstimator.mloutput == null) throw new DMLRuntimeException("Cannot save as you need to train the model first using fit")
-	  val dmlScript = new StringBuilder
-	  dmlScript.append("print(\"Saving the model to " + outputDir + "...\")\n")
-	  for(varName <- modelVariables) {
-	  	dmlScript.append("write(" + varName + ", \"" + outputDir + sep + varName + ".mtx\", format=\"" + format + "\")\n")
-	  }
-	  val script = dml(dmlScript.toString)
-		for(varName <- modelVariables) {
-			script.in(varName, baseEstimator.mloutput.getMatrix(varName))
-		}
-	  val ml = new MLContext(sc)
-	  ml.execute(script)
-	}
 }
 
 trait BaseSystemMLClassifier extends BaseSystemMLEstimator {
   def baseFit(X_mb: MatrixBlock, y_mb: MatrixBlock, sc: SparkContext): MLResults = {
     val isSingleNode = true
-    val ml = new MLContext(sc)
+    val ml           = new MLContext(sc)
     updateML(ml)
     y_mb.recomputeNonZeros();
-    val ret = getTrainingScript(isSingleNode)
+    val ret    = getTrainingScript(isSingleNode)
     val script = ret._1.in(ret._2, X_mb).in(ret._3, y_mb)
     ml.execute(script)
   }
   def baseFit(df: ScriptsUtils.SparkDataType, sc: SparkContext): MLResults = {
     val isSingleNode = false
-    val ml = new MLContext(df.rdd.sparkContext)
+    val ml           = new MLContext(df.rdd.sparkContext)
     updateML(ml)
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.dataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame].select("features"), mcXin, false, true)
+    val mcXin           = new MatrixCharacteristics()
+    val Xin             = RDDConverterUtils.dataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame].select("features"), mcXin, false, true)
     val revLabelMapping = new java.util.HashMap[Int, String]
-    val yin = df.select("label")
-    val ret = getTrainingScript(isSingleNode)
-    val mmXin = new MatrixMetadata(mcXin)
-    val Xbin = new Matrix(Xin, mmXin)
-    val script = ret._1.in(ret._2, Xbin).in(ret._3, yin)
+    val yin             = df.select("label")
+    val ret             = getTrainingScript(isSingleNode)
+    val mmXin           = new MatrixMetadata(mcXin)
+    val Xbin            = new Matrix(Xin, mmXin)
+    val script          = ret._1.in(ret._2, Xbin).in(ret._3, yin)
     ml.execute(script)
   }
 }
 
 trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
 
-	def baseTransform(X: MatrixBlock, sc: SparkContext, probVar:String): MatrixBlock = baseTransform(X, sc, probVar, -1, 1, 1)
-	
-	def baseTransform(X: MatrixBlock, sc: SparkContext, probVar:String, C:Int, H: Int, W:Int): MatrixBlock = {
+  def baseTransform(X: MatrixBlock, sc: SparkContext, probVar: String): MatrixBlock = baseTransform(X, sc, probVar, -1, 1, 1)
+
+  def baseTransform(X: MatrixBlock, sc: SparkContext, probVar: String, C: Int, H: Int, W: Int): MatrixBlock = {
     val Prob = baseTransformHelper(X, sc, probVar, C, H, W)
     val script1 = dml("source(\"nn/util.dml\") as util; Prediction = util::predict_class(Prob, C, H, W);")
-    							.out("Prediction").in("Prob", Prob.toMatrixBlock, Prob.getMatrixMetadata).in("C", C).in("H", H).in("W", W)
+      .out("Prediction")
+      .in("Prob", Prob.toMatrixBlock, Prob.getMatrixMetadata)
+      .in("C", C)
+      .in("H", H)
+      .in("W", W)
     val ret = (new MLContext(sc)).execute(script1).getMatrix("Prediction").toMatrixBlock
-              
-    if(ret.getNumColumns != 1 && H == 1 && W == 1) {
+
+    if (ret.getNumColumns != 1 && H == 1 && W == 1) {
       throw new RuntimeException("Expected predicted label to be a column vector")
     }
     return ret
   }
-	
-	def baseTransformHelper(X: MatrixBlock, sc: SparkContext, probVar:String, C:Int, H: Int, W:Int): Matrix = {
-	  val isSingleNode = true
-    val ml = new MLContext(sc)
+
+  def baseTransformHelper(X: MatrixBlock, sc: SparkContext, probVar: String, C: Int, H: Int, W: Int): Matrix = {
+    val isSingleNode = true
+    val ml           = new MLContext(sc)
     updateML(ml)
     val script = getPredictionScript(isSingleNode)
     // Uncomment for debugging
     // ml.setExplainLevel(ExplainLevel.RECOMPILE_RUNTIME)
     val modelPredict = ml.execute(script._1.in(script._2, X, new MatrixMetadata(X.getNumRows, X.getNumColumns, X.getNonZeros)))
     return modelPredict.getMatrix(probVar)
-	}
-	
-	def baseTransformProbability(X: MatrixBlock, sc: SparkContext, probVar:String): MatrixBlock = {
-	  baseTransformProbability(X, sc, probVar, -1, 1, 1)
-	}
-	
-	def baseTransformProbability(X: MatrixBlock, sc: SparkContext, probVar:String, C:Int, H: Int, W:Int): MatrixBlock = {
-	  return baseTransformHelper(X, sc, probVar, C, H, W).toMatrixBlock
-	}
-	
-	
-	def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, 
-      probVar:String, outputProb:Boolean=true): DataFrame = {
-		baseTransform(df, sc, probVar, outputProb, -1, 1, 1)
-	}
-	
-	def baseTransformHelper(df: ScriptsUtils.SparkDataType, sc: SparkContext, 
-      probVar:String, outputProb:Boolean, C:Int, H: Int, W:Int): Matrix = {
-	  val isSingleNode = false
-    val ml = new MLContext(sc)
+  }
+
+  def baseTransformProbability(X: MatrixBlock, sc: SparkContext, probVar: String): MatrixBlock =
+    baseTransformProbability(X, sc, probVar, -1, 1, 1)
+
+  def baseTransformProbability(X: MatrixBlock, sc: SparkContext, probVar: String, C: Int, H: Int, W: Int): MatrixBlock =
+    return baseTransformHelper(X, sc, probVar, C, H, W).toMatrixBlock
+
+  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, probVar: String, outputProb: Boolean = true): DataFrame =
+    baseTransform(df, sc, probVar, outputProb, -1, 1, 1)
+
+  def baseTransformHelper(df: ScriptsUtils.SparkDataType, sc: SparkContext, probVar: String, outputProb: Boolean, C: Int, H: Int, W: Int): Matrix = {
+    val isSingleNode = false
+    val ml           = new MLContext(sc)
     updateML(ml)
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame].select("features"), mcXin, false, true)
-    val script = getPredictionScript(isSingleNode)
-    val mmXin = new MatrixMetadata(mcXin)
-    val Xin_bin = new Matrix(Xin, mmXin)
+    val mcXin        = new MatrixCharacteristics()
+    val Xin          = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame].select("features"), mcXin, false, true)
+    val script       = getPredictionScript(isSingleNode)
+    val mmXin        = new MatrixMetadata(mcXin)
+    val Xin_bin      = new Matrix(Xin, mmXin)
     val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
     return modelPredict.getMatrix(probVar)
-	}
+  }
 
-  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, 
-      probVar:String, outputProb:Boolean, C:Int, H: Int, W:Int): DataFrame = {
+  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, probVar: String, outputProb: Boolean, C: Int, H: Int, W: Int): DataFrame = {
     val Prob = baseTransformHelper(df, sc, probVar, outputProb, C, H, W)
     val script1 = dml("source(\"nn/util.dml\") as util; Prediction = util::predict_class(Prob, C, H, W);")
-    							.out("Prediction").in("Prob", Prob).in("C", C).in("H", H).in("W", W)
+      .out("Prediction")
+      .in("Prob", Prob)
+      .in("C", C)
+      .in("H", H)
+      .in("W", W)
     val predLabelOut = (new MLContext(sc)).execute(script1)
-    val predictedDF = predLabelOut.getDataFrame("Prediction").select(RDDConverterUtils.DF_ID_COLUMN, "C1").withColumnRenamed("C1", "prediction")
-      
-    if(outputProb) {
-      val prob = Prob.toDFVectorWithIDColumn().withColumnRenamed("C1", "probability").select(RDDConverterUtils.DF_ID_COLUMN, "probability")
+    val predictedDF  = predLabelOut.getDataFrame("Prediction").select(RDDConverterUtils.DF_ID_COLUMN, "C1").withColumnRenamed("C1", "prediction")
+
+    if (outputProb) {
+      val prob    = Prob.toDFVectorWithIDColumn().withColumnRenamed("C1", "probability").select(RDDConverterUtils.DF_ID_COLUMN, "probability")
       val dataset = RDDConverterUtilsExt.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sparkSession, RDDConverterUtils.DF_ID_COLUMN)
       return PredictionUtils.joinUsingID(dataset, PredictionUtils.joinUsingID(prob, predictedDF))
-    }
-    else {
+    } else {
       val dataset = RDDConverterUtilsExt.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sparkSession, RDDConverterUtils.DF_ID_COLUMN)
       return PredictionUtils.joinUsingID(dataset, predictedDF)
     }
-    
+
   }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
index 5610bf3..d94655b 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,71 +22,71 @@ package org.apache.sysml.api.ml
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
-import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt, RDDConverterUtils }
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtils, RDDConverterUtilsExt }
 import org.apache.sysml.api.mlcontext._
 import org.apache.sysml.api.mlcontext.ScriptFactory._
 
 trait BaseSystemMLRegressor extends BaseSystemMLEstimator {
-  
+
   def baseFit(X_mb: MatrixBlock, y_mb: MatrixBlock, sc: SparkContext): MLResults = {
     val isSingleNode = true
-    val ml = new MLContext(sc)
+    val ml           = new MLContext(sc)
     updateML(ml)
-    val ret = getTrainingScript(isSingleNode)
+    val ret    = getTrainingScript(isSingleNode)
     val script = ret._1.in(ret._2, X_mb).in(ret._3, y_mb)
     ml.execute(script)
   }
-  
+
   def baseFit(df: ScriptsUtils.SparkDataType, sc: SparkContext): MLResults = {
     val isSingleNode = false
-    val ml = new MLContext(df.rdd.sparkContext)
+    val ml           = new MLContext(df.rdd.sparkContext)
     updateML(ml)
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.dataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame], mcXin, false, true)
-    val yin = df.select("label")
-    val ret = getTrainingScript(isSingleNode)
-    val mmXin = new MatrixMetadata(mcXin)
-    val Xbin = new Matrix(Xin, mmXin)
+    val mcXin  = new MatrixCharacteristics()
+    val Xin    = RDDConverterUtils.dataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame], mcXin, false, true)
+    val yin    = df.select("label")
+    val ret    = getTrainingScript(isSingleNode)
+    val mmXin  = new MatrixMetadata(mcXin)
+    val Xbin   = new Matrix(Xin, mmXin)
     val script = ret._1.in(ret._2, Xbin).in(ret._3, yin)
     ml.execute(script)
   }
 }
 
 trait BaseSystemMLRegressorModel extends BaseSystemMLEstimatorModel {
-  
-  def baseTransform(X: MatrixBlock, sc: SparkContext, predictionVar:String): MatrixBlock = {
+
+  def baseTransform(X: MatrixBlock, sc: SparkContext, predictionVar: String): MatrixBlock = {
     val isSingleNode = true
-    val ml = new MLContext(sc)
+    val ml           = new MLContext(sc)
     updateML(ml)
-    val script = getPredictionScript(isSingleNode)
+    val script       = getPredictionScript(isSingleNode)
     val modelPredict = ml.execute(script._1.in(script._2, X))
-    val ret = modelPredict.getMatrix(predictionVar).toMatrixBlock
-              
-    if(ret.getNumColumns != 1) {
+    val ret          = modelPredict.getMatrix(predictionVar).toMatrixBlock
+
+    if (ret.getNumColumns != 1) {
       throw new RuntimeException("Expected prediction to be a column vector")
     }
     return ret
   }
-  
-  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, predictionVar:String): DataFrame = {
+
+  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, predictionVar: String): DataFrame = {
     val isSingleNode = false
-    val ml = new MLContext(sc)
+    val ml           = new MLContext(sc)
     updateML(ml)
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame], mcXin, false, true)
-    val script = getPredictionScript(isSingleNode)
-    val mmXin = new MatrixMetadata(mcXin)
-    val Xin_bin = new Matrix(Xin, mmXin)
+    val mcXin        = new MatrixCharacteristics()
+    val Xin          = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame], mcXin, false, true)
+    val script       = getPredictionScript(isSingleNode)
+    val mmXin        = new MatrixMetadata(mcXin)
+    val Xin_bin      = new Matrix(Xin, mmXin)
     val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
-    val predictedDF = modelPredict.getDataFrame(predictionVar).select(RDDConverterUtils.DF_ID_COLUMN, "C1").withColumnRenamed("C1", "prediction")
-    val dataset = RDDConverterUtilsExt.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sparkSession, RDDConverterUtils.DF_ID_COLUMN)
+    val predictedDF  = modelPredict.getDataFrame(predictionVar).select(RDDConverterUtils.DF_ID_COLUMN, "C1").withColumnRenamed("C1", "prediction")
+    val dataset      = RDDConverterUtilsExt.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sparkSession, RDDConverterUtils.DF_ID_COLUMN)
     return PredictionUtils.joinUsingID(dataset, predictedDF)
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
index b7634d7..b6f4966 100644
--- a/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,10 +22,10 @@ package org.apache.sysml.api.ml
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -39,28 +39,32 @@ object LinearRegression {
 }
 
 // algorithm = "direct-solve", "conjugate-gradient"
-class LinearRegression(override val uid: String, val sc: SparkContext, val solver:String="direct-solve") 
-  extends Estimator[LinearRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLRegressor {
-  
-  def setIcpt(value: Int) = set(icpt, value)
-  def setMaxIter(value: Int) = set(maxOuterIter, value)
+class LinearRegression(override val uid: String, val sc: SparkContext, val solver: String = "direct-solve")
+    extends Estimator[LinearRegressionModel]
+    with HasIcpt
+    with HasRegParam
+    with HasTol
+    with HasMaxOuterIter
+    with BaseSystemMLRegressor {
+
+  def setIcpt(value: Int)        = set(icpt, value)
+  def setMaxIter(value: Int)     = set(maxOuterIter, value)
   def setRegParam(value: Double) = set(regParam, value)
-  def setTol(value: Double) = set(tol, value)
-  
+  def setTol(value: Double)      = set(tol, value)
 
   override def copy(extra: ParamMap): Estimator[LinearRegressionModel] = {
     val that = new LinearRegression(uid, sc, solver)
     copyValues(that, extra)
   }
-  
-          
-  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
-    val script = dml(ScriptsUtils.getDMLScript(
-        if(solver.compareTo("direct-solve") == 0) LinearRegression.scriptPathDS 
-        else if(solver.compareTo("newton-cg") == 0) LinearRegression.scriptPathCG
-        else throw new DMLRuntimeException("The algorithm should be direct-solve or newton-cg")))
-      .in("$X", " ")
+
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
+    val script = dml(
+      ScriptsUtils.getDMLScript(
+        if (solver.compareTo("direct-solve") == 0) LinearRegression.scriptPathDS
+        else if (solver.compareTo("newton-cg") == 0) LinearRegression.scriptPathCG
+        else throw new DMLRuntimeException("The algorithm should be direct-solve or newton-cg")
+      )
+    ).in("$X", " ")
       .in("$Y", " ")
       .in("$B", " ")
       .in("$Log", " ")
@@ -72,41 +76,46 @@ class LinearRegression(override val uid: String, val sc: SparkContext, val solve
       .out("beta_out")
     (script, "X", "y")
   }
-  
-  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LinearRegressionModel =  {
+
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LinearRegressionModel = {
     mloutput = baseFit(X_mb, y_mb, sc)
     new LinearRegressionModel(this)
   }
-    
-  def fit(df: ScriptsUtils.SparkDataType): LinearRegressionModel = { 
+
+  def fit(df: ScriptsUtils.SparkDataType): LinearRegressionModel = {
     mloutput = baseFit(df, sc)
     new LinearRegressionModel(this)
   }
-  
+
 }
 
-class LinearRegressionModel(override val uid: String)(estimator:LinearRegression, val sc: SparkContext) extends Model[LinearRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLRegressorModel {
+class LinearRegressionModel(override val uid: String)(estimator: LinearRegression, val sc: SparkContext)
+    extends Model[LinearRegressionModel]
+    with HasIcpt
+    with HasRegParam
+    with HasTol
+    with HasMaxOuterIter
+    with BaseSystemMLRegressorModel {
   override def copy(extra: ParamMap): LinearRegressionModel = {
     val that = new LinearRegressionModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
-  
+
   def transform_probability(X: MatrixBlock): MatrixBlock = throw new DMLRuntimeException("Unsupported method")
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  
-  def this(estimator:LinearRegression) =  {
-  	this("model")(estimator, estimator.sc)
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+
+  def this(estimator: LinearRegression) = {
+    this("model")(estimator, estimator.sc)
   }
-  
-  def getPredictionScript(isSingleNode:Boolean): (Script, String) =
+
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) =
     PredictionUtils.getGLMPredictionScript(estimator.mloutput.getMatrix("beta_out"), isSingleNode)
-  
-  def modelVariables():List[String] = List[String]("beta_out")
-  
+
+  def modelVariables(): List[String] = List[String]("beta_out")
+
   def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "means")
-  
-  def transform(X: MatrixBlock): MatrixBlock =  baseTransform(X, sc, "means")
-  
-}
\ No newline at end of file
+
+  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "means")
+
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
index b04acd1..98b6dd4 100644
--- a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,10 +22,10 @@ package org.apache.sysml.api.ml
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -38,36 +38,40 @@ object LogisticRegression {
 }
 
 /**
- * Logistic Regression Scala API
- */
-class LogisticRegression(override val uid: String, val sc: SparkContext) extends Estimator[LogisticRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter with BaseSystemMLClassifier {
+  * Logistic Regression Scala API
+  */
+class LogisticRegression(override val uid: String, val sc: SparkContext)
+    extends Estimator[LogisticRegressionModel]
+    with HasIcpt
+    with HasRegParam
+    with HasTol
+    with HasMaxOuterIter
+    with HasMaxInnerIter
+    with BaseSystemMLClassifier {
 
-  def setIcpt(value: Int) = set(icpt, value)
+  def setIcpt(value: Int)         = set(icpt, value)
   def setMaxOuterIter(value: Int) = set(maxOuterIter, value)
   def setMaxInnerIter(value: Int) = set(maxInnerIter, value)
-  def setRegParam(value: Double) = set(regParam, value)
-  def setTol(value: Double) = set(tol, value)
+  def setRegParam(value: Double)  = set(regParam, value)
+  def setTol(value: Double)       = set(tol, value)
 
   override def copy(extra: ParamMap): LogisticRegression = {
     val that = new LogisticRegression(uid, sc)
     copyValues(that, extra)
   }
-  
 
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LogisticRegressionModel = {
     mloutput = baseFit(X_mb, y_mb, sc)
     new LogisticRegressionModel(this)
   }
-  
+
   def fit(df: ScriptsUtils.SparkDataType): LogisticRegressionModel = {
     mloutput = baseFit(df, sc)
     new LogisticRegressionModel(this)
   }
-  
-  
-  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
     val script = dml(ScriptsUtils.getDMLScript(LogisticRegression.scriptPath))
       .in("$X", " ")
       .in("$Y", " ")
@@ -86,36 +90,39 @@ object LogisticRegressionModel {
 }
 
 /**
- * Logistic Regression Scala API
- */
-
-class LogisticRegressionModel(override val uid: String)(
-    estimator: LogisticRegression, val sc: SparkContext) 
-    extends Model[LogisticRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter with BaseSystemMLClassifierModel {
+  * Logistic Regression Scala API
+  */
+class LogisticRegressionModel(override val uid: String)(estimator: LogisticRegression, val sc: SparkContext)
+    extends Model[LogisticRegressionModel]
+    with HasIcpt
+    with HasRegParam
+    with HasTol
+    with HasMaxOuterIter
+    with HasMaxInnerIter
+    with BaseSystemMLClassifierModel {
   override def copy(extra: ParamMap): LogisticRegressionModel = {
     val that = new LogisticRegressionModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
-  var outputRawPredictions = true
-  def setOutputRawPredictions(outRawPred:Boolean): Unit = { outputRawPredictions = outRawPred }
-  def this(estimator:LogisticRegression) =  {
-  	this("model")(estimator, estimator.sc)
+  var outputRawPredictions                               = true
+  def setOutputRawPredictions(outRawPred: Boolean): Unit = outputRawPredictions = outRawPred
+  def this(estimator: LogisticRegression) = {
+    this("model")(estimator, estimator.sc)
   }
-  def getPredictionScript(isSingleNode:Boolean): (Script, String) =
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) =
     PredictionUtils.getGLMPredictionScript(estimator.mloutput.getMatrix("B_out"), isSingleNode, 3)
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  def modelVariables():List[String] = List[String]("B_out")
-  
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "means")
-  def transform_probability(X: MatrixBlock): MatrixBlock = baseTransformProbability(X, sc, "means")
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+  def modelVariables(): List[String]         = List[String]("B_out")
+
+  def transform(X: MatrixBlock): MatrixBlock               = baseTransform(X, sc, "means")
+  def transform_probability(X: MatrixBlock): MatrixBlock   = baseTransformProbability(X, sc, "means")
   def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "means")
 }
 
 /**
- * Example code for Logistic Regression
- */
+  * Example code for Logistic Regression
+  */
 object LogisticRegressionExample {
   import org.apache.spark.{ SparkConf, SparkContext }
   import org.apache.spark.sql._
@@ -124,28 +131,34 @@ object LogisticRegressionExample {
   import org.apache.spark.ml.feature.LabeledPoint
 
   def main(args: Array[String]) = {
-    val sparkSession = SparkSession.builder().master("local").appName("TestLocal").getOrCreate();
+    val sparkSession     = SparkSession.builder().master("local").appName("TestLocal").getOrCreate();
     val sc: SparkContext = sparkSession.sparkContext;
 
     import sparkSession.implicits._
-    val training = sc.parallelize(Seq(
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.4, 2.1)),
-      LabeledPoint(2.0, Vectors.dense(1.2, 0.0, 3.5)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.5, 2.2)),
-      LabeledPoint(2.0, Vectors.dense(1.6, 0.8, 3.6)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.3))))
-    val lr = new LogisticRegression("log", sc)
+    val training = sc.parallelize(
+      Seq(
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.4, 2.1)),
+        LabeledPoint(2.0, Vectors.dense(1.2, 0.0, 3.5)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.5, 2.2)),
+        LabeledPoint(2.0, Vectors.dense(1.6, 0.8, 3.6)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.3))
+      )
+    )
+    val lr      = new LogisticRegression("log", sc)
     val lrmodel = lr.fit(training.toDF)
     // lrmodel.mloutput.getDF(sparkSession, "B_out").show()
 
-    val testing = sc.parallelize(Seq(
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.4, 2.1)),
-      LabeledPoint(2.0, Vectors.dense(1.2, 0.0, 3.5)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.5, 2.2)),
-      LabeledPoint(2.0, Vectors.dense(1.6, 0.8, 3.6)),
-      LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.3))))
+    val testing = sc.parallelize(
+      Seq(
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.4, 2.1)),
+        LabeledPoint(2.0, Vectors.dense(1.2, 0.0, 3.5)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.5, 2.2)),
+        LabeledPoint(2.0, Vectors.dense(1.6, 0.8, 3.6)),
+        LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.3))
+      )
+    )
 
     lrmodel.transform(testing.toDF).show
   }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
index 990ab52..8ecd4f0 100644
--- a/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,10 +22,10 @@ package org.apache.sysml.api.ml
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -43,19 +43,19 @@ class NaiveBayes(override val uid: String, val sc: SparkContext) extends Estimat
     copyValues(that, extra)
   }
   def setLaplace(value: Double) = set(laplace, value)
-  
+
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): NaiveBayesModel = {
     mloutput = baseFit(X_mb, y_mb, sc)
     new NaiveBayesModel(this)
   }
-  
+
   def fit(df: ScriptsUtils.SparkDataType): NaiveBayesModel = {
     mloutput = baseFit(df, sc)
     new NaiveBayesModel(this)
   }
-  
-  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
     val script = dml(ScriptsUtils.getDMLScript(NaiveBayes.scriptPath))
       .in("$X", " ")
       .in("$Y", " ")
@@ -68,49 +68,47 @@ class NaiveBayes(override val uid: String, val sc: SparkContext) extends Estimat
   }
 }
 
-
 object NaiveBayesModel {
   final val scriptPath = "scripts" + File.separator + "algorithms" + File.separator + "naive-bayes-predict.dml"
 }
 
-class NaiveBayesModel(override val uid: String)
-  (estimator:NaiveBayes, val sc: SparkContext) 
-  extends Model[NaiveBayesModel] with HasLaplace with BaseSystemMLClassifierModel {
-  
-  def this(estimator:NaiveBayes) =  {
+class NaiveBayesModel(override val uid: String)(estimator: NaiveBayes, val sc: SparkContext) extends Model[NaiveBayesModel] with HasLaplace with BaseSystemMLClassifierModel {
+
+  def this(estimator: NaiveBayes) = {
     this("model")(estimator, estimator.sc)
   }
-  
+
   override def copy(extra: ParamMap): NaiveBayesModel = {
     val that = new NaiveBayesModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
-  
-  def modelVariables():List[String] = List[String]("classPrior", "classConditionals")
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
+
+  def modelVariables(): List[String] = List[String]("classPrior", "classConditionals")
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) = {
     val script = dml(ScriptsUtils.getDMLScript(NaiveBayesModel.scriptPath))
       .in("$X", " ")
       .in("$prior", " ")
       .in("$conditionals", " ")
       .in("$probabilities", " ")
       .out("probs")
-    
-    val classPrior = estimator.mloutput.getMatrix("classPrior")
+
+    val classPrior        = estimator.mloutput.getMatrix("classPrior")
     val classConditionals = estimator.mloutput.getMatrix("classConditionals")
-    val ret = if(isSingleNode) {
-      script.in("prior", classPrior.toMatrixBlock, classPrior.getMatrixMetadata)
-            .in("conditionals", classConditionals.toMatrixBlock, classConditionals.getMatrixMetadata)
-    }
-    else {
-      script.in("prior", classPrior.toBinaryBlocks, classPrior.getMatrixMetadata)
-            .in("conditionals", classConditionals.toBinaryBlocks, classConditionals.getMatrixMetadata)
+    val ret = if (isSingleNode) {
+      script
+        .in("prior", classPrior.toMatrixBlock, classPrior.getMatrixMetadata)
+        .in("conditionals", classConditionals.toMatrixBlock, classConditionals.getMatrixMetadata)
+    } else {
+      script
+        .in("prior", classPrior.toBinaryBlocks, classPrior.getMatrixMetadata)
+        .in("conditionals", classConditionals.toBinaryBlocks, classConditionals.getMatrixMetadata)
     }
     (ret, "D")
   }
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "probs")
-  def transform_probability(X: MatrixBlock): MatrixBlock = baseTransformProbability(X, sc, "probs")
+
+  def baseEstimator(): BaseSystemMLEstimator               = estimator
+  def transform(X: MatrixBlock): MatrixBlock               = baseTransform(X, sc, "probs")
+  def transform_probability(X: MatrixBlock): MatrixBlock   = baseTransformProbability(X, sc, "probs")
   def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "probs")
-  
+
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala b/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
index 3406169..72e82e8 100644
--- a/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,39 +33,35 @@ import org.apache.sysml.api.mlcontext.Script
 import org.apache.sysml.api.mlcontext.Matrix
 
 object PredictionUtils {
-  
-  def getGLMPredictionScript(B_full: Matrix, isSingleNode:Boolean, dfam:java.lang.Integer=1): (Script, String)  = {
+
+  def getGLMPredictionScript(B_full: Matrix, isSingleNode: Boolean, dfam: java.lang.Integer = 1): (Script, String) = {
     val script = dml(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath))
       .in("$X", " ")
       .in("$B", " ")
       .in("$dfam", dfam)
       .out("means")
-    val ret = if(isSingleNode) {
+    val ret = if (isSingleNode) {
       script.in("B_full", B_full.toMatrixBlock, B_full.getMatrixMetadata)
-    }
-    else {
+    } else {
       script.in("B_full", B_full)
     }
     (ret, "X")
   }
-  
-  def joinUsingID(df1:DataFrame, df2:DataFrame):DataFrame = {
+
+  def joinUsingID(df1: DataFrame, df2: DataFrame): DataFrame =
     df1.join(df2, RDDConverterUtils.DF_ID_COLUMN)
-  }
-  
-  def computePredictedClassLabelsFromProbability(mlscoreoutput:MLResults, isSingleNode:Boolean, sc:SparkContext, inProbVar:String): MLResults = {
-    val ml = new org.apache.sysml.api.mlcontext.MLContext(sc)
-    val script = dml(
-        """
+
+  def computePredictedClassLabelsFromProbability(mlscoreoutput: MLResults, isSingleNode: Boolean, sc: SparkContext, inProbVar: String): MLResults = {
+    val ml      = new org.apache.sysml.api.mlcontext.MLContext(sc)
+    val script  = dml("""
         Prob = read("temp1");
         Prediction = rowIndexMax(Prob); # assuming one-based label mapping
         write(Prediction, "tempOut", "csv");
         """).out("Prediction")
     val probVar = mlscoreoutput.getMatrix(inProbVar)
-    if(isSingleNode) {
+    if (isSingleNode) {
       ml.execute(script.in("Prob", probVar.toMatrixBlock, probVar.getMatrixMetadata))
-    }
-    else {
+    } else {
       ml.execute(script.in("Prob", probVar))
     }
   }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/SVM.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/SVM.scala b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
index 9107836..2013385 100644
--- a/src/main/scala/org/apache/sysml/api/ml/SVM.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,10 +22,10 @@ package org.apache.sysml.api.ml
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -34,25 +34,30 @@ import org.apache.sysml.api.mlcontext._
 import org.apache.sysml.api.mlcontext.ScriptFactory._
 
 object SVM {
-  final val scriptPathBinary = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm.dml"
+  final val scriptPathBinary     = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm.dml"
   final val scriptPathMulticlass = "scripts" + File.separator + "algorithms" + File.separator + "m-svm.dml"
 }
 
-class SVM (override val uid: String, val sc: SparkContext, val isMultiClass:Boolean=false) extends Estimator[SVMModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLClassifier {
+class SVM(override val uid: String, val sc: SparkContext, val isMultiClass: Boolean = false)
+    extends Estimator[SVMModel]
+    with HasIcpt
+    with HasRegParam
+    with HasTol
+    with HasMaxOuterIter
+    with BaseSystemMLClassifier {
 
-  def setIcpt(value: Int) = set(icpt, value)
-  def setMaxIter(value: Int) = set(maxOuterIter, value)
+  def setIcpt(value: Int)        = set(icpt, value)
+  def setMaxIter(value: Int)     = set(maxOuterIter, value)
   def setRegParam(value: Double) = set(regParam, value)
-  def setTol(value: Double) = set(tol, value)
-  
+  def setTol(value: Double)      = set(tol, value)
+
   override def copy(extra: ParamMap): Estimator[SVMModel] = {
     val that = new SVM(uid, sc, isMultiClass)
     copyValues(that, extra)
   }
-  
-  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
-    val script = dml(ScriptsUtils.getDMLScript(if(isMultiClass) SVM.scriptPathMulticlass else SVM.scriptPathBinary))
+
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
+    val script = dml(ScriptsUtils.getDMLScript(if (isMultiClass) SVM.scriptPathMulticlass else SVM.scriptPathBinary))
       .in("$X", " ")
       .in("$Y", " ")
       .in("$model", " ")
@@ -64,58 +69,56 @@ class SVM (override val uid: String, val sc: SparkContext, val isMultiClass:Bool
       .out("w")
     (script, "X", "Y")
   }
-  
+
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): SVMModel = {
     mloutput = baseFit(X_mb, y_mb, sc)
     new SVMModel(this, isMultiClass)
   }
-  
+
   def fit(df: ScriptsUtils.SparkDataType): SVMModel = {
     mloutput = baseFit(df, sc)
     new SVMModel(this, isMultiClass)
   }
-  
+
 }
 
 object SVMModel {
-  final val predictionScriptPathBinary = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm-predict.dml"
+  final val predictionScriptPathBinary     = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm-predict.dml"
   final val predictionScriptPathMulticlass = "scripts" + File.separator + "algorithms" + File.separator + "m-svm-predict.dml"
 }
 
-class SVMModel (override val uid: String)(estimator:SVM, val sc: SparkContext, val isMultiClass:Boolean) 
-  extends Model[SVMModel] with BaseSystemMLClassifierModel {
+class SVMModel(override val uid: String)(estimator: SVM, val sc: SparkContext, val isMultiClass: Boolean) extends Model[SVMModel] with BaseSystemMLClassifierModel {
   override def copy(extra: ParamMap): SVMModel = {
     val that = new SVMModel(uid)(estimator, sc, isMultiClass)
     copyValues(that, extra)
   }
-  
-  def this(estimator:SVM, isMultiClass:Boolean) =  {
-  	this("model")(estimator, estimator.sc, isMultiClass)
+
+  def this(estimator: SVM, isMultiClass: Boolean) = {
+    this("model")(estimator, estimator.sc, isMultiClass)
   }
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  def modelVariables():List[String] = List[String]("w")
-  
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
-    val script = dml(ScriptsUtils.getDMLScript(if(isMultiClass) SVMModel.predictionScriptPathMulticlass else SVMModel.predictionScriptPathBinary))
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+  def modelVariables(): List[String]         = List[String]("w")
+
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) = {
+    val script = dml(ScriptsUtils.getDMLScript(if (isMultiClass) SVMModel.predictionScriptPathMulticlass else SVMModel.predictionScriptPathBinary))
       .in("$X", " ")
       .in("$model", " ")
       .out("scores")
-    
-    val w = estimator.mloutput.getMatrix("w")
-    val wVar = if(isMultiClass) "W" else "w"
-      
-    val ret = if(isSingleNode) {
+
+    val w    = estimator.mloutput.getMatrix("w")
+    val wVar = if (isMultiClass) "W" else "w"
+
+    val ret = if (isSingleNode) {
       script.in(wVar, w.toMatrixBlock, w.getMatrixMetadata)
-    }
-    else {
+    } else {
       script.in(wVar, w)
     }
     (ret, "X")
   }
-  
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "scores")
-  def transform_probability(X: MatrixBlock): MatrixBlock = baseTransformProbability(X, sc, "scores")
+
+  def transform(X: MatrixBlock): MatrixBlock               = baseTransform(X, sc, "scores")
+  def transform_probability(X: MatrixBlock): MatrixBlock   = baseTransformProbability(X, sc, "scores")
   def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "scores")
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala b/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
index 2ba0f2b..016457e 100644
--- a/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -26,16 +26,16 @@ import org.apache.sysml.runtime.DMLRuntimeException
 
 object ScriptsUtils {
   var systemmlHome = System.getenv("SYSTEMML_HOME")
-		  
+
   type SparkDataType = org.apache.spark.sql.Dataset[_] // org.apache.spark.sql.DataFrame for Spark 1.x
 
   /**
-   * set SystemML home
-   */
+    * set SystemML home
+    */
   def setSystemmlHome(path: String) {
     systemmlHome = path
   }
-  
+
   /*
    * Internal function to get dml path
    */
@@ -49,7 +49,7 @@ object ScriptsUtils {
    */
   private[sysml] def getDMLScript(scriptPath: String): String = {
     var reader: BufferedReader = null
-    val out = new StringBuilder()
+    val out                    = new StringBuilder()
     try {
       val in = {
         if (systemmlHome == null || systemmlHome.equals("")) {
@@ -60,7 +60,7 @@ object ScriptsUtils {
         }
       }
       var reader = new BufferedReader(new InputStreamReader(in))
-      var line = reader.readLine()
+      var line   = reader.readLine()
       while (line != null) {
         out.append(line);
         out.append(System.getProperty("line.separator"));
@@ -75,4 +75,4 @@ object ScriptsUtils {
     }
     out.toString()
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/ml/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/Utils.scala b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
index a804f64..77bd17a 100644
--- a/src/main/scala/org/apache/sysml/api/ml/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,18 +27,17 @@ object Utils {
   val originalErr = System.err
 }
 class Utils {
-  def checkIfFileExists(filePath:String):Boolean = {
+  def checkIfFileExists(filePath: String): Boolean =
     return org.apache.sysml.runtime.util.MapReduceTool.existsFileOnHDFS(filePath)
-  }
-  
+
   // --------------------------------------------------------------------------------
   // Simple utility function to print the information about our binary blocked format
-  def getBinaryBlockInfo(binaryBlocks:JavaPairRDD[MatrixIndexes, MatrixBlock]):String = {
-    val sb = new StringBuilder
+  def getBinaryBlockInfo(binaryBlocks: JavaPairRDD[MatrixIndexes, MatrixBlock]): String = {
+    val sb             = new StringBuilder
     var partitionIndex = 0
-    for(str <- binaryBlocks.rdd.mapPartitions(binaryBlockIteratorToString(_), true).collect) {
+    for (str <- binaryBlocks.rdd.mapPartitions(binaryBlockIteratorToString(_), true).collect) {
       sb.append("-------------------------------------\n")
-      sb.append("Partition " + partitionIndex  + ":\n")
+      sb.append("Partition " + partitionIndex + ":\n")
       sb.append(str)
       partitionIndex = partitionIndex + 1
     }
@@ -47,40 +46,40 @@ class Utils {
   }
   def binaryBlockIteratorToString(it: Iterator[(MatrixIndexes, MatrixBlock)]): Iterator[String] = {
     val sb = new StringBuilder
-    for(entry <- it) {
+    for (entry <- it) {
       val mi = entry._1
       val mb = entry._2
       sb.append(mi.toString);
-  		sb.append(" sparse? = ");
-  		sb.append(mb.isInSparseFormat());
-  		if(mb.isUltraSparse)
-  		  sb.append(" (ultra-sparse)") 
-  		sb.append(", nonzeros = ");
-  		sb.append(mb.getNonZeros);
-  		sb.append(", dimensions = ");
-  		sb.append(mb.getNumRows);
-  		sb.append(" X ");
-  		sb.append(mb.getNumColumns);
-  		sb.append("\n");
+      sb.append(" sparse? = ");
+      sb.append(mb.isInSparseFormat());
+      if (mb.isUltraSparse)
+        sb.append(" (ultra-sparse)")
+      sb.append(", nonzeros = ");
+      sb.append(mb.getNonZeros);
+      sb.append(", dimensions = ");
+      sb.append(mb.getNumRows);
+      sb.append(" X ");
+      sb.append(mb.getNumColumns);
+      sb.append("\n");
     }
     List[String](sb.toString).iterator
   }
   val baos = new java.io.ByteArrayOutputStream()
   val baes = new java.io.ByteArrayOutputStream()
-  def startRedirectStdOut():Unit = {  
+  def startRedirectStdOut(): Unit = {
     System.setOut(new java.io.PrintStream(baos));
     System.setErr(new java.io.PrintStream(baes));
   }
-  def flushStdOut():String = {
+  def flushStdOut(): String = {
     val ret = baos.toString() + baes.toString()
     baos.reset(); baes.reset()
     return ret
   }
-  def stopRedirectStdOut():String = {
+  def stopRedirectStdOut(): String = {
     val ret = baos.toString() + baes.toString()
     System.setOut(Utils.originalOut)
     System.setErr(Utils.originalErr)
     return ret
   }
   // --------------------------------------------------------------------------------
-}
\ No newline at end of file
+}


[3/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

Posted by ni...@apache.org.
http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index cbd5fa3..c8159be 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -32,81 +32,79 @@ import java.util.ArrayList
 trait CaffeLayer extends BaseDMLGenerator {
   // -------------------------------------------------
   // Any layer that wants to reuse SystemML-NN has to override following methods that help in generating the DML for the given layer:
-  def sourceFileName:String;
-  def init(dmlScript:StringBuilder):Unit;
-  def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
-  def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
-  var computedOutputShape:(String, String, String) = null
-  def outputShape:(String, String, String) = {
-    if(computedOutputShape == null) computedOutputShape = bottomLayerOutputShape
+  def sourceFileName: String;
+  def init(dmlScript: StringBuilder): Unit;
+  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit;
+  def backward(dmlScript: StringBuilder, outSuffix: String): Unit;
+  var computedOutputShape: (String, String, String) = null
+  def outputShape: (String, String, String) = {
+    if (computedOutputShape == null) computedOutputShape = bottomLayerOutputShape
     computedOutputShape
   }
   // -------------------------------------------------
-  var computedBottomLayerOutputShape:(String, String, String) = null
-  def bottomLayerOutputShape:(String, String, String) = {
-    if(computedBottomLayerOutputShape == null) {
+  var computedBottomLayerOutputShape: (String, String, String) = null
+  def bottomLayerOutputShape: (String, String, String) = {
+    if (computedBottomLayerOutputShape == null) {
       // Note: if you get org.apache.sysml.parser.LanguageException: Map is null exception
       // from org.apache.sysml.api.dl.CaffeNetwork.org$apache$sysml$api$dl$CaffeNetwork$$convertLayerParameterToCaffeLayer
       // you are attempting to get traverse the network (for example: bottomLayerOutputShape) before it is created.
       val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-      if(ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
+      if (ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
       computedBottomLayerOutputShape = ret(0).outputShape
     }
     computedBottomLayerOutputShape
   }
-  def param:LayerParameter
-  def id:Int
-  def net:CaffeNetwork
+  def param: LayerParameter
+  def id: Int
+  def net: CaffeNetwork
   // --------------------------------------------------------------------------------------
   // No need to override these methods in subclasses
   // Exception: Only Data layer overrides "out" method to use 'Xb' for consistency
   // Naming of the below methods is consistent with the nn library:
   // X (feature map from the previous layer) ----> Forward pass  ----> out (feature map to the next layer)
   // dX (errors to the previous layer)       <---- Backward pass <---- dout (errors from the next layer)
-  def out = "out" + id  
-  var computedX:String = null
-  def X:String = {
-    if(computedX == null) {
+  def out               = "out" + id
+  var computedX: String = null
+  def X: String = {
+    if (computedX == null) {
       val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-      if(ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
-      else if(ret.size == 1)    computedX = ret(0).out
-      else                      computedX = sum(new StringBuilder, ret.map(_.out).toList).toString()
+      if (ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
+      else if (ret.size == 1) computedX = ret(0).out
+      else computedX = sum(new StringBuilder, ret.map(_.out).toList).toString()
     }
     computedX
   }
-  var computedDout:String = null
+  var computedDout: String = null
   def dout: String = {
-    if(computedDout == null) {
+    if (computedDout == null) {
       val ret = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-      if(ret.size == 0) throw new LanguageException("Expected atleast 1 top layer for " + param.getName)
-      else if(ret.size == 1)     computedDout = ret(0).dX(id)
-      else                       computedDout = sum(new StringBuilder, ret.map(_.dX(id)).toList).toString()
+      if (ret.size == 0) throw new LanguageException("Expected atleast 1 top layer for " + param.getName)
+      else if (ret.size == 1) computedDout = ret(0).dX(id)
+      else computedDout = sum(new StringBuilder, ret.map(_.dX(id)).toList).toString()
     }
     computedDout
   }
-  def dX(bottomLayerID:Int) = "dOut" + id + "_" + bottomLayerID
+  def dX(bottomLayerID: Int) = "dOut" + id + "_" + bottomLayerID
   // --------------------------------------------------------------------------------------
-  // No need to override these methods in subclasses, instead classes that have weights and biases 
+  // No need to override these methods in subclasses, instead classes that have weights and biases
   // should implement HasWeight and HasBias traits.
-  def dWeight():String = throw new DMLRuntimeException("dWeight is not implemented in super class")
-  def dBias():String = throw new DMLRuntimeException("dBias is not implemented in super class")
-  def weight():String = null;
-  def weightShape():Array[Int];
-  def bias():String = null;
-  def biasShape():Array[Int];
-  def shouldUpdateWeight():Boolean = if(weight != null) true else false
-  def shouldUpdateBias():Boolean = if(bias != null) true else false
+  def dWeight(): String = throw new DMLRuntimeException("dWeight is not implemented in super class")
+  def dBias(): String   = throw new DMLRuntimeException("dBias is not implemented in super class")
+  def weight(): String  = null;
+  def weightShape(): Array[Int];
+  def bias(): String = null;
+  def biasShape(): Array[Int];
+  def shouldUpdateWeight(): Boolean = if (weight != null) true else false
+  def shouldUpdateBias(): Boolean   = if (bias != null) true else false
   // --------------------------------------------------------------------------------------
   // Helper methods to simplify the code of subclasses
-  def invokeInit(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
+  def invokeInit(dmlScript: StringBuilder, returnVariables: List[String], arguments: String*): Unit =
     invoke(dmlScript, sourceFileName + "::", returnVariables, "init", arguments.toList)
-  }
-  def invokeForward(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
+  def invokeForward(dmlScript: StringBuilder, returnVariables: List[String], arguments: String*): Unit =
     invoke(dmlScript, sourceFileName + "::", returnVariables, "forward", arguments.toList)
-  }
   // -----------------------------------------------------------------------------------
   // All the layers (with the exception of Concat) call one of the below methods in the backward function.
-  // The preceding layer expects that 'dX(bottomLayerID) + outSuffix' is assigned. 
+  // The preceding layer expects that 'dX(bottomLayerID) + outSuffix' is assigned.
   // l1 <--- dX(1) <-----|
   //                     |-- [current layer: dOut3 (computed by backward)] <---- "dOut" + id + outSuffix
   // l2 <--- dX(2) <-----|
@@ -114,73 +112,70 @@ trait CaffeLayer extends BaseDMLGenerator {
   // 1. Compute backward: either call dml file's backward (for example: invokeBackward) or just propagate next layers errors (assignDoutToDX)
   // 2. Then make sure that all the preceding layer get the errors using:
   //        bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
-  
+
   // The layers that have a corresponding dml script call this method.
   // Assumption: the first variable of resultVariables is always dX
-  def invokeBackward(dmlScript:StringBuilder, outSuffix:String, resultVariables:List[String],  arguments:String*):Unit = {
+  def invokeBackward(dmlScript: StringBuilder, outSuffix: String, resultVariables: List[String], arguments: String*): Unit = {
     invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
     val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
     dmlScript.append("; ")
-    bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
+    bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
     dmlScript.append("\n")
   }
   // On-the-fly layers (such as Scale and Elementwise) call this function to propagate next layers errors to the previous layer
-  def assignDoutToDX(dmlScript:StringBuilder, outSuffix:String):Unit = {
-    dmlScript.append("dOut" + id  + outSuffix + " = " + dout)
+  def assignDoutToDX(dmlScript: StringBuilder, outSuffix: String): Unit = {
+    dmlScript.append("dOut" + id + outSuffix + " = " + dout)
     val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
     dmlScript.append("; ")
-    bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+    bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
     dmlScript.append("\n")
   }
   // --------------------------------------------------------------------------------------
 }
 
-
 trait IsLossLayer extends CaffeLayer {
-  def computeLoss(dmlScript:StringBuilder, numTabs:Int):Unit
+  def computeLoss(dmlScript: StringBuilder, numTabs: Int): Unit
 }
 
 trait HasWeight extends CaffeLayer {
-  override def weight = param.getName + "_weight"
+  override def weight  = param.getName + "_weight"
   override def dWeight = param.getName + "_dWeight"
 }
 
 trait HasBias extends CaffeLayer {
-  override def bias = param.getName + "_bias"
+  override def bias  = param.getName + "_bias"
   override def dBias = param.getName + "_dBias"
 }
 
-class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numChannels:String, val height:String, val width:String) extends CaffeLayer {
+class Data(val param: LayerParameter, val id: Int, val net: CaffeNetwork, val numChannels: String, val height: String, val width: String) extends CaffeLayer {
   // -------------------------------------------------
   override def sourceFileName = null
-  override def init(dmlScript:StringBuilder) = {
-    if(param.hasTransformParam && param.getTransformParam.hasScale) {
+  override def init(dmlScript: StringBuilder) = {
+    if (param.hasTransformParam && param.getTransformParam.hasScale) {
       dmlScript.append("X_full = X_full * " + param.getTransformParam.getScale + "\n")
     }
-    if(param.hasDataParam && param.getDataParam.hasBatchSize) {
+    if (param.hasDataParam && param.getDataParam.hasBatchSize) {
       dmlScript.append("BATCH_SIZE = " + param.getDataParam.getBatchSize + "\n")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Using default batch size of 64 as batch size is not set with DataParam")
       dmlScript.append("BATCH_SIZE = 64\n")
     }
   }
-  var dataOutputShape = ("$num_channels", "$height", "$width")
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = { }
-  override def out = "Xb"
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = { }
-  override def outputShape = (numChannels, height, width)
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  var dataOutputShape                                                   = ("$num_channels", "$height", "$width")
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = {}
+  override def out                                                      = "Xb"
+  override def backward(dmlScript: StringBuilder, outSuffix: String)    = {}
+  override def outputShape                                              = (numChannels, height, width)
+  override def weightShape(): Array[Int]                                = null
+  override def biasShape(): Array[Int]                                  = null
   // -------------------------------------------------
 }
 
-
 // ------------------------------------------------------------------
 // weight is ema_mean and bias is ema_var
-// Fuse 
-class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
-  // val scale =  
+// Fuse
+class BatchNorm(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  // val scale =
   override def sourceFileName = "batch_norm2d"
   /*
    * Initialize the parameters of this layer.
@@ -199,8 +194,8 @@ class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) exte
    *  - ema_var: Exponential moving average of the variance, of
    *      shape (C, 1).
    */
-  override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](gamma, beta, ema_mean, ema_var), numChannels)
-  var update_mean_var = true
+  override def init(dmlScript: StringBuilder) = invokeInit(dmlScript, List[String](gamma, beta, ema_mean, ema_var), numChannels)
+  var update_mean_var                         = true
   /*
    * Computes the forward pass for a 2D (spatial) batch normalization
    * layer.  The input data has N examples, each represented as a 3D
@@ -258,9 +253,22 @@ class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) exte
    *      during training.
    */
   def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
-    val mode = if(isPrediction) "\"test\"" else "\"train\""
-    invokeForward(dmlScript, List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)), 
-        X, gamma, beta, numChannels, Hin, Win, mode, ema_mean, ema_var,  ma_fraction, eps)  
+    val mode = if (isPrediction) "\"test\"" else "\"train\""
+    invokeForward(
+      dmlScript,
+      List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)),
+      X,
+      gamma,
+      beta,
+      numChannels,
+      Hin,
+      Win,
+      mode,
+      ema_mean,
+      ema_var,
+      ma_fraction,
+      eps
+    )
   }
   /*
    * Computes the backward pass for a 2D (spatial) batch normalization
@@ -309,89 +317,105 @@ class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) exte
    *  - dbeta: Gradient wrt `b`, of shape (C, 1).
    *
    */
-  def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
-    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels, 
-          Hin, Win, "\"train\"", ema_mean, ema_var,  ma_fraction, eps)
-  }
-  
-  private def withSuffix(str:String):String = if(update_mean_var) str else str + "_ignore"
-  override def weight = "ema_mean" + id
-  override def weightShape():Array[Int] = Array(numChannels.toInt, 1)
-  override def biasShape():Array[Int] = Array(numChannels.toInt, 1)
-  override def bias = "ema_var" + id
-  def cache_mean(): String = "cache_mean" + id
-  def cache_var():String = "cache_mean" + id
-  def cache_norm():String = "cache_norm" + id
-  var scaleLayer:Scale = null
-  def gamma():String = { checkNextLayer(); scaleLayer.weight }
-  def ma_fraction():String = if(param.getBatchNormParam.hasMovingAverageFraction()) param.getBatchNormParam.getMovingAverageFraction.toString else "0.999"
-  def eps():String = if(param.getBatchNormParam.hasEps()) param.getBatchNormParam.getEps.toString else "1e-5"
-  def beta():String = { checkNextLayer(); scaleLayer.bias }
-  def dgamma():String = { checkNextLayer();  scaleLayer.dWeight }
-  def dbeta():String = { checkNextLayer();  scaleLayer.dBias }
-  override def shouldUpdateWeight():Boolean = false
-  override def shouldUpdateBias():Boolean = false
-  def ema_mean(): String = weight
-  def ema_var(): String = bias
-  def checkNextLayer(): Unit = {
-    if(scaleLayer == null) {
+  def backward(dmlScript: StringBuilder, outSuffix: String): Unit =
+    invokeBackward(
+      dmlScript,
+      outSuffix,
+      List[String]("dOut" + id, dgamma, dbeta),
+      dout,
+      out,
+      ema_mean,
+      ema_var,
+      cache_mean,
+      cache_var,
+      cache_norm,
+      X,
+      gamma,
+      beta,
+      numChannels,
+      Hin,
+      Win,
+      "\"train\"",
+      ema_mean,
+      ema_var,
+      ma_fraction,
+      eps
+    )
+
+  private def withSuffix(str: String): String = if (update_mean_var) str else str + "_ignore"
+  override def weightShape(): Array[Int]      = Array(numChannels.toInt, 1)
+  override def biasShape(): Array[Int]        = Array(numChannels.toInt, 1)
+  def cache_mean(): String                    = "cache_mean" + id
+  def cache_var(): String                     = "cache_mean" + id
+  def cache_norm(): String                    = "cache_norm" + id
+  var scaleLayer: Scale                       = null
+  def gamma(): String                         = { checkNextLayer(); scaleLayer.weight }
+  def ma_fraction(): String                   = if (param.getBatchNormParam.hasMovingAverageFraction()) param.getBatchNormParam.getMovingAverageFraction.toString else "0.999"
+  def eps(): String                           = if (param.getBatchNormParam.hasEps()) param.getBatchNormParam.getEps.toString else "1e-5"
+  def beta(): String                          = { checkNextLayer(); scaleLayer.bias }
+  def dgamma(): String                        = { checkNextLayer(); scaleLayer.dWeight }
+  def dbeta(): String                         = { checkNextLayer(); scaleLayer.dBias }
+  override def shouldUpdateWeight(): Boolean  = false
+  override def shouldUpdateBias(): Boolean    = false
+  def ema_mean(): String                      = weight
+  def ema_var(): String                       = bias
+  def checkNextLayer(): Unit =
+    if (scaleLayer == null) {
       val topLayers = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-      if(topLayers.length != 1 && !topLayers(0).isInstanceOf[Scale]) throw new LanguageException("Only one top layer of type Scale allowed for BatchNorm")
+      if (topLayers.length != 1 && !topLayers(0).isInstanceOf[Scale]) throw new LanguageException("Only one top layer of type Scale allowed for BatchNorm")
       scaleLayer = topLayers(0).asInstanceOf[Scale]
     }
-  }
   def numChannels = bottomLayerOutputShape._1
-  def Hin = bottomLayerOutputShape._2
-  def Win = bottomLayerOutputShape._3
+  def Hin         = bottomLayerOutputShape._2
+  def Win         = bottomLayerOutputShape._3
 }
 // weight is gamma and bias is beta
-class Scale(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
-  if(!param.getScaleParam.getBiasTerm) throw new LanguageException("Add \"scale_param { bias_term: true }\" to the layer " + param.getName)
-  override def sourceFileName = null
+class Scale(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  if (!param.getScaleParam.getBiasTerm) throw new LanguageException("Add \"scale_param { bias_term: true }\" to the layer " + param.getName)
+  override def sourceFileName                       = null
   override def init(dmlScript: StringBuilder): Unit = {}
   // TODO: Generalize this !!
-  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = assign(dmlScript, out, X)
-  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
-  override def weightShape():Array[Int] = Array(bottomLayerOutputShape._1.toInt, 1)
-  override def biasShape():Array[Int] = Array(bottomLayerOutputShape._1.toInt, 1)
+  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit       = assign(dmlScript, out, X)
+  override def backward(dmlScript: StringBuilder, outSuffix: String): Unit = assignDoutToDX(dmlScript, outSuffix)
+  override def weightShape(): Array[Int]                                   = Array(bottomLayerOutputShape._1.toInt, 1)
+  override def biasShape(): Array[Int]                                     = Array(bottomLayerOutputShape._1.toInt, 1)
 }
 // ------------------------------------------------------------------
 
-class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
-  override def sourceFileName = null
+class Elementwise(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName                       = null
   override def init(dmlScript: StringBuilder): Unit = {}
-  if(param.getEltwiseParam.hasOperation && param.getEltwiseParam.getOperation != EltwiseOp.SUM)
+  if (param.getEltwiseParam.hasOperation && param.getEltwiseParam.getOperation != EltwiseOp.SUM)
     throw new LanguageException("Currently only elementwise sum operation supported")
-  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit =
     addAndAssign(dmlScript, out, param.getBottomList.map(b => net.getCaffeLayer(b).out).toList)
-  }
-  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
+  override def backward(dmlScript: StringBuilder, outSuffix: String): Unit = assignDoutToDX(dmlScript, outSuffix)
   override def outputShape = {
-    if(_out == null) _out = net.getCaffeLayer(net.getBottomLayers(param.getName).take(1).toSeq.get(0)).outputShape
+    if (_out == null) _out = net.getCaffeLayer(net.getBottomLayers(param.getName).take(1).toSeq.get(0)).outputShape
     _out
   }
-  var _out:(String, String, String) = null
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  var _out: (String, String, String)     = null
+  override def weightShape(): Array[Int] = null
+  override def biasShape(): Array[Int]   = null
 }
 
-class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
-  override def sourceFileName = null
+class Concat(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName                       = null
   override def init(dmlScript: StringBuilder): Unit = {}
-  var _childLayers:List[CaffeLayer] = null
-  
+  var _childLayers: List[CaffeLayer]                = null
+
   // Utility function to create string of format:
   // fn(fn(fn(_childLayers(0).out, _childLayers(1).out), _childLayers(2).out), ...)
   // This is useful because we do not support multi-input cbind and rbind in DML.
-  def _getMultiFn(fn:String):String = {
-    if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+  def _getMultiFn(fn: String): String = {
+    if (_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
     var tmp = fn + "(" + _childLayers(0).out + ", " + _childLayers(1).out + ")"
-    for(i <- 2 until _childLayers.size) {
-      tmp = fn + "(" + tmp + ", " +  _childLayers(i).out + ")"
+    for (i <- 2 until _childLayers.size) {
+      tmp = fn + "(" + tmp + ", " + _childLayers(i).out + ")"
     }
     tmp
   }
-  
+
   /*
    * Computes the forward pass for a concatenation layer.
    *
@@ -399,174 +423,162 @@ class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends
    *  - n_i * c_i * h * w for each input blob i from 1 to K.
    *
    * Outputs:
-   *  - out: Outputs, of shape 
+   *  - out: Outputs, of shape
    *    - if axis = 0: (n_1 + n_2 + ... + n_K) * c_1 * h * w, and all input c_i should be the same.
    *    - if axis = 1: n_1 * (c_1 + c_2 + ... + c_K) * h * w, and all input n_i should be the same.
    */
-  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
-    if(param.getConcatParam.getAxis == 0) {
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit =
+    if (param.getConcatParam.getAxis == 0) {
       // rbind the inputs
       assign(dmlScript, out, _getMultiFn("rbind"))
-    }
-    else if(param.getConcatParam.getAxis == 1) {
+    } else if (param.getConcatParam.getAxis == 1) {
       // cbind the inputs
       assign(dmlScript, out, _getMultiFn("cbind"))
-    }
-    else {
+    } else {
       throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
     }
-  }
-  
-  def startIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
-  def endIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
-  def getConcatIndex(bottomLayerOut:String, outSuffix:String):String = 
-     startIndex(outSuffix) + " = " + endIndex(outSuffix) + " + 1; " +
-     endIndex(outSuffix) + " = " + startIndex(outSuffix) + " + nrow(" + bottomLayerOut + "); "
-  
+
+  def startIndex(outSuffix: String): String = "concat_start_index_" + outSuffix
+  def endIndex(outSuffix: String): String   = "concat_start_index_" + outSuffix
+  def getConcatIndex(bottomLayerOut: String, outSuffix: String): String =
+    startIndex(outSuffix) + " = " + endIndex(outSuffix) + " + 1; " +
+    endIndex(outSuffix) + " = " + startIndex(outSuffix) + " + nrow(" + bottomLayerOut + "); "
+
   /*
    * Computes the backward pass for a concatenation layer.
    *
    * The top gradients are deconcatenated back to the inputs.
    *
    */
-  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
+  override def backward(dmlScript: StringBuilder, outSuffix: String): Unit = {
     val bottomLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-    val dOutVar = "dOut" + id  + outSuffix
+    val dOutVar      = "dOut" + id + outSuffix
     // concat_end_index = 0
     dmlScript.append(dOutVar + " = " + dout + "; concat_end_index" + outSuffix + " = 0; ")
-    
+
     val indexString = "concat_start_index" + outSuffix + " : concat_end_index" + outSuffix
-    val doutVarAssignment = if(param.getConcatParam.getAxis == 0) " = " + dOutVar + "[" +  indexString + ", ]; "
-                            else " = " + dOutVar + "[," +  indexString + " ]; "
-    
+    val doutVarAssignment =
+      if (param.getConcatParam.getAxis == 0) " = " + dOutVar + "[" + indexString + ", ]; "
+      else " = " + dOutVar + "[," + indexString + " ]; "
+
     // concat_start_index = concat_end_index + 1
     // concat_end_index = concat_start_index + $$ - 1
-    val initializeIndexString = "concat_start_index" + outSuffix + " = concat_end_index" + outSuffix + " + 1; concat_end_index" + outSuffix + 
-        " = concat_start_index" + outSuffix + " + $$ - 1; "
-    if(param.getConcatParam.getAxis == 0) {
+    val initializeIndexString = "concat_start_index" + outSuffix + " = concat_end_index" + outSuffix + " + 1; concat_end_index" + outSuffix +
+    " = concat_start_index" + outSuffix + " + $$ - 1; "
+    if (param.getConcatParam.getAxis == 0) {
       bottomLayers.map(l => {
-        dmlScript.append(initializeIndexString.replaceAll("$$", nrow(l.out)))
-                  // X1 = Z[concat_start_index:concat_end_index,]
-                 .append( dX(l.id) + outSuffix + doutVarAssignment)
+        dmlScript
+          .append(initializeIndexString.replaceAll("$$", nrow(l.out)))
+          // X1 = Z[concat_start_index:concat_end_index,]
+          .append(dX(l.id) + outSuffix + doutVarAssignment)
       })
-    }
-    else {
+    } else {
       bottomLayers.map(l => {
-        dmlScript.append(initializeIndexString.replaceAll("$$", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3) ))
-                  // X1 = Z[concat_start_index:concat_end_index,]
-                 .append( dX(l.id) + outSuffix + doutVarAssignment)
+        dmlScript
+          .append(initializeIndexString.replaceAll("$$", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3)))
+          // X1 = Z[concat_start_index:concat_end_index,]
+          .append(dX(l.id) + outSuffix + doutVarAssignment)
       })
     }
     dmlScript.append("\n")
   }
-  def sumChannels():String = {
+  def sumChannels(): String = {
     val channels = _childLayers.map(_.outputShape._1)
-    try { 
+    try {
       channels.reduce((c1, c2) => (c1.toInt + c2.toInt).toString())
-    }
-    catch { 
-      case _:Throwable => sum(new StringBuilder, channels).toString
+    } catch {
+      case _: Throwable => sum(new StringBuilder, channels).toString
     }
   }
   override def outputShape = {
-    if(_out == null) {
-      if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-      if(param.getConcatParam.getAxis == 0) {
+    if (_out == null) {
+      if (_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if (param.getConcatParam.getAxis == 0) {
         _out = _childLayers(0).outputShape
-      }
-      else if(param.getConcatParam.getAxis == 1) {
+      } else if (param.getConcatParam.getAxis == 1) {
         _out = (sumChannels(), _childLayers(0).outputShape._2, _childLayers(0).outputShape._3)
-      }
-      else {
+      } else {
         throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
       }
     }
     _out
   }
-  var _out:(String, String, String) = null
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  var _out: (String, String, String)     = null
+  override def weightShape(): Array[Int] = null
+  override def biasShape(): Array[Int]   = null
 }
 
-class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with IsLossLayer {
+class SoftmaxWithLoss(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with IsLossLayer {
   // -------------------------------------------------
-  override def sourceFileName = if(!isSegmentationProblem()) "softmax" else "softmax2d" 
-  override def init(dmlScript:StringBuilder) = {}
-  def isSegmentationProblem():Boolean = {
+  override def sourceFileName                 = if (!isSegmentationProblem()) "softmax" else "softmax2d"
+  override def init(dmlScript: StringBuilder) = {}
+  def isSegmentationProblem(): Boolean =
     try {
       return outputShape._2.toInt != 1 && outputShape._3.toInt != 1
-    } catch { 
-      case _:Throwable => throw new RuntimeException("Cannot infer the output dimensions:" + outputShape)
+    } catch {
+      case _: Throwable => throw new RuntimeException("Cannot infer the output dimensions:" + outputShape)
     }
-  }
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = {
-    if(!isSegmentationProblem()) {
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+    if (!isSegmentationProblem()) {
       invokeForward(dmlScript, List[String](out), scores)
-    }
-    else {
+    } else {
       invokeForward(dmlScript, List[String](out), scores, outputShape._1)
     }
-  }
-  override def backward(dmlScript:StringBuilder, outSuffix:String) =  {
-    if(!isSegmentationProblem()) {
+  override def backward(dmlScript: StringBuilder, outSuffix: String) =
+    if (!isSegmentationProblem()) {
       invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", false, out, "yb")
-      dmlScript.append("; ") 
+      dmlScript.append("; ")
       invoke(dmlScript, "softmax::", List[String]("dOut" + id + outSuffix), "backward", false, "dProbs", scores)
       val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
       dmlScript.append("; ")
-      bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+      bottomLayerIDs.map(bottomLayerID => dmlScript.append(dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
       dmlScript.append("\n")
-    }
-    else {
+    } else {
       throw new RuntimeException("backward for SoftmaxWithLoss is not implemented for segmentation problem")
     }
-  }
-  override def computeLoss(dmlScript:StringBuilder, numTabs:Int) = {
-    if(!isSegmentationProblem()) {
+  override def computeLoss(dmlScript: StringBuilder, numTabs: Int) =
+    if (!isSegmentationProblem()) {
       val tabBuilder = new StringBuilder
-      for(i <- 0 until numTabs) tabBuilder.append("\t")
+      for (i <- 0 until numTabs) tabBuilder.append("\t")
       val tabs = tabBuilder.toString
       dmlScript.append("tmp_loss = cross_entropy_loss::forward(" + commaSep(out, "yb") + ")\n")
       dmlScript.append(tabs).append("loss = loss + tmp_loss\n")
       dmlScript.append(tabs).append("true_yb = rowIndexMax(yb)\n")
       dmlScript.append(tabs).append("predicted_yb = rowIndexMax(" + out + ")\n")
       dmlScript.append(tabs).append("accuracy = mean(predicted_yb == true_yb)*100\n")
-    }
-    else {
+    } else {
       throw new RuntimeException("Computation of loss for SoftmaxWithLoss is not implemented for segmentation problem")
     }
+  def scores(): String = {
+    val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+    if (ret.size == 1) return ret.get(0).out
+    else if (ret.size == 2) {
+      val ret1 = if (!ret.get(0).out.equals("Xb")) ret.get(0).out else "";
+      val ret2 = if (!ret.get(1).out.equals("Xb")) ret.get(1).out else "";
+      if (!ret1.equals("") && !ret2.equals("")) throw new LanguageException("Atleast one of the output of previous layer should be Xb")
+      else if (!ret1.equals("")) return ret1
+      else return ret2
+    } else
+      throw new LanguageException("More than 2 bottom layers is not supported")
   }
-  def scores():String = {
-	  val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
-	  if(ret.size == 1) return ret.get(0).out
-	  else if(ret.size == 2) {
-		  val ret1 = if(!ret.get(0).out.equals("Xb")) ret.get(0).out else ""; 
-		  val ret2 = if(!ret.get(1).out.equals("Xb")) ret.get(1).out else "";
-		  if(!ret1.equals("") && !ret2.equals("")) throw new LanguageException("Atleast one of the output of previous layer should be Xb")
-		  else if(!ret1.equals("")) return ret1
-		  else return ret2
-	  }
-	  else 
-		  throw new LanguageException("More than 2 bottom layers is not supported")
-  }
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  override def weightShape(): Array[Int] = null
+  override def biasShape(): Array[Int]   = null
   // -------------------------------------------------
-  override def bottomLayerOutputShape:(String, String, String) = {
-    if(computedBottomLayerOutputShape == null) {
+  override def bottomLayerOutputShape: (String, String, String) = {
+    if (computedBottomLayerOutputShape == null) {
       val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).filter(l => !l.isInstanceOf[Data]).toList
-      if(ret.size != 1) throw new LanguageException("Expected exactly 1 bottom non-Data layer for " + param.getName)
+      if (ret.size != 1) throw new LanguageException("Expected exactly 1 bottom non-Data layer for " + param.getName)
       computedBottomLayerOutputShape = ret(0).outputShape
     }
     computedBottomLayerOutputShape
   }
 }
 
-class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+class ReLU(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
   // TODO: Leaky ReLU: negative_slope [default 0]: specifies whether to leak the negative part by multiplying it with the slope value rather than setting it to 0.
   // -------------------------------------------------
-  override def sourceFileName = "relu"
-  override def init(dmlScript:StringBuilder) = { }
+  override def sourceFileName                 = "relu"
+  override def init(dmlScript: StringBuilder) = {}
   /*
    * Computes the forward pass for a ReLU nonlinearity layer.
    *
@@ -578,7 +590,7 @@ class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends C
    * Outputs:
    *  - out: Outputs, of same shape as `X`.
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = invokeForward(dmlScript, List[String](out), X)
   /*
    * Computes the backward pass for a ReLU nonlinearity layer.
    *
@@ -592,16 +604,16 @@ class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends C
    * Outputs:
    *  - dX: Gradient wrt `X`, of same shape as `X`.
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  override def backward(dmlScript: StringBuilder, outSuffix: String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
+  override def weightShape(): Array[Int]                             = null
+  override def biasShape(): Array[Int]                               = null
   // -------------------------------------------------
 }
 
-class Softmax(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+class Softmax(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
   // -------------------------------------------------
-  override def sourceFileName = "softmax"
-  override def init(dmlScript:StringBuilder) = { }
+  override def sourceFileName                 = "softmax"
+  override def init(dmlScript: StringBuilder) = {}
   /*
    * Computes the forward pass for a softmax classifier.  The inputs
    * are interpreted as unnormalized, log-probabilities for each of
@@ -619,7 +631,7 @@ class Softmax(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
    * Outputs:
    *  - probs: Outputs, of shape (N, D).
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = invokeForward(dmlScript, List[String](out), X)
   /*
    * Computes the backward pass for a softmax classifier.
    *
@@ -641,28 +653,26 @@ class Softmax(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
    * Outputs:
    *  - dscores: Gradient wrt `scores`, of shape (N, D).
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  override def backward(dmlScript: StringBuilder, outSuffix: String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
+  override def weightShape(): Array[Int]                             = null
+  override def biasShape(): Array[Int]                               = null
   // -------------------------------------------------
 }
 
-
-class Threshold(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
-  override def sourceFileName = null
-  override def init(dmlScript:StringBuilder) = { }
-  val threshold = if(param.getThresholdParam.hasThreshold) param.getThresholdParam.getThreshold else 0
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = assign(dmlScript, out, X + " > " + threshold)
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = throw new DMLRuntimeException("Backward operation for Threshold layer is not supported.")
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+class Threshold(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName                                           = null
+  override def init(dmlScript: StringBuilder)                           = {}
+  val threshold                                                         = if (param.getThresholdParam.hasThreshold) param.getThresholdParam.getThreshold else 0
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) = assign(dmlScript, out, X + " > " + threshold)
+  override def backward(dmlScript: StringBuilder, outSuffix: String)    = throw new DMLRuntimeException("Backward operation for Threshold layer is not supported.")
+  override def weightShape(): Array[Int]                                = null
+  override def biasShape(): Array[Int]                                  = null
 }
 
-
-class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+class Dropout(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
   // -------------------------------------------------
-  override def sourceFileName = "dropout"
-  override def init(dmlScript:StringBuilder) = { }
+  override def sourceFileName                 = "dropout"
+  override def init(dmlScript: StringBuilder) = {}
   /*
    * Computes the forward pass for an inverted dropout layer.
    *
@@ -680,8 +690,8 @@ class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
    *  - out: Outputs, of same shape as `X`.
    *  - mask: Dropout mask used to compute the output.
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
-    if(!isPrediction)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+    if (!isPrediction)
       invokeForward(dmlScript, List[String](out, mask), X, p, seed)
     else
       assign(dmlScript, out, X) // Forward-pass not required to be performed during prediction for Dropout layer
@@ -700,18 +710,17 @@ class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
    * Outputs:
    *  - dX: Gradient wrt `X`, of same shape as `X`.
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, 
-      List[String]("dOut" + id), dout, X, p, mask)
+  override def backward(dmlScript: StringBuilder, outSuffix: String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X, p, mask)
   // -------------------------------------------------
   def mask = "mask" + id
   // dropout ratio
-  def p = if(param.getDropoutParam.hasDropoutRatio()) param.getDropoutParam.getDropoutRatio.toString else "0.5"
-  def seed = "-1"
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  def p                                  = if (param.getDropoutParam.hasDropoutRatio()) param.getDropoutParam.getDropoutRatio.toString else "0.5"
+  def seed                               = "-1"
+  override def weightShape(): Array[Int] = null
+  override def biasShape(): Array[Int]   = null
 }
 
-class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+class InnerProduct(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
   // -------------------------------------------------
   // TODO: bias_filler [default type: 'constant' value: 0]; bias_term [default true]: specifies whether to learn and apply a set of additive biases to the filter outputs
   override def sourceFileName = "affine"
@@ -735,7 +744,7 @@ class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) e
    *  - W: Weights, of shape (D, M).
    *  - b: Biases, of shape (1, M).
    */
-  override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numFeatures, numNeurons)
+  override def init(dmlScript: StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numFeatures, numNeurons)
   /*
    * Computes the forward pass for an affine (fully-connected) layer
    * with M neurons.  The input data has N examples, each with D
@@ -749,8 +758,8 @@ class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) e
    * Outputs:
    *  - out: Outputs, of shape (N, M).
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
-      invokeForward(dmlScript, List[String](out), X, weight, bias)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+    invokeForward(dmlScript, List[String](out), X, weight, bias)
   /*
    * Computes the backward pass for a fully-connected (affine) layer
    * with M neurons.
@@ -766,23 +775,22 @@ class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) e
    *  - dW: Gradient wrt `W`, of shape (D, M).
    *  - db: Gradient wrt `b`, of shape (1, M).
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
-      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, X, weight, bias)
+  override def backward(dmlScript: StringBuilder, outSuffix: String) =
+    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, X, weight, bias)
   // -------------------------------------------------
   // num_output (c_o): the number of filters
-  def numNeurons = param.getInnerProductParam.getNumOutput.toString
+  def numNeurons  = param.getInnerProductParam.getNumOutput.toString
   def numFeatures = int_mult(bottomLayerOutputShape._1, bottomLayerOutputShape._2, bottomLayerOutputShape._3)
   // n * c_o * 1 * 1
-  override def outputShape = ( param.getInnerProductParam.getNumOutput.toString, "1", "1" )
-  override def weightShape():Array[Int] = Array(numFeatures.toInt, numNeurons.toInt)
-  override def biasShape():Array[Int] = Array(1, numNeurons.toInt)
+  override def outputShape               = (param.getInnerProductParam.getNumOutput.toString, "1", "1")
+  override def weightShape(): Array[Int] = Array(numFeatures.toInt, numNeurons.toInt)
+  override def biasShape(): Array[Int]   = Array(1, numNeurons.toInt)
 }
 
-
-class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+class MaxPooling(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer {
   // -------------------------------------------------
-  override def sourceFileName = "max_pool2d_builtin"
-  override def init(dmlScript:StringBuilder) = {}
+  override def sourceFileName                 = "max_pool2d_builtin"
+  override def init(dmlScript: StringBuilder) = {}
   /*
    * Computes the forward pass for a 2D spatial max pooling layer.
    * The input data has N examples, each represented as a 3D volume
@@ -810,9 +818,8 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
    *  - Hout: Output height.
    *  - Wout: Output width.
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
-    invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
-        X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+    invokeForward(dmlScript, List[String](out, "ignoreHout_" + id, "ignoreWout_" + id), X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
   /*
    * Computes the backward pass for a 2D spatial max pooling layer.
    * The input data has N examples, each represented as a 3D volume
@@ -839,50 +846,58 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
    * Outputs:
    *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
+  override def backward(dmlScript: StringBuilder, outSuffix: String) =
     invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
   // n * c * h_o * w_o, where h_o and w_o are computed in the same way as convolution.
-  override def outputShape = ( numChannels, Hout, Wout )
+  override def outputShape = (numChannels, Hout, Wout)
   // -------------------------------------------------
-  def Hin = bottomLayerOutputShape._2
-  def Win = bottomLayerOutputShape._3
-  def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h)
-  def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+  def Hin          = bottomLayerOutputShape._2
+  def Win          = bottomLayerOutputShape._3
+  def Hout         = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h)
+  def Wout         = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
   def poolingParam = param.getPoolingParam
-  def numChannels = bottomLayerOutputShape._1
+  def numChannels  = bottomLayerOutputShape._1
   // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
-  def kernel_h = if(poolingParam.hasKernelH) poolingParam.getKernelH.toString 
-                   else poolingParam.getKernelSize.toString 
-  def kernel_w = if(poolingParam.hasKernelW) poolingParam.getKernelW.toString 
-                   else poolingParam.getKernelSize.toString
+  def kernel_h =
+    if (poolingParam.hasKernelH) poolingParam.getKernelH.toString
+    else poolingParam.getKernelSize.toString
+  def kernel_w =
+    if (poolingParam.hasKernelW) poolingParam.getKernelW.toString
+    else poolingParam.getKernelSize.toString
   // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
-  def stride_h = if(poolingParam.hasStrideH) poolingParam.getStrideH.toString 
-                   else if(poolingParam.hasStride) poolingParam.getStride.toString
-                   else "1"
-  def stride_w = if(poolingParam.hasStrideW) poolingParam.getStrideW.toString 
-                   else if(poolingParam.hasStride) poolingParam.getStride.toString
-                   else "1"
+  def stride_h =
+    if (poolingParam.hasStrideH) poolingParam.getStrideH.toString
+    else if (poolingParam.hasStride) poolingParam.getStride.toString
+    else "1"
+  def stride_w =
+    if (poolingParam.hasStrideW) poolingParam.getStrideW.toString
+    else if (poolingParam.hasStride) poolingParam.getStride.toString
+    else "1"
   // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
-  def pad_h =   if(poolingParam.hasPadH) poolingParam.getPadH.toString 
-                   else if(poolingParam.hasPad) poolingParam.getPad.toString
-                   else "0"
-  def pad_w =   if(poolingParam.hasPadW) poolingParam.getPadW.toString 
-                   else if(poolingParam.hasPad) poolingParam.getPad.toString
-                   else "0"
-  override def weightShape():Array[Int] = null
-  override def biasShape():Array[Int] = null
+  def pad_h =
+    if (poolingParam.hasPadH) poolingParam.getPadH.toString
+    else if (poolingParam.hasPad) poolingParam.getPad.toString
+    else "0"
+  def pad_w =
+    if (poolingParam.hasPadW) poolingParam.getPadW.toString
+    else if (poolingParam.hasPad) poolingParam.getPad.toString
+    else "0"
+  override def weightShape(): Array[Int] = null
+  override def biasShape(): Array[Int]   = null
 }
 
-class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
-  def isDepthWise():Boolean = {
-    if(param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1 && numChannels.toInt % param.getConvolutionParam.getGroup != 0) 
-      throw new DMLRuntimeException("The number of groups=" + param.getConvolutionParam.getGroup + " is not supported as it is not divisible by number of channels" + numChannels + ".")
+class Convolution(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  def isDepthWise(): Boolean = {
+    if (param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1 && numChannels.toInt % param.getConvolutionParam.getGroup != 0)
+      throw new DMLRuntimeException(
+        "The number of groups=" + param.getConvolutionParam.getGroup + " is not supported as it is not divisible by number of channels" + numChannels + "."
+      )
     param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1
   }
-  def depthMultiplier():String = if(isDepthWise) (numChannels.toInt / param.getConvolutionParam.getGroup).toString else throw new DMLRuntimeException("Incorrect usage of depth")
-  
+  def depthMultiplier(): String = if (isDepthWise) (numChannels.toInt / param.getConvolutionParam.getGroup).toString else throw new DMLRuntimeException("Incorrect usage of depth")
+
   // -------------------------------------------------
-  override def sourceFileName = if(isDepthWise) "conv2d_builtin_depthwise" else "conv2d_builtin" 
+  override def sourceFileName = if (isDepthWise) "conv2d_builtin_depthwise" else "conv2d_builtin"
   /*
    * Initialize the parameters of this layer.
    *
@@ -911,12 +926,11 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
    *  - W: Weights, of shape (F, C*Hf*Wf).
    *  - b: Biases, of shape (F, 1).
    */
-  override def init(dmlScript:StringBuilder) = {
-    if(isDepthWise)
+  override def init(dmlScript: StringBuilder) =
+    if (isDepthWise)
       invokeInit(dmlScript, List[String](weight, bias), numChannels, depthMultiplier, kernel_h, kernel_w)
     else
       invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
-  }
   /*
    * Computes the forward pass for a 2D spatial convolutional layer with
    * F filters.  The input data has N examples, each represented as a 3D
@@ -953,14 +967,40 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
    *  - Hout: Output height.
    *  - Wout: Output width.
    */
-  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = {
-    if(isDepthWise)
-      invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
-        X, weight, bias, numChannels, Hin, Win, depthMultiplier, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean) =
+    if (isDepthWise)
+      invokeForward(
+        dmlScript,
+        List[String](out, "ignoreHout_" + id, "ignoreWout_" + id),
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        depthMultiplier,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w
+      )
     else
-      invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
-        X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
-  }
+      invokeForward(dmlScript,
+                    List[String](out, "ignoreHout_" + id, "ignoreWout_" + id),
+                    X,
+                    weight,
+                    bias,
+                    numChannels,
+                    Hin,
+                    Win,
+                    kernel_h,
+                    kernel_w,
+                    stride_h,
+                    stride_w,
+                    pad_h,
+                    pad_w)
   /*
    * Computes the backward pass for a 2D spatial convolutional layer
    * with F filters.
@@ -997,71 +1037,114 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
    *  - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf).
    *  - db: Gradient wrt `b`, of shape (F, 1).
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) =  {
-    if(isDepthWise)
-      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, depthMultiplier, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def backward(dmlScript: StringBuilder, outSuffix: String) =
+    if (isDepthWise)
+      invokeBackward(
+        dmlScript,
+        outSuffix,
+        List[String]("dOut" + id, dWeight, dBias),
+        dout,
+        Hout,
+        Wout,
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        depthMultiplier,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w
+      )
     else
-      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
-  }
+      invokeBackward(
+        dmlScript,
+        outSuffix,
+        List[String]("dOut" + id, dWeight, dBias),
+        dout,
+        Hout,
+        Wout,
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w
+      )
   // if not depthwise, n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
   // else (N, C*M*Hout*Wout)
-  override def outputShape = {
-    if(isDepthWise) ( (numChannels.toInt*depthMultiplier.toInt).toString, Hout, Wout )
-    else ( numKernels, Hout, Wout )
-  }
+  override def outputShape =
+    if (isDepthWise) ((numChannels.toInt * depthMultiplier.toInt).toString, Hout, Wout)
+    else (numKernels, Hout, Wout)
   // -------------------------------------------------
   def numChannels = bottomLayerOutputShape._1
-  def Hin = bottomLayerOutputShape._2
-  def Win = bottomLayerOutputShape._3
-  def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h) 
-  def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+  def Hin         = bottomLayerOutputShape._2
+  def Win         = bottomLayerOutputShape._3
+  def Hout        = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h)
+  def Wout        = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
   // -------------------------------------------------
   def convParam = param.getConvolutionParam
   // if depthwise (C, M*Hf*Wf) else (F, C*Hf*Wf)
-  override def weightShape():Array[Int] = {
-    if(isDepthWise) Array(numChannels.toInt, int_mult(depthMultiplier, kernel_h, kernel_w).toInt)
+  override def weightShape(): Array[Int] =
+    if (isDepthWise) Array(numChannels.toInt, int_mult(depthMultiplier, kernel_h, kernel_w).toInt)
     else Array(numKernels.toInt, int_mult(numChannels, kernel_h, kernel_w).toInt)
-  }
   // if depthwise (C*M, 1) else (F, 1)
-  override def biasShape():Array[Int] = {
-    if(isDepthWise) Array(numChannels.toInt*depthMultiplier.toInt, 1)
+  override def biasShape(): Array[Int] =
+    if (isDepthWise) Array(numChannels.toInt * depthMultiplier.toInt, 1)
     else Array(numKernels.toInt, 1)
-  }
   // num_output (c_o): the number of filters
   def numKernels = convParam.getNumOutput.toString
   // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
-  def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString 
-                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
-                   else throw new LanguageException("Incorrect kernel parameters")
-  def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString 
-                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
-                   else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_h =
+    if (convParam.hasKernelH) convParam.getKernelH.toString
+    else if (convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+    else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_w =
+    if (convParam.hasKernelW) convParam.getKernelW.toString
+    else if (convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+    else throw new LanguageException("Incorrect kernel parameters")
   // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
-  def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString 
-                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else "1"
-  def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString 
-                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else "1"
+  def stride_h =
+    if (convParam.hasStrideH) convParam.getStrideH.toString
+    else if (convParam.getStrideCount > 0) convParam.getStride(0).toString
+    else "1"
+  def stride_w =
+    if (convParam.hasStrideW) convParam.getStrideW.toString
+    else if (convParam.getStrideCount > 0) convParam.getStride(0).toString
+    else "1"
   // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
-  def pad_h =   if(convParam.hasPadH) convParam.getPadH.toString 
-                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else "0"
-  def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
-                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else "0"
+  def pad_h =
+    if (convParam.hasPadH) convParam.getPadH.toString
+    else if (convParam.getPadCount > 0) convParam.getPad(0).toString
+    else "0"
+  def pad_w =
+    if (convParam.hasPadW) convParam.getPadW.toString
+    else if (convParam.getPadCount > 0) convParam.getPad(0).toString
+    else "0"
 }
 
-class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
-  def isDepthWise():Boolean = {
-    if(param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1 && numChannels.toInt % param.getConvolutionParam.getGroup != 0) 
-      throw new DMLRuntimeException("The number of groups=" + param.getConvolutionParam.getGroup + " is not supported as it is not divisible by number of channels" + numChannels + ".")
+class DeConvolution(val param: LayerParameter, val id: Int, val net: CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  def isDepthWise(): Boolean = {
+    if (param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1 && numChannels.toInt % param.getConvolutionParam.getGroup != 0)
+      throw new DMLRuntimeException(
+        "The number of groups=" + param.getConvolutionParam.getGroup + " is not supported as it is not divisible by number of channels" + numChannels + "."
+      )
     param.getConvolutionParam.hasGroup && param.getConvolutionParam.getGroup != 1
   }
-  def depthMultiplier():String = if(isDepthWise) (numChannels.toInt / param.getConvolutionParam.getGroup).toString else throw new DMLRuntimeException("Incorrect usage of depth")
-  
-  override def sourceFileName: String = if(isDepthWise) "conv2d_transpose_depthwise" else "conv2d_transpose" 
-  
+  def depthMultiplier(): String = if (isDepthWise) (numChannels.toInt / param.getConvolutionParam.getGroup).toString else throw new DMLRuntimeException("Incorrect usage of depth")
+
+  override def sourceFileName: String = if (isDepthWise) "conv2d_transpose_depthwise" else "conv2d_transpose"
+
   /*
    * Utility function to initialize the parameters of this layer.
    *
@@ -1082,37 +1165,34 @@ class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork)
    *  - M: Depth of each filter (C must be divisible by M).
    *  - Hf: Filter height.
    *  - Wf: Filter width.
-   *  
+   *
    * Outputs:
    *  - W: Weights, of shape (C, F*Hf*Wf).
    *  - b: Biases, of shape (F, 1).
    */
-  override def init(dmlScript: StringBuilder): Unit = {
-    if(isDepthWise)
+  override def init(dmlScript: StringBuilder): Unit =
+    if (isDepthWise)
       invokeInit(dmlScript, List[String](weight, bias), numChannels, depthMultiplier, kernel_h, kernel_w)
     else
       invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
-  }
-  
-  private def C_DivideBy_M():Int = numChannels.toInt / depthMultiplier.toInt
-  
-  // if depthwise (C/M, M*Hf*Wf), else (C, F*Hf*Wf) 
-  override def weightShape():Array[Int] = { 
-    if(isDepthWise)
+
+  private def C_DivideBy_M(): Int = numChannels.toInt / depthMultiplier.toInt
+
+  // if depthwise (C/M, M*Hf*Wf), else (C, F*Hf*Wf)
+  override def weightShape(): Array[Int] =
+    if (isDepthWise)
       Array(C_DivideBy_M, int_mult(depthMultiplier, kernel_h, kernel_w).toInt)
     else
       Array(numChannels.toInt, int_mult(numKernels, kernel_h, kernel_w).toInt)
-  }
   // if depthwise (C/M, 1), else (F, 1)
-  override def biasShape():Array[Int] = {
-    if(isDepthWise)
+  override def biasShape(): Array[Int] =
+    if (isDepthWise)
       Array(C_DivideBy_M, 1)
     else
       Array(numKernels.toInt, 1)
-  }
-  
-  private def numGroups:Int = if(param.getConvolutionParam.hasGroup) param.getConvolutionParam.getGroup else 1
-  
+
+  private def numGroups: Int = if (param.getConvolutionParam.hasGroup) param.getConvolutionParam.getGroup else 1
+
   /*
    * Computes the forward pass for a 2D spatial transpose convolutional
    * layer with F filters.  The input data has N examples, each
@@ -1142,15 +1222,47 @@ class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork)
    *  - Hout: Output height.
    *  - Wout: Output width.
    */
-  override def forward(dmlScript: StringBuilder,isPrediction: Boolean): Unit = {
-    if(isDepthWise)
-      invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
-        X, weight, bias, numChannels, Hin, Win, depthMultiplier, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, "0", "0")
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit =
+    if (isDepthWise)
+      invokeForward(
+        dmlScript,
+        List[String](out, "ignoreHout_" + id, "ignoreWout_" + id),
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        depthMultiplier,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w,
+        "0",
+        "0"
+      )
     else
-      invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
-        X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, "0", "0")
-  }
-        
+      invokeForward(
+        dmlScript,
+        List[String](out, "ignoreHout_" + id, "ignoreWout_" + id),
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w,
+        "0",
+        "0"
+      )
+
   /*
    * Computes the backward pass for a 2D spatial transpose
    * convolutional layer with F filters.
@@ -1179,58 +1291,100 @@ class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork)
    *  - dW: Gradient wrt `W`, of shape (C, F*Hf*Wf).
    *  - db: Gradient wrt `b`, of shape (F, 1).
    */
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = {
-    if(isDepthWise)
-      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), 
-        dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, depthMultiplier, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def backward(dmlScript: StringBuilder, outSuffix: String) =
+    if (isDepthWise)
+      invokeBackward(
+        dmlScript,
+        outSuffix,
+        List[String]("dOut" + id, dWeight, dBias),
+        dout,
+        Hout,
+        Wout,
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        depthMultiplier,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w
+      )
     else
-      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), 
-        dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
-  }
+      invokeBackward(
+        dmlScript,
+        outSuffix,
+        List[String]("dOut" + id, dWeight, dBias),
+        dout,
+        Hout,
+        Wout,
+        X,
+        weight,
+        bias,
+        numChannels,
+        Hin,
+        Win,
+        kernel_h,
+        kernel_w,
+        stride_h,
+        stride_w,
+        pad_h,
+        pad_w
+      )
   // if not depthwise n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
   // else (N, C/M*Hout*Wout)
-  override def outputShape = if(isDepthWise) ( C_DivideBy_M().toString, Hout, Wout ) else ( numChannels, Hout, Wout )
+  override def outputShape = if (isDepthWise) (C_DivideBy_M().toString, Hout, Wout) else (numChannels, Hout, Wout)
   // -------------------------------------------------
   def numChannels = bottomLayerOutputShape._1
-  def Hin = bottomLayerOutputShape._2
-  def Win = bottomLayerOutputShape._3
+  def Hin         = bottomLayerOutputShape._2
+  def Win         = bottomLayerOutputShape._3
   // Hout = strideh * (Hin-1) - 2*padh + Hf + out_padh
-  def Hout:String =  try { 
-    (stride_h.toInt * (Hin.toInt-1) - 2*pad_h.toInt + kernel_h.toInt).toString()
-  }
-  catch { 
-    case _:Throwable => stride_h + " * " +  "(" + Hin + "-1) - 2*" + pad_h + " + " + kernel_h
-  }
+  def Hout: String =
+    try {
+      (stride_h.toInt * (Hin.toInt - 1) - 2 * pad_h.toInt + kernel_h.toInt).toString()
+    } catch {
+      case _: Throwable => stride_h + " * " + "(" + Hin + "-1) - 2*" + pad_h + " + " + kernel_h
+    }
   // Wout = stridew * (Win-1) - 2*padw + Wf + out_padw
-  def Wout:String =  try { 
-    (stride_w.toInt * (Win.toInt-1) - 2*pad_w.toInt + kernel_w.toInt).toString()
-  }
-  catch { 
-    case _:Throwable => stride_w + " * " +  "(" + Win + "-1) - 2*" + pad_w + " + " + kernel_w
-  }
+  def Wout: String =
+    try {
+      (stride_w.toInt * (Win.toInt - 1) - 2 * pad_w.toInt + kernel_w.toInt).toString()
+    } catch {
+      case _: Throwable => stride_w + " * " + "(" + Win + "-1) - 2*" + pad_w + " + " + kernel_w
+    }
   // -------------------------------------------------
   def convParam = param.getConvolutionParam
   // num_output (c_o): the number of filters
   def numKernels = convParam.getNumOutput.toString
   // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
-  def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString 
-                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
-                   else throw new LanguageException("Incorrect kernel parameters")
-  def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString 
-                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
-                   else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_h =
+    if (convParam.hasKernelH) convParam.getKernelH.toString
+    else if (convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+    else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_w =
+    if (convParam.hasKernelW) convParam.getKernelW.toString
+    else if (convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+    else throw new LanguageException("Incorrect kernel parameters")
   // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
-  def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString 
-                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else "1"
-  def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString 
-                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else "1"
+  def stride_h =
+    if (convParam.hasStrideH) convParam.getStrideH.toString
+    else if (convParam.getStrideCount > 0) convParam.getStride(0).toString
+    else "1"
+  def stride_w =
+    if (convParam.hasStrideW) convParam.getStrideW.toString
+    else if (convParam.getStrideCount > 0) convParam.getStride(0).toString
+    else "1"
   // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
-  def pad_h =   if(convParam.hasPadH) convParam.getPadH.toString 
-                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else "0"
-  def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
-                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else "0"
+  def pad_h =
+    if (convParam.hasPadH) convParam.getPadH.toString
+    else if (convParam.getPadCount > 0) convParam.getPad(0).toString
+    else "0"
+  def pad_w =
+    if (convParam.hasPadW) convParam.getPadW.toString
+    else if (convParam.getPadCount > 0) convParam.getPad(0).toString
+    else "0"
 }


[4/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

Posted by ni...@apache.org.
[SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

- Also fixed scala formatting.

Closes #662.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f07b5a2d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f07b5a2d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f07b5a2d

Branch: refs/heads/master
Commit: f07b5a2d92f95f28bcdf141d700fc1be0887d735
Parents: ebb6ea6
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Fri Sep 15 11:00:06 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Fri Sep 15 11:01:49 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  732 ++++++-------
 .../apache/sysml/api/dl/Caffe2DMLLoader.scala   |   20 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 1002 ++++++++++--------
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  216 ++--
 .../org/apache/sysml/api/dl/CaffeSolver.scala   |  193 ++--
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  566 +++++-----
 .../scala/org/apache/sysml/api/dl/Utils.scala   |  484 ++++-----
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  264 +++--
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |   68 +-
 .../apache/sysml/api/ml/LinearRegression.scala  |   93 +-
 .../sysml/api/ml/LogisticRegression.scala       |  117 +-
 .../org/apache/sysml/api/ml/NaiveBayes.scala    |   62 +-
 .../apache/sysml/api/ml/PredictionUtils.scala   |   32 +-
 .../scala/org/apache/sysml/api/ml/SVM.scala     |   81 +-
 .../org/apache/sysml/api/ml/ScriptsUtils.scala  |   18 +-
 .../scala/org/apache/sysml/api/ml/Utils.scala   |   49 +-
 16 files changed, 2100 insertions(+), 1897 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index a62fae2..6e3e1dc 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -35,10 +35,10 @@ import java.util.HashSet
 import org.apache.sysml.api.DMLScript
 import java.io.File
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.ml.{ Estimator, Model }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.spark.ml.param.{ DoubleParam, Param, ParamMap, Params }
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics
 import org.apache.sysml.runtime.matrix.data.MatrixBlock
 import org.apache.sysml.runtime.DMLRuntimeException
@@ -55,7 +55,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 DESIGN OF CAFFE2DML:
 
 1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that were to be implemented are:
-- `getTrainingScript` for the Estimator class. 
+- `getTrainingScript` for the Estimator class.
 - `getPredictionScript` for the Model class.
 
 These methods should be the starting point of any developer to understand the DML generated for training and prediction respectively.
@@ -74,7 +74,7 @@ caffe.proto ---> protoc ---> target/generated-sources/caffe/Caffe.java
 - Just like the classes generated by Dml.g4 are used to parse input DML file,
 the target/generated-sources/caffe/Caffe.java class is used to parse the input caffe network/deploy prototxt and solver files.
 
-- You can think of .caffemodel file as DML file with matrix values encoded in it (please see below example). 
+- You can think of .caffemodel file as DML file with matrix values encoded in it (please see below example).
 So it is possible to read .caffemodel file with the Caffe.java class. This is done in Utils.scala's readCaffeNet method.
 
 X = matrix("1.2 3.5 0.999 7.123", rows=2, cols=2)
@@ -91,7 +91,7 @@ trait CaffeLayer {
   def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
   def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
   ...
-} 
+}
 trait CaffeSolver {
   def sourceFileName:String;
   def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
@@ -114,67 +114,85 @@ To shield from network files that violates this restriction, Caffe2DML performs
 6. Caffe2DML also expects the layers to be in sorted order.
 
 ***************************************************************************************/
-
-object Caffe2DML  {
-  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName()) 
+object Caffe2DML {
+  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
   // ------------------------------------------------------------------------
   def layerDir = "nn/layers/"
   def optimDir = "nn/optim/"
-  
+
   // Naming conventions:
-  val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
+  val X    = "X"; val y        = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
   val XVal = "X_val"; val yVal = "y_val"
-  
+
   val USE_NESTEROV_UDF = {
     // Developer environment variable flag 'USE_NESTEROV_UDF' until codegen starts working.
     // Then, we will remove this flag and also the class org.apache.sysml.udf.lib.SGDNesterovUpdate
     val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
     envFlagNesterovUDF != null && envFlagNesterovUDF.toBoolean
   }
-  
+
   def main(args: Array[String]): Unit = {
-	// Arguments: [train_script | predict_script] $OUTPUT_DML_FILE $SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
-	if(args.length < 6) throwUsageError
-	val outputDMLFile = args(1)
-	val solverFile = args(2)
-	val inputChannels = args(3)
-	val inputHeight = args(4)
-	val inputWidth = args(5)
-	val caffeObj = new Caffe2DML(new SparkContext(), solverFile, inputChannels, inputHeight, inputWidth)
-	if(args(0).equals("train_script")) {
-		Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, outputDMLFile)
-	}
-	else if(args(0).equals("predict_script")) {
-		Utils.writeToFile(new Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, outputDMLFile)
-	}
-	else {
-		throwUsageError
-	}
-  }
-  def throwUsageError():Unit = {
-	throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH"); 
+    // Arguments: [train_script | predict_script] $OUTPUT_DML_FILE $SOLVER_FILE $INPUT_CHANNELS $INPUT_HEIGHT $INPUT_WIDTH $NUM_ITER
+    if (args.length < 6) throwUsageError
+    val outputDMLFile = args(1)
+    val solverFile    = args(2)
+    val inputChannels = args(3)
+    val inputHeight   = args(4)
+    val inputWidth    = args(5)
+    val caffeObj      = new Caffe2DML(new SparkContext(), solverFile, inputChannels, inputHeight, inputWidth)
+    if (args(0).equals("train_script")) {
+      Utils.writeToFile(caffeObj.getTrainingScript(true)._1.getScriptString, outputDMLFile)
+    } else if (args(0).equals("predict_script")) {
+      Utils.writeToFile(new Caffe2DMLModel(caffeObj).getPredictionScript(true)._1.getScriptString, outputDMLFile)
+    } else {
+      throwUsageError
+    }
   }
+  def throwUsageError(): Unit =
+    throw new RuntimeException("Incorrect usage: train_script OUTPUT_DML_FILE SOLVER_FILE INPUT_CHANNELS INPUT_HEIGHT INPUT_WIDTH");
 }
 
-class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
-    val solver:CaffeSolver, val net:CaffeNetwork, 
-    val lrPolicy:LearningRatePolicy, val numChannels:String, val height:String, val width:String) extends Estimator[Caffe2DMLModel] 
-  with BaseSystemMLClassifier with DMLGenerator {
+class Caffe2DML(val sc: SparkContext,
+                val solverParam: Caffe.SolverParameter,
+                val solver: CaffeSolver,
+                val net: CaffeNetwork,
+                val lrPolicy: LearningRatePolicy,
+                val numChannels: String,
+                val height: String,
+                val width: String)
+    extends Estimator[Caffe2DMLModel]
+    with BaseSystemMLClassifier
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, networkPath:String, numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), 
-        new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, networkPath: String, numChannels: String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solver1:Caffe.SolverParameter, numChannels:String, height:String, width:String) {
-    this(sc, solver1, Utils.parseSolver(solver1), new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width), 
-        new LearningRatePolicy(solver1), numChannels, height, width)
+  def this(sc: SparkContext, solver1: Caffe.SolverParameter, numChannels: String, height: String, width: String) {
+    this(
+      sc,
+      solver1,
+      Utils.parseSolver(solver1),
+      new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
+      new LearningRatePolicy(solver1),
+      numChannels,
+      height,
+      width
+    )
   }
-  def this(sc: SparkContext, solverPath:String, numChannels:String, height:String, width:String) {
+  def this(sc: SparkContext, solverPath: String, numChannels: String, height: String, width: String) {
     this(sc, Utils.readCaffeSolver(solverPath), numChannels, height, width)
   }
-  val uid:String = "caffe_classifier_" + (new Random).nextLong
+  val uid: String = "caffe_classifier_" + (new Random).nextLong
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Estimator[Caffe2DMLModel] = {
     val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, numChannels, height, width)
     copyValues(that, extra)
@@ -188,221 +206,223 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
     mloutput = baseFit(df, sc)
     new Caffe2DMLModel(this)
   }
-	// --------------------------------------------------------------
+  // --------------------------------------------------------------
   // Returns true if last 2 of 4 dimensions are 1.
   // The first dimension refers to number of input datapoints.
   // The second dimension refers to number of classes.
-  def isClassification():Boolean = {
+  def isClassification(): Boolean = {
     val outShape = getOutputShapeOfLastLayer
     return outShape._2 == 1 && outShape._3 == 1
   }
-  def getOutputShapeOfLastLayer():(Int, Int, Int) = {
+  def getOutputShapeOfLastLayer(): (Int, Int, Int) = {
     val out = net.getCaffeLayer(net.getLayers().last).outputShape
-    (out._1.toInt, out._2.toInt, out._3.toInt) 
+    (out._1.toInt, out._2.toInt, out._3.toInt)
   }
-  
+
   // Used for simplifying transfer learning
-  private val layersToIgnore:HashSet[String] = new HashSet[String]() 
-  def setWeightsToIgnore(layerName:String):Unit = layersToIgnore.add(layerName)
-  def setWeightsToIgnore(layerNames:ArrayList[String]):Unit = layersToIgnore.addAll(layerNames)
-  	  
+  private val layersToIgnore: HashSet[String]                 = new HashSet[String]()
+  def setWeightsToIgnore(layerName: String): Unit             = layersToIgnore.add(layerName)
+  def setWeightsToIgnore(layerNames: ArrayList[String]): Unit = layersToIgnore.addAll(layerNames)
+
   // Input parameters to prediction and scoring script
-  val inputs:java.util.HashMap[String, String] = new java.util.HashMap[String, String]()
-  def setInput(key: String, value:String):Unit = inputs.put(key, value)
+  val inputs: java.util.HashMap[String, String]  = new java.util.HashMap[String, String]()
+  def setInput(key: String, value: String): Unit = inputs.put(key, value)
   customAssert(solverParam.getTestIterCount <= 1, "Multiple test_iter variables are not supported")
   customAssert(solverParam.getMaxIter > 0, "Please set max_iter to a positive value")
   customAssert(net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[IsLossLayer]).length == 1, "Expected exactly one loss layer")
-    
+
   // TODO: throw error or warning if user tries to set solver_mode == GPU instead of using setGPU method
-  
+
   // Method called by Python mllearn to visualize variable of certain layer
-  def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = visualizeLayer(net, layerName, varType, aggFn)
-  
-  def getTrainAlgo():String = if(inputs.containsKey("$train_algo")) inputs.get("$train_algo") else "minibatch"
-  def getTestAlgo():String = if(inputs.containsKey("$test_algo")) inputs.get("$test_algo") else "minibatch"
+  def visualizeLayer(layerName: String, varType: String, aggFn: String): Unit = visualizeLayer(net, layerName, varType, aggFn)
 
-  def summary(sparkSession:org.apache.spark.sql.SparkSession):Unit = {
+  def getTrainAlgo(): String = if (inputs.containsKey("$train_algo")) inputs.get("$train_algo") else "minibatch"
+  def getTestAlgo(): String  = if (inputs.containsKey("$test_algo")) inputs.get("$test_algo") else "minibatch"
+
+  def summary(sparkSession: org.apache.spark.sql.SparkSession): Unit = {
     val header = Seq("Name", "Type", "Output", "Weight", "Bias", "Top", "Bottom")
-    val entries = net.getLayers.map(l => (l, net.getCaffeLayer(l))).map(l => {
-      val layer = l._2
-      (l._1, layer.param.getType, 
-          "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + layer.outputShape._3 + ")",
-          if(layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + layer.weightShape()(1) + "]" else "",
-          if(layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + layer.biasShape()(1) + "]" else "",
-          layer.param.getTopList.mkString(","),
-          layer.param.getBottomList.mkString(",")
-      )
-    })
+    val entries = net.getLayers
+      .map(l => (l, net.getCaffeLayer(l)))
+      .map(l => {
+        val layer = l._2
+        (l._1,
+         layer.param.getType,
+         "(, " + layer.outputShape._1 + ", " + layer.outputShape._2 + ", " + layer.outputShape._3 + ")",
+         if (layer.weightShape != null) "[" + layer.weightShape()(0) + " X " + layer.weightShape()(1) + "]" else "",
+         if (layer.biasShape != null) "[" + layer.biasShape()(0) + " X " + layer.biasShape()(1) + "]" else "",
+         layer.param.getTopList.mkString(","),
+         layer.param.getBottomList.mkString(","))
+      })
     import sparkSession.implicits._
-    sc.parallelize(entries).toDF(header : _*).show(net.getLayers.size)
+    sc.parallelize(entries).toDF(header: _*).show(net.getLayers.size)
   }
-  
+
   // ================================================================================================
   // The below method parses the provided network and solver file and generates DML script.
-	def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
-	  val startTrainingTime = System.nanoTime()
-	  
-    reset                                 // Reset the state of DML generator for training script.
-    
+  def getTrainingScript(isSingleNode: Boolean): (Script, String, String) = {
+    val startTrainingTime = System.nanoTime()
+
+    reset // Reset the state of DML generator for training script.
+
     // Flags passed by user
-	  val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
-	  assign(tabDMLScript, "debug", if(DEBUG_TRAINING) "TRUE" else "FALSE")
-	  
-	  appendHeaders(net, solver, true)      // Appends DML corresponding to source and externalFunction statements.
-	  readInputData(net, true)              // Read X_full and y_full
-	  // Initialize the layers and solvers. Reads weights and bias if $weights is set.
-	  initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
-	  
-	  // Split into training and validation set
-	  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
-	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
-	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
-	  
-	  // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
-	  ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
-	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-	  assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
-	  assign(tabDMLScript, "e", "0")
-	  
-	  val lossLayers = getLossLayers(net)
-	  // ----------------------------------------------------------------------------
-	  // Main logic
-	  forBlock("iter", "1", "max_iter") {
-		performTrainingIter(lossLayers, shouldValidate)
-		if(getTrainAlgo.toLowerCase.equals("batch")) {
-			assign(tabDMLScript, "e", "iter")
-			tabDMLScript.append("# Learning rate\n")
-			lrPolicy.updateLearningRate(tabDMLScript)
-		}
-		else {
-			ifBlock("iter %% num_iters_per_epoch == 0") {
-				// After every epoch, update the learning rate
-				assign(tabDMLScript, "e", "e + 1")
-				tabDMLScript.append("# Learning rate\n")
-				lrPolicy.updateLearningRate(tabDMLScript)
-			}
-		}
-	  }
-	  // ----------------------------------------------------------------------------
-	  
-	  // Check if this is necessary
-	  if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
-	  
-	  val trainingScript = tabDMLScript.toString()
-	  // Print script generation time and the DML script on stdout
-	  System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
-	  if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
-	  
-	  // Set input/output variables and execute the script
-	  val script = dml(trainingScript).in(inputs)
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
-	  (script, "X_full", "y_full")
-	}
-	// ================================================================================================
-  
-  private def performTrainingIter(lossLayers:List[IsLossLayer], shouldValidate:Boolean):Unit = {
-	getTrainAlgo.toLowerCase match {
-      case "minibatch" => 
-          getTrainingBatch(tabDMLScript)
-          // -------------------------------------------------------
-          // Perform forward, backward and update on minibatch
-          forward; backward; update
-          // -------------------------------------------------------
-          displayLoss(lossLayers(0), shouldValidate)
-          performSnapshot
+    val DEBUG_TRAINING = if (inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_TRAINING) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, true) // Appends DML corresponding to source and externalFunction statements.
+    readInputData(net, true)         // Read X_full and y_full
+    // Initialize the layers and solvers. Reads weights and bias if $weights is set.
+    initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
+
+    // Split into training and validation set
+    // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+    val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+    trainTestSplit(if (shouldValidate) solverParam.getTestIter(0) else 0)
+
+    // Set iteration-related variables such as num_iters_per_epoch, lr, etc.
+    ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
+    assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+    assign(tabDMLScript, "max_iter", ifdef("$max_iter", solverParam.getMaxIter.toString))
+    assign(tabDMLScript, "e", "0")
+
+    val lossLayers = getLossLayers(net)
+    // ----------------------------------------------------------------------------
+    // Main logic
+    forBlock("iter", "1", "max_iter") {
+      performTrainingIter(lossLayers, shouldValidate)
+      if (getTrainAlgo.toLowerCase.equals("batch")) {
+        assign(tabDMLScript, "e", "iter")
+        tabDMLScript.append("# Learning rate\n")
+        lrPolicy.updateLearningRate(tabDMLScript)
+      } else {
+        ifBlock("iter %% num_iters_per_epoch == 0") {
+          // After every epoch, update the learning rate
+          assign(tabDMLScript, "e", "e + 1")
+          tabDMLScript.append("# Learning rate\n")
+          lrPolicy.updateLearningRate(tabDMLScript)
+        }
+      }
+    }
+    // ----------------------------------------------------------------------------
+
+    // Check if this is necessary
+    if (doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
+
+    val trainingScript = tabDMLScript.toString()
+    // Print script generation time and the DML script on stdout
+    System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime) * 1e-9) + " seconds.")
+    if (DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+
+    // Set input/output variables and execute the script
+    val script = dml(trainingScript).in(inputs)
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
+    (script, "X_full", "y_full")
+  }
+  // ================================================================================================
+
+  private def performTrainingIter(lossLayers: List[IsLossLayer], shouldValidate: Boolean): Unit =
+    getTrainAlgo.toLowerCase match {
+      case "minibatch" =>
+        getTrainingBatch(tabDMLScript)
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       case "batch" => {
-	      // -------------------------------------------------------
-	      // Perform forward, backward and update on entire dataset
-	      forward; backward; update
-	      // -------------------------------------------------------
-	      displayLoss(lossLayers(0), shouldValidate)
-	      performSnapshot
+        // -------------------------------------------------------
+        // Perform forward, backward and update on entire dataset
+        forward; backward; update
+        // -------------------------------------------------------
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case "allreduce_parallel_batches" => {
-    	  // This setting uses the batch size provided by the user
-	      if(!inputs.containsKey("$parallel_batches")) {
-	        throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
-	      }
-	      // The user specifies the number of parallel_batches
-	      // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-	      assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
-	      assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
-	      assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
-	      // Grab groups of mini-batches
-	      forBlock("g", "1", "groups") {
-	        // Get next group of mini-batches
-	        assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
-	        assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
-	        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
-	        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
-	        initializeGradients("parallel_batches")
-	        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-	        parForBlock("j", "1", "parallel_batches") {
-	          // Get a mini-batch in this group
-	          assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
-	          assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
-	          assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
-	          assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
-	          forward; backward
-	          flattenGradients
-	        }
-	        aggregateAggGradients    
-	        update
-	        // -------------------------------------------------------
-	        assign(tabDMLScript, "Xb", "X_group_batch")
-	        assign(tabDMLScript, "yb", "y_group_batch")
-	        displayLoss(lossLayers(0), shouldValidate)
-	        performSnapshot
-	      }
-      }
-      case "allreduce" => {
-    	  // This is distributed synchronous gradient descent
-    	  // -------------------------------------------------------
-    	  // Perform forward, backward and update on minibatch in parallel
-    	  assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
-    	  assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
-    	  assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
-    	  assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
-    	  assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
-          tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
-          val localBatchSize = "local_batch_size"
-          initializeGradients(localBatchSize)
-          parForBlock("j", "1", localBatchSize) {
-            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
-            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+        // This setting uses the batch size provided by the user
+        if (!inputs.containsKey("$parallel_batches")) {
+          throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
+        }
+        // The user specifies the number of parallel_batches
+        // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
+        assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
+        assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
+        // Grab groups of mini-batches
+        forBlock("g", "1", "groups") {
+          // Get next group of mini-batches
+          assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + Caffe2DML.numImages + " + 1")
+          assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", group_beg + group_batch_size - 1)")
+          assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[group_beg:group_end,]")
+          assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[group_beg:group_end,]")
+          initializeGradients("parallel_batches")
+          assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+          parForBlock("j", "1", "parallel_batches") {
+            // Get a mini-batch in this group
+            assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") %% nrow(X_group_batch) + 1")
+            assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + Caffe2DML.batchSize + " - 1)")
+            assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+            assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
             forward; backward
-          flattenGradients
+            flattenGradients
           }
-          aggregateAggGradients    
+          aggregateAggGradients
           update
           // -------------------------------------------------------
           assign(tabDMLScript, "Xb", "X_group_batch")
           assign(tabDMLScript, "yb", "y_group_batch")
           displayLoss(lossLayers(0), shouldValidate)
           performSnapshot
+        }
+      }
+      case "allreduce" => {
+        // This is distributed synchronous gradient descent
+        // -------------------------------------------------------
+        // Perform forward, backward and update on minibatch in parallel
+        assign(tabDMLScript, "beg", "((iter-1) * " + Caffe2DML.batchSize + ") %% " + Caffe2DML.numImages + " + 1")
+        assign(tabDMLScript, "end", " min(beg +  " + Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
+        assign(tabDMLScript, "X_group_batch", Caffe2DML.X + "[beg:end,]")
+        assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+        assign(tabDMLScript, "X_group_batch_size", nrow("X_group_batch"))
+        tabDMLScript.append("local_batch_size = nrow(y_group_batch)\n")
+        val localBatchSize = "local_batch_size"
+        initializeGradients(localBatchSize)
+        parForBlock("j", "1", localBatchSize) {
+          assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+          assign(tabDMLScript, "yb", "y_group_batch[j,]")
+          forward; backward
+          flattenGradients
+        }
+        aggregateAggGradients
+        update
+        // -------------------------------------------------------
+        assign(tabDMLScript, "Xb", "X_group_batch")
+        assign(tabDMLScript, "yb", "y_group_batch")
+        displayLoss(lossLayers(0), shouldValidate)
+        performSnapshot
       }
       case _ => throw new DMLRuntimeException("Unsupported train algo:" + getTrainAlgo)
     }
-  }
   // -------------------------------------------------------------------------------------------
   // Helper functions to generate DML
   // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
-  private def trainTestSplit(numValidationBatches:Int):Unit = {
-    if(numValidationBatches > 0) {
-      if(solverParam.getDisplay <= 0) 
+  private def trainTestSplit(numValidationBatches: Int): Unit =
+    if (numValidationBatches > 0) {
+      if (solverParam.getDisplay <= 0)
         throw new DMLRuntimeException("Since test_iter and test_interval is greater than zero, you should set display to be greater than zero")
       tabDMLScript.append(Caffe2DML.numValidationImages).append(" = " + numValidationBatches + " * " + Caffe2DML.batchSize + "\n")
       tabDMLScript.append("# Sanity check to ensure that validation set is not too large\n")
       val maxValidationSize = "ceil(0.3 * " + Caffe2DML.numImages + ")"
-      ifBlock(Caffe2DML.numValidationImages  + " > " + maxValidationSize) {
+      ifBlock(Caffe2DML.numValidationImages + " > " + maxValidationSize) {
         assign(tabDMLScript, "max_test_iter", "floor(" + maxValidationSize + " / " + Caffe2DML.batchSize + ")")
-        tabDMLScript.append("stop(" +
-            dmlConcat(asDMLString("Too large validation size. Please reduce test_iter to "), "max_test_iter") 
-            + ")\n")
+        tabDMLScript.append(
+          "stop(" +
+          dmlConcat(asDMLString("Too large validation size. Please reduce test_iter to "), "max_test_iter")
+          + ")\n"
+        )
       }
       val one = "1"
-      val rl = int_add(Caffe2DML.numValidationImages, one)
+      val rl  = int_add(Caffe2DML.numValidationImages, one)
       rightIndexing(tabDMLScript.append(Caffe2DML.X).append(" = "), "X_full", rl, Caffe2DML.numImages, null, null)
       tabDMLScript.append("; ")
       rightIndexing(tabDMLScript.append(Caffe2DML.y).append(" = "), "y_full", rl, Caffe2DML.numImages, null, null)
@@ -412,41 +432,39 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
       rightIndexing(tabDMLScript.append(Caffe2DML.yVal).append(" = "), "y_full", one, Caffe2DML.numValidationImages, null, null)
       tabDMLScript.append("; ")
       tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y)\n")
-    }
-    else {
+    } else {
       assign(tabDMLScript, Caffe2DML.X, "X_full")
-	    assign(tabDMLScript, Caffe2DML.y, "y_full")
-	    tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y + ")\n")
+      assign(tabDMLScript, Caffe2DML.y, "y_full")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y + ")\n")
     }
-  }
-  
+
   // Append the DML to display training and validation loss
-  private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit = {
-    if(solverParam.getDisplay > 0) {
+  private def displayLoss(lossLayer: IsLossLayer, shouldValidate: Boolean): Unit = {
+    if (solverParam.getDisplay > 0) {
       // Append the DML to compute training loss
-      if(!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
+      if (!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
         // Compute training loss for allreduce
         tabDMLScript.append("# Compute training loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
           lossLayer.computeLoss(dmlScript, numTabs)
           assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, "training_accuracy", "accuracy")
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", training loss:"), "training_loss", asDMLString(", training accuracy:"), "training_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", training loss:"), "training_loss", asDMLString(", training accuracy:"), "training_accuracy"))
+          )
           appendTrainingVisualizationBody(dmlScript, numTabs)
           printClassificationReport
         }
-      }
-      else {
+      } else {
         Caffe2DML.LOG.info("Training loss is not printed for train_algo=" + getTrainAlgo)
       }
-      if(shouldValidate) {
-        if(  getTrainAlgo.toLowerCase.startsWith("allreduce") &&
+      if (shouldValidate) {
+        if (getTrainAlgo.toLowerCase.startsWith("allreduce") &&
             getTestAlgo.toLowerCase.startsWith("allreduce")) {
           Caffe2DML.LOG.warn("The setting: train_algo=" + getTrainAlgo + " and test_algo=" + getTestAlgo + " is not recommended. Consider changing test_algo=minibatch")
         }
         // Append the DML to compute validation loss
-        val numValidationBatches = if(solverParam.getTestIterCount > 0) solverParam.getTestIter(0) else 0
+        val numValidationBatches = if (solverParam.getTestIterCount > 0) solverParam.getTestIter(0) else 0
         tabDMLScript.append("# Compute validation loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getTestInterval + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
@@ -455,11 +473,11 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "0")
               assign(tabDMLScript, "validation_accuracy", "0")
               forBlock("iVal", "1", "num_iters_per_epoch") {
-    	          getValidationBatch(tabDMLScript)
-    	          forward;  lossLayer.computeLoss(dmlScript, numTabs)
+                getValidationBatch(tabDMLScript)
+                forward; lossLayer.computeLoss(dmlScript, numTabs)
                 tabDMLScript.append("validation_loss = validation_loss + loss\n")
                 tabDMLScript.append("validation_accuracy = validation_accuracy + accuracy\n")
-    	        }
+              }
               tabDMLScript.append("validation_accuracy = validation_accuracy / num_iters_per_epoch\n")
             }
             case "batch" => {
@@ -467,16 +485,16 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
               lossLayer.computeLoss(dmlScript, numTabs)
               assign(tabDMLScript, "validation_loss", "loss"); assign(tabDMLScript, "validation_accuracy", "accuracy")
-              
+
             }
             case "allreduce_parallel_batches" => {
               // This setting uses the batch size provided by the user
-              if(!inputs.containsKey("$parallel_batches")) {
+              if (!inputs.containsKey("$parallel_batches")) {
                 throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
               }
               // The user specifies the number of parallel_batches
               // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-              assign(tabDMLScript, "parallel_batches_val", "$parallel_batches") 
+              assign(tabDMLScript, "parallel_batches_val", "$parallel_batches")
               assign(tabDMLScript, "group_batch_size_val", "parallel_batches_val*" + Caffe2DML.batchSize)
               assign(tabDMLScript, "groups_val", "as.integer(ceil(" + Caffe2DML.numValidationImages + "/group_batch_size_val))")
               assign(tabDMLScript, "validation_accuracy", "0")
@@ -511,8 +529,8 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "group_validation_loss", matrix("0", Caffe2DML.numValidationImages, "1"))
               assign(tabDMLScript, "group_validation_accuracy", matrix("0", Caffe2DML.numValidationImages, "1"))
               parForBlock("iVal", "1", Caffe2DML.numValidationImages) {
-                assign(tabDMLScript, "Xb",  Caffe2DML.XVal + "[iVal,]")
-                assign(tabDMLScript, "yb",  Caffe2DML.yVal + "[iVal,]")
+                assign(tabDMLScript, "Xb", Caffe2DML.XVal + "[iVal,]")
+                assign(tabDMLScript, "yb", Caffe2DML.yVal + "[iVal,]")
                 net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
                 lossLayer.computeLoss(dmlScript, numTabs)
                 assign(tabDMLScript, "group_validation_loss[iVal,1]", "loss")
@@ -521,124 +539,132 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "sum(group_validation_loss)")
               assign(tabDMLScript, "validation_accuracy", "mean(group_validation_accuracy)")
             }
-            
+
             case _ => throw new DMLRuntimeException("Unsupported test algo:" + getTestAlgo)
           }
-          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-              asDMLString(", validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), "validation_accuracy" )))
+          tabDMLScript.append(
+            print(dmlConcat(asDMLString("Iter:"), "iter", asDMLString(", validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), "validation_accuracy"))
+          )
           appendValidationVisualizationBody(dmlScript, numTabs)
         }
       }
     }
   }
-  
-  private def performSnapshot():Unit = {
-    if(solverParam.getSnapshot > 0) {
+  private def appendSnapshotWrite(varName: String, fileName: String): Unit =
+    tabDMLScript.append(write(varName, "snapshot_dir + \"" + fileName + "\"", "binary"))
+  private def performSnapshot(): Unit =
+    if (solverParam.getSnapshot > 0) {
       ifBlock("iter %% " + solverParam.getSnapshot + " == 0") {
         tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix + "\" + \"/iter_\" + iter + \"/\"\n")
-        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(
-        	"write(" + l.weight + ", snapshot_dir + \"" + l.param.getName + "_weight.mtx\", format=\"binary\")\n"))
-  		net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(
-  			"write(" + l.bias + ", snapshot_dir + \"" + l.param.getName + "_bias.mtx\", format=\"binary\")\n"))
+        val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+        allLayers.filter(_.weight != null).map(l => appendSnapshotWrite(l.weight, l.param.getName + "_weight.mtx"))
+        allLayers.filter(_.bias != null).map(l => appendSnapshotWrite(l.bias, l.param.getName + "_bias.mtx"))
       }
-  	}
-  }
-  
-  private def forward():Unit = {
+    }
+
+  private def forward(): Unit = {
     tabDMLScript.append("# Perform forward pass\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
+    net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
   }
-  private def backward():Unit = {
+  private def backward(): Unit = {
     tabDMLScript.append("# Perform backward pass\n")
     net.getLayers.reverse.map(layer => net.getCaffeLayer(layer).backward(tabDMLScript, ""))
   }
-  private def update():Unit = {
+  private def update(): Unit = {
     tabDMLScript.append("# Update the parameters\n")
     net.getLayers.map(layer => solver.update(tabDMLScript, net.getCaffeLayer(layer)))
   }
-  private def initializeGradients(parallel_batches:String):Unit = {
+  private def initializeGradients(parallel_batches: String): Unit = {
     tabDMLScript.append("# Data structure to store gradients computed in parallel\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
-      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias)))) 
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
+        if (l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias))))
+      })
   }
-  private def flattenGradients():Unit = {
+  private def flattenGradients(): Unit = {
     tabDMLScript.append("# Flatten and store gradients for this parallel execution\n")
     // Note: We multiply by a weighting to allow for proper gradient averaging during the
     // aggregation even with uneven batch sizes.
     assign(tabDMLScript, "weighting", "nrow(Xb)/X_group_batch_size")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
-          matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * weighting") 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", 
-          matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias)))  + " * weighting")
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight))) + " * weighting")
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))) + " * weighting")
+      })
   }
-  private def aggregateAggGradients():Unit = {
+  private def aggregateAggGradients(): Unit = {
     tabDMLScript.append("# Aggregate the gradients\n")
-    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, 
-          matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight))) 
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, 
-          matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
-    })
+    net.getLayers
+      .map(layer => net.getCaffeLayer(layer))
+      .map(l => {
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight)))
+        if (l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
+      })
   }
   // -------------------------------------------------------------------------------------------
 }
 
-class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:CaffeSolver,
-    val net:CaffeNetwork, val lrPolicy:LearningRatePolicy,
-    val estimator:Caffe2DML) 
-  extends Model[Caffe2DMLModel] with HasMaxOuterIter with BaseSystemMLClassifierModel with DMLGenerator {
+class Caffe2DMLModel(val numClasses: String, val sc: SparkContext, val solver: CaffeSolver, val net: CaffeNetwork, val lrPolicy: LearningRatePolicy, val estimator: Caffe2DML)
+    extends Model[Caffe2DMLModel]
+    with HasMaxOuterIter
+    with BaseSystemMLClassifierModel
+    with DMLGenerator {
   // --------------------------------------------------------------
   // Invoked by Python, MLPipeline
-  val uid:String = "caffe_model_" + (new Random).nextLong 
-  def this(estimator:Caffe2DML) =  {
-    this(Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
-        estimator.net,
-        // new CaffeNetwork(estimator.solverParam.getNet, caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, estimator.width), 
-        estimator.lrPolicy, estimator) 
+  val uid: String = "caffe_model_" + (new Random).nextLong
+  def this(estimator: Caffe2DML) = {
+    this(
+      Utils.numClasses(estimator.net),
+      estimator.sc,
+      estimator.solver,
+      estimator.net,
+      // new CaffeNetwork(estimator.solverParam.getNet, caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, estimator.width),
+      estimator.lrPolicy,
+      estimator
+    )
   }
-      
+
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Caffe2DMLModel = {
     val that = new Caffe2DMLModel(numClasses, sc, solver, net, lrPolicy, estimator)
     copyValues(that, extra)
   }
   // --------------------------------------------------------------
-  
-  def modelVariables():List[String] = {
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(_.weight) ++
-    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(_.bias)
+
+  def modelVariables(): List[String] = {
+    val allLayers = net.getLayers.map(net.getCaffeLayer(_))
+    allLayers.filter(_.weight != null).map(_.weight) ++ allLayers.filter(_.bias != null).map(_.bias)
   }
-    
+
   // ================================================================================================
   // The below method parses the provided network and solver file and generates DML script.
-  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
+  def getPredictionScript(isSingleNode: Boolean): (Script, String) = {
     val startPredictionTime = System.nanoTime()
-    
-	  reset                                  // Reset the state of DML generator for training script.
-	  
-	  val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
-	  assign(tabDMLScript, "debug", if(DEBUG_PREDICTION) "TRUE" else "FALSE")
-    
-    appendHeaders(net, solver, false)      // Appends DML corresponding to source and externalFunction statements.
-    readInputData(net, false)              // Read X_full and y_full
+
+    reset // Reset the state of DML generator for training script.
+
+    val DEBUG_PREDICTION = if (estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+    assign(tabDMLScript, "debug", if (DEBUG_PREDICTION) "TRUE" else "FALSE")
+
+    appendHeaders(net, solver, false) // Appends DML corresponding to source and externalFunction statements.
+    readInputData(net, false)         // Read X_full and y_full
     assign(tabDMLScript, "X", "X_full")
-    
+
     // Initialize the layers and solvers. Reads weights and bias if readWeights is true.
-    if(!estimator.inputs.containsKey("$weights") && estimator.mloutput == null) 
+    if (!estimator.inputs.containsKey("$weights") && estimator.mloutput == null)
       throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
     val readWeights = estimator.inputs.containsKey("$weights") || estimator.mloutput != null
     initWeights(net, solver, readWeights)
-	  
-	  // Donot update mean and variance in batchnorm
-	  updateMeanVarianceForBatchNorm(net, false)
-	  
-	  val lossLayers = getLossLayers(net)
-	  val lastLayerShape = estimator.getOutputShapeOfLastLayer
-	  assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, (lastLayerShape._1*lastLayerShape._2*lastLayerShape._3).toString))
-	  estimator.getTestAlgo.toLowerCase match {
+
+    // Donot update mean and variance in batchnorm
+    updateMeanVarianceForBatchNorm(net, false)
+
+    val lossLayers     = getLossLayers(net)
+    val lastLayerShape = estimator.getOutputShapeOfLastLayer
+    assign(tabDMLScript, "Prob", matrix("1", Caffe2DML.numImages, (lastLayerShape._1 * lastLayerShape._2 * lastLayerShape._3).toString))
+    estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
         forBlock("iter", "1", "num_iters") {
@@ -654,12 +680,12 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
       }
       case "allreduce_parallel_batches" => {
         // This setting uses the batch size provided by the user
-        if(!estimator.inputs.containsKey("$parallel_batches")) {
+        if (!estimator.inputs.containsKey("$parallel_batches")) {
           throw new RuntimeException("The parameter parallel_batches is required for allreduce_parallel_batches")
         }
         // The user specifies the number of parallel_batches
         // This ensures that the user of generated script remembers to provide the commandline parameter $parallel_batches
-        assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches")
         assign(tabDMLScript, "group_batch_size", "parallel_batches*" + Caffe2DML.batchSize)
         assign(tabDMLScript, "groups", "as.integer(ceil(" + Caffe2DML.numImages + "/group_batch_size))")
         // Grab groups of mini-batches
@@ -688,70 +714,66 @@ class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val solver:Caf
       }
       case _ => throw new DMLRuntimeException("Unsupported test algo:" + estimator.getTestAlgo)
     }
-    
-    if(estimator.inputs.containsKey("$output_activations")) {
-      if(estimator.getTestAlgo.toLowerCase.equals("batch")) {
-        net.getLayers.map(layer => 
-          tabDMLScript.append(write(net.getCaffeLayer(layer).out, 
-              estimator.inputs.get("$output_activations") + "/" + net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n")
-        )  
-      }
-      else {
+
+    if (estimator.inputs.containsKey("$output_activations")) {
+      if (estimator.getTestAlgo.toLowerCase.equals("batch")) {
+        net.getLayers.map(
+          layer =>
+            tabDMLScript.append(
+              write(net.getCaffeLayer(layer).out, estimator.inputs.get("$output_activations") + "/" + net.getCaffeLayer(layer).param.getName + "_activations.mtx", "csv") + "\n"
+          )
+        )
+      } else {
         throw new DMLRuntimeException("Incorrect usage of output_activations. It should be only used in batch mode.")
       }
     }
-		
-		val predictionScript = dmlScript.toString()
-		System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." )
-		if(DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
-		
-		// Reset state of BatchNorm layer
-		updateMeanVarianceForBatchNorm(net, true)
-		
-	  val script = dml(predictionScript).out("Prob").in(estimator.inputs)
-	  if(estimator.mloutput != null) {
-	    // fit was called
-  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
-  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
-	  }
-	  (script, "X_full")
+
+    val predictionScript = dmlScript.toString()
+    System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime) * 1e-9) + "secs.")
+    if (DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
+
+    // Reset state of BatchNorm layer
+    updateMeanVarianceForBatchNorm(net, true)
+
+    val script = dml(predictionScript).out("Prob").in(estimator.inputs)
+    if (estimator.mloutput != null) {
+      // fit was called
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
+      net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
+    }
+    (script, "X_full")
   }
   // ================================================================================================
-  
-  def baseEstimator():BaseSystemMLEstimator = estimator
-  
+
+  def baseEstimator(): BaseSystemMLEstimator = estimator
+
   // Prediction
-  def transform(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(X, sc, "Prob", outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  }
-  def transform_probability(X: MatrixBlock): MatrixBlock = {
-    if(estimator.isClassification) {
+  def transform_probability(X: MatrixBlock): MatrixBlock =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction of probability assuming classification")
       baseTransformProbability(X, sc, "Prob")
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction of probability assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransformProbability(X, sc, "Prob", outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  } 
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = {
-    if(estimator.isClassification) {
+
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame =
+    if (estimator.isClassification) {
       Caffe2DML.LOG.debug("Prediction assuming classification")
       baseTransform(df, sc, "Prob", true)
-    }
-    else {
+    } else {
       Caffe2DML.LOG.debug("Prediction assuming segmentation")
       val outShape = estimator.getOutputShapeOfLastLayer
       baseTransform(df, sc, "Prob", true, outShape._1.toInt, outShape._2.toInt, outShape._3.toInt)
     }
-  }
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
index 30d86fd..19aff63 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DMLLoader.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -22,16 +22,16 @@ package org.apache.sysml.api.dl
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.net.MalformedURLException;
-import java.net.URL; 
+import java.net.URL;
 import java.net.URLClassLoader;
 import java.io.File;
 
 class Caffe2DMLLoader {
-  def loadCaffe2DML(filePath:String):Unit = {
-    val url = new File(filePath).toURI().toURL();
-		val classLoader = ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
-		val method = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]);
-		method.setAccessible(true);
-	  method.invoke(classLoader, url);
+  def loadCaffe2DML(filePath: String): Unit = {
+    val url         = new File(filePath).toURI().toURL();
+    val classLoader = ClassLoader.getSystemClassLoader().asInstanceOf[URLClassLoader];
+    val method      = classOf[URLClassLoader].getDeclaredMethod("addURL", classOf[URL]);
+    method.setAccessible(true);
+    method.invoke(classLoader, url);
   }
-}
\ No newline at end of file
+}


[2/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

Posted by ni...@apache.org.
http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index 67682d5..2b07788 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,205 +33,197 @@ import org.apache.commons.logging.LogFactory
 
 trait Network {
   def getLayers(): List[String]
-  def getCaffeLayer(layerName:String):CaffeLayer
-  def getBottomLayers(layerName:String): Set[String]
-  def getTopLayers(layerName:String): Set[String]
-  def getLayerID(layerName:String): Int
+  def getCaffeLayer(layerName: String): CaffeLayer
+  def getBottomLayers(layerName: String): Set[String]
+  def getTopLayers(layerName: String): Set[String]
+  def getLayerID(layerName: String): Int
 }
 
 object CaffeNetwork {
   val LOG = LogFactory.getLog(classOf[CaffeNetwork].getName)
 }
 
-class CaffeNetwork(netFilePath:String, val currentPhase:Phase, 
-     var numChannels:String, var height:String, var width:String
-    ) extends Network {
-  private def isIncludedInCurrentPhase(l:LayerParameter): Boolean = {
-    if(currentPhase == null) return true // while deployment
-    else if(l.getIncludeCount == 0) true 
+class CaffeNetwork(netFilePath: String, val currentPhase: Phase, var numChannels: String, var height: String, var width: String) extends Network {
+  private def isIncludedInCurrentPhase(l: LayerParameter): Boolean =
+    if (currentPhase == null) return true // while deployment
+    else if (l.getIncludeCount == 0) true
     else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
-  }
   private var id = 1
-  def this(deployFilePath:String) {
+  def this(deployFilePath: String) {
     this(deployFilePath, null, null, null, null)
   }
   // --------------------------------------------------------------------------------
-  private var _net:NetParameter = Utils.readCaffeNet(netFilePath)
-  private var _caffeLayerParams:List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+  private var _net: NetParameter                      = Utils.readCaffeNet(netFilePath)
+  private var _caffeLayerParams: List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
   // This method is used if the user doesnot provide number of channels, height and width
-  private def setCHW(inputShapes:java.util.List[caffe.Caffe.BlobShape]):Unit = {
-    if(inputShapes.size != 1)
-        throw new DMLRuntimeException("Expected only one input shape")
+  private def setCHW(inputShapes: java.util.List[caffe.Caffe.BlobShape]): Unit = {
+    if (inputShapes.size != 1)
+      throw new DMLRuntimeException("Expected only one input shape")
     val inputShape = inputShapes.get(0)
-    if(inputShape.getDimCount != 4)
+    if (inputShape.getDimCount != 4)
       throw new DMLRuntimeException("Expected the input shape of dimension 4")
     numChannels = inputShape.getDim(1).toString
     height = inputShape.getDim(2).toString
     width = inputShape.getDim(3).toString
   }
-  if(numChannels == null && height == null && width == null) {
-    val inputLayer:List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
-    if(inputLayer.size == 1) {
+  if (numChannels == null && height == null && width == null) {
+    val inputLayer: List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
+    if (inputLayer.size == 1) {
       setCHW(inputLayer(0).getInputParam.getShapeList)
-    }
-    else if(inputLayer.size == 0) {
-      throw new DMLRuntimeException("Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer.")
-    }
-    else {
+    } else if (inputLayer.size == 0) {
+      throw new DMLRuntimeException(
+        "Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer."
+      )
+    } else {
       throw new DMLRuntimeException("Multiple Input layer is not supported")
     }
   }
   // --------------------------------------------------------------------------------
-  
+
   private var _layerNames: List[String] = _caffeLayerParams.map(l => l.getName).toList
   CaffeNetwork.LOG.debug("Layers in current phase:" + _layerNames)
-  
+
   // Condition 1: assert that each name is unique
   private val _duplicateLayerNames = _layerNames.diff(_layerNames.distinct)
-  if(_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
-  
+  if (_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
+
   // Condition 2: only 1 top name, except Data layer
   private val _condition2Exceptions = Set("data")
-  _caffeLayerParams.filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase)).map(l => if(l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
+  _caffeLayerParams
+    .filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase))
+    .map(l => if (l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
 
   // Condition 3: Replace top layer names referring to a Data layer with its name
   // Example: layer{ name: mnist, top: data, top: label, ... }
-  private val _topToNameMappingForDataLayer = new HashMap[String, String]()
-  private def containsOnly(list:java.util.List[String], v:String): Boolean = list.toSet.diff(Set(v)).size() == 0
-  private def isData(l:LayerParameter):Boolean = l.getType.equalsIgnoreCase("data")
-  private def replaceTopWithNameOfDataLayer(l:LayerParameter):LayerParameter =  {
-    if(containsOnly(l.getTopList,l.getName))
+  private val _topToNameMappingForDataLayer                                  = new HashMap[String, String]()
+  private def containsOnly(list: java.util.List[String], v: String): Boolean = list.toSet.diff(Set(v)).size() == 0
+  private def isData(l: LayerParameter): Boolean                             = l.getType.equalsIgnoreCase("data")
+  private def replaceTopWithNameOfDataLayer(l: LayerParameter): LayerParameter =
+    if (containsOnly(l.getTopList, l.getName))
       return l
     else {
-      val builder = l.toBuilder(); 
-      for(i <- 0 until l.getTopCount) {
-        if(! l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
-        builder.setTop(i, l.getName) 
+      val builder = l.toBuilder();
+      for (i <- 0 until l.getTopCount) {
+        if (!l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
+        builder.setTop(i, l.getName)
       }
-      return builder.build() 
+      return builder.build()
     }
-  }
   // 3a: Replace top of DataLayer with its names
   // Example: layer{ name: mnist, top: mnist, top: mnist, ... }
-  _caffeLayerParams = _caffeLayerParams.map(l => if(isData(l)) replaceTopWithNameOfDataLayer(l) else l)
-  private def replaceBottomOfNonDataLayers(l:LayerParameter):LayerParameter = {
+  _caffeLayerParams = _caffeLayerParams.map(l => if (isData(l)) replaceTopWithNameOfDataLayer(l) else l)
+  private def replaceBottomOfNonDataLayers(l: LayerParameter): LayerParameter = {
     val builder = l.toBuilder();
     // Note: Top will never be Data layer
-    for(i <- 0 until l.getBottomCount) {
-      if(_topToNameMappingForDataLayer.containsKey(l.getBottom(i))) 
+    for (i <- 0 until l.getBottomCount) {
+      if (_topToNameMappingForDataLayer.containsKey(l.getBottom(i)))
         builder.setBottom(i, _topToNameMappingForDataLayer.get(l.getBottom(i)))
     }
     return builder.build()
   }
   // 3a: If top/bottom of other layers refer DataLayer, then replace them
   // layer { name: "conv1_1", type: "Convolution", bottom: "data"
-  _caffeLayerParams = if(_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if(isData(l)) l else replaceBottomOfNonDataLayers(l))
-  
+  _caffeLayerParams = if (_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if (isData(l)) l else replaceBottomOfNonDataLayers(l))
+
   // Condition 4: Deal with fused layer
   // Example: layer { name: conv1, top: conv1, ... } layer { name: foo, bottom: conv1, top: conv1 }
-  private def isFusedLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
-  private def containsReferencesToFusedLayer(l:LayerParameter):Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
-  private val _fusedTopLayer = new HashMap[String, String]()
+  private def isFusedLayer(l: LayerParameter): Boolean                   = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
+  private def containsReferencesToFusedLayer(l: LayerParameter): Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
+  private val _fusedTopLayer                                             = new HashMap[String, String]()
   _caffeLayerParams = _caffeLayerParams.map(l => {
-    if(isFusedLayer(l)) {
+    if (isFusedLayer(l)) {
       val builder = l.toBuilder();
-      if(_fusedTopLayer.containsKey(l.getBottom(0))) {
+      if (_fusedTopLayer.containsKey(l.getBottom(0))) {
         builder.setBottom(0, _fusedTopLayer.get(l.getBottom(0)))
       }
       builder.setTop(0, l.getName)
       _fusedTopLayer.put(l.getBottom(0), l.getName)
       builder.build()
-    }
-    else if(containsReferencesToFusedLayer(l)) {
+    } else if (containsReferencesToFusedLayer(l)) {
       val builder = l.toBuilder();
-      for(i <- 0 until l.getBottomCount) {
-        if(_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
+      for (i <- 0 until l.getBottomCount) {
+        if (_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
           builder.setBottom(i, _fusedTopLayer.get(l.getBottomList.get(i)))
         }
       }
       builder.build()
-    }
-    else l
+    } else l
   })
-  
+
   // Used while reading caffemodel
   val replacedLayerNames = new HashMap[String, String]();
-  
+
   // Condition 5: Deal with incorrect naming
   // Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
-  private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
+  private def isIncorrectNamingLayer(l: LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
   _caffeLayerParams = _caffeLayerParams.map(l => {
-    if(isIncorrectNamingLayer(l)) {
+    if (isIncorrectNamingLayer(l)) {
       val builder = l.toBuilder();
       replacedLayerNames.put(l.getName, l.getTop(0))
       builder.setName(l.getTop(0))
       builder.build()
-    }
-    else l
+    } else l
   })
 
   // --------------------------------------------------------------------------------
-  
+
   // Helper functions to extract bottom and top layers
-  private def convertTupleListToMap(m:List[(String, String)]):Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
-  private def flipKeyValues(t:List[(String, Set[String])]): List[(String, String)] = t.flatMap(x => x._2.map(b => b -> x._1))
-  private def expandBottomList(layerName:String, bottomList:java.util.List[String]): List[(String, String)] = bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList 
-  
+  private def convertTupleListToMap(m: List[(String, String)]): Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
+  private def flipKeyValues(t: List[(String, Set[String])]): List[(String, String)]      = t.flatMap(x => x._2.map(b => b -> x._1))
+  private def expandBottomList(layerName: String, bottomList: java.util.List[String]): List[(String, String)] =
+    bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList
+
   // The bottom layers are the layers available in the getBottomList (from Caffe .proto files)
-  private val _bottomLayers:Map[String, Set[String]] = convertTupleListToMap(
-      _caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
+  private val _bottomLayers: Map[String, Set[String]] = convertTupleListToMap(_caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
   CaffeNetwork.LOG.debug("Bottom layers:" + _bottomLayers)
-  
+
   // Find the top layers by reversing the bottom list
-  private val _topLayers:Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
+  private val _topLayers: Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
   CaffeNetwork.LOG.debug("Top layers:" + _topLayers)
-  
+
   private val _layers: Map[String, CaffeLayer] = _caffeLayerParams.map(l => l.getName -> convertLayerParameterToCaffeLayer(l)).toMap
   CaffeNetwork.LOG.debug("Layers:" + _layers)
   private val _layerIDs: Map[String, Int] = _layers.entrySet().map(x => x.getKey -> x.getValue.id).toMap
-  
-  
-  private def throwException(layerName:String) = throw new LanguageException("Layer with name " + layerName + " not found")                              
-  def getLayers(): List[String] =  _layerNames
-  def getCaffeLayer(layerName:String):CaffeLayer = {
-    if(checkKey(_layers, layerName)) _layers.get(layerName).get
+
+  private def throwException(layerName: String) = throw new LanguageException("Layer with name " + layerName + " not found")
+  def getLayers(): List[String]                 = _layerNames
+  def getCaffeLayer(layerName: String): CaffeLayer =
+    if (checkKey(_layers, layerName)) _layers.get(layerName).get
     else {
-      if(replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
+      if (replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
         _layers.get(replacedLayerNames.get(layerName)).get
-      }
-      else throwException(layerName)
+      } else throwException(layerName)
     }
-  }
-  def getBottomLayers(layerName:String): Set[String] =  if(checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
-  def getTopLayers(layerName:String): Set[String] = if(checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
-  def getLayerID(layerName:String): Int = if(checkKey(_layerIDs, layerName))  _layerIDs.get(layerName).get else throwException(layerName)
-  
+  def getBottomLayers(layerName: String): Set[String] = if (checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
+  def getTopLayers(layerName: String): Set[String]    = if (checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
+  def getLayerID(layerName: String): Int              = if (checkKey(_layerIDs, layerName)) _layerIDs.get(layerName).get else throwException(layerName)
+
   // Helper functions
-  private def checkKey(m:Map[String, Any], key:String): Boolean = {
-    if(m == null) throw new LanguageException("Map is null (key=" + key + ")")
-    else if(key == null) throw new LanguageException("key is null (map=" + m + ")")
+  private def checkKey(m: Map[String, Any], key: String): Boolean =
+    if (m == null) throw new LanguageException("Map is null (key=" + key + ")")
+    else if (key == null) throw new LanguageException("key is null (map=" + m + ")")
     else m.containsKey(key)
-  }
-  private def convertLayerParameterToCaffeLayer(param:LayerParameter):CaffeLayer = {
+  private def convertLayerParameterToCaffeLayer(param: LayerParameter): CaffeLayer = {
     id = id + 1
     param.getType.toLowerCase() match {
       case "convolution" => new Convolution(param, id, this)
-      case "pooling" => if(param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX)  new MaxPooling(param, id, this)
-                        else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
-      case "innerproduct" => new InnerProduct(param, id, this)
-      case "relu" => new ReLU(param, id, this)
+      case "pooling" =>
+        if (param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX) new MaxPooling(param, id, this)
+        else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
+      case "innerproduct"    => new InnerProduct(param, id, this)
+      case "relu"            => new ReLU(param, id, this)
       case "softmaxwithloss" => new SoftmaxWithLoss(param, id, this)
-      case "dropout" => new Dropout(param, id, this)
-      case "data" => new Data(param, id, this, numChannels, height, width)
-      case "input" => new Data(param, id, this, numChannels, height, width)
-      case "batchnorm" => new BatchNorm(param, id, this)
-      case "scale" => new Scale(param, id, this)
-      case "eltwise" => new Elementwise(param, id, this)
-      case "concat" => new Concat(param, id, this)
-      case "deconvolution" => new DeConvolution(param, id, this)
-      case "threshold" => new Threshold(param, id, this)
-      case "softmax" => new Softmax(param, id, this)
-      case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
+      case "dropout"         => new Dropout(param, id, this)
+      case "data"            => new Data(param, id, this, numChannels, height, width)
+      case "input"           => new Data(param, id, this, numChannels, height, width)
+      case "batchnorm"       => new BatchNorm(param, id, this)
+      case "scale"           => new Scale(param, id, this)
+      case "eltwise"         => new Elementwise(param, id, this)
+      case "concat"          => new Concat(param, id, this)
+      case "deconvolution"   => new DeConvolution(param, id, this)
+      case "threshold"       => new Threshold(param, id, this)
+      case "softmax"         => new Softmax(param, id, this)
+      case _                 => throw new LanguageException("Layer of type " + param.getType + " is not supported")
     }
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 0e39192..a61ff10 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,72 +23,71 @@ import org.apache.sysml.runtime.DMLRuntimeException
 import caffe.Caffe
 
 trait CaffeSolver {
-  def sourceFileName:String;
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
-  
+  def sourceFileName: String;
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+
   // ----------------------------------------------------------------
   // Used for Fine-tuning
-  private def getLayerLr(layer:CaffeLayer, paramIndex:Int):String = {
+  private def getLayerLr(layer: CaffeLayer, paramIndex: Int): String = {
     val param = layer.param.getParamList
-    if(param == null || param.size() <= paramIndex)
+    if (param == null || param.size() <= paramIndex)
       return "lr"
     else
       // TODO: Ignoring param.get(index).getDecayMult for now
       return "(lr * " + param.get(paramIndex).getLrMult + ")"
   }
   // the first param { } is for the weights and the second is for the biases.
-  def getWeightLr(layer:CaffeLayer):String = getLayerLr(layer, 0)
-  def getBiasLr(layer:CaffeLayer):String = getLayerLr(layer, 1)
+  def getWeightLr(layer: CaffeLayer): String = getLayerLr(layer, 0)
+  def getBiasLr(layer: CaffeLayer): String   = getLayerLr(layer, 1)
   // ----------------------------------------------------------------
-  
-  def commaSep(arr:String*):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  
-  def l2reg_update(lambda:Double, dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+
+  def commaSep(arr: String*): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+
+  def l2reg_update(lambda: Double, dmlScript: StringBuilder, layer: CaffeLayer): Unit =
     // val donotRegularizeLayers:Boolean = layer.isInstanceOf[BatchNorm] || layer.isInstanceOf[Scale];
-    if(lambda != 0 && layer.shouldUpdateWeight) {
+    if (lambda != 0 && layer.shouldUpdateWeight) {
       dmlScript.append("\t").append(layer.dWeight + "_reg = l2_reg::backward(" + layer.weight + ", " + lambda + ")\n")
       dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n")
     }
-  }
 }
 
-class LearningRatePolicy(lr_policy:String="exp", base_lr:Double=0.01) {
-  def this(solver:Caffe.SolverParameter) {
+class LearningRatePolicy(lr_policy: String = "exp", base_lr: Double = 0.01) {
+  def this(solver: Caffe.SolverParameter) {
     this(solver.getLrPolicy, solver.getBaseLr)
-    if(solver.hasGamma) setGamma(solver.getGamma)
-    if(solver.hasStepsize) setStepsize(solver.getStepsize)
-    if(solver.hasPower()) setPower(solver.getPower)
+    if (solver.hasGamma) setGamma(solver.getGamma)
+    if (solver.hasStepsize) setStepsize(solver.getStepsize)
+    if (solver.hasPower()) setPower(solver.getPower)
   }
-  var gamma:Double = 0.95
-  var step:Double = 100000
-  var power:Double = 0.75
-  def setGamma(gamma1:Double):Unit = { gamma = gamma1 } 
-  def setStepsize(step1:Double):Unit = { step = step1 } 
-  def setPower(power1:Double): Unit = { power = power1 }
-  def updateLearningRate(dmlScript:StringBuilder):Unit = {
+  var gamma: Double                    = 0.95
+  var step: Double                     = 100000
+  var power: Double                    = 0.75
+  def setGamma(gamma1: Double): Unit   = gamma = gamma1
+  def setStepsize(step1: Double): Unit = step = step1
+  def setPower(power1: Double): Unit   = power = power1
+  def updateLearningRate(dmlScript: StringBuilder): Unit = {
     val new_lr = lr_policy.toLowerCase match {
-      case "fixed" => base_lr.toString
-      case "step" => "(" + base_lr + " * " +  gamma + " ^ " + " floor(e/" + step + "))"
-      case "exp" => "(" + base_lr + " * " + gamma + "^e)"
-      case "inv" =>  "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
-      case "poly" => "(" + base_lr  + " * (1 - e/ max_epochs) ^ " + power + ")"
+      case "fixed"   => base_lr.toString
+      case "step"    => "(" + base_lr + " * " + gamma + " ^ " + " floor(e/" + step + "))"
+      case "exp"     => "(" + base_lr + " * " + gamma + "^e)"
+      case "inv"     => "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
+      case "poly"    => "(" + base_lr + " * (1 - e/ max_epochs) ^ " + power + ")"
       case "sigmoid" => "(" + base_lr + "( 1/(1 + exp(-" + gamma + "* (e - " + step + "))))"
-      case _ => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
+      case _         => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
     }
     dmlScript.append("lr = " + new_lr + "\n")
   }
 }
 
-class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with momentum.
    *
@@ -112,31 +111,39 @@ class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
    *  - v: Updated velocity of the parameters `X`, of same shape as
    *      input `X`.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     l2reg_update(lambda, dmlScript, layer)
-    if(momentum == 0) {
+    if (momentum == 0) {
       // Use sgd
-      if(layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
-    }
-    else {
+      if (layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
+      if (layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
+    } else {
       // Use sgd_momentum
-      if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " + 
-          "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " + 
-          "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + ")\n")
+      if (layer.shouldUpdateWeight)
+        dmlScript
+          .append("\t")
+          .append(
+            "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+            "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + ")\n"
+          )
+      if (layer.shouldUpdateBias)
+        dmlScript
+          .append("\t")
+          .append(
+            "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+            "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + ")\n"
+          )
     }
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(momentum != 0) {
-      if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_momentum::init(" + layer.weight + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_momentum::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit =
+    if (momentum != 0) {
+      if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_momentum::init(" + layer.weight + ")\n")
+      if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_momentum::init(" + layer.bias + ")\n")
     }
-  }
-  def sourceFileName:String = if(momentum == 0) "sgd" else "sgd_momentum" 
+  def sourceFileName: String = if (momentum == 0) "sgd" else "sgd_momentum"
 }
 
-class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
+class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 1e-6) extends CaffeSolver {
   /*
    * Performs an Adagrad update.
    *
@@ -164,21 +171,31 @@ class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
    *  - cache: State that maintains per-parameter sum of squared
    *      gradients, of same shape as `X`.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     l2reg_update(lambda, dmlScript, layer)
-    if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_cache") + "] " + 
-        "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight+"_cache") + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_cache") + "] " + 
-        "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias+"_cache") + ")\n")
+    if (layer.shouldUpdateWeight)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.weight, layer.weight + "_cache") + "] " +
+          "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight + "_cache") + ")\n"
+        )
+    if (layer.shouldUpdateBias)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.bias, layer.bias + "_cache") + "] " +
+          "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias + "_cache") + ")\n"
+        )
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_cache = adagrad::init(" + layer.weight + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_cache = adagrad::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_cache = adagrad::init(" + layer.weight + ")\n")
+    if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_cache = adagrad::init(" + layer.bias + ")\n")
   }
-  def sourceFileName:String = "adagrad"
+  def sourceFileName: String = "adagrad"
 }
 
-class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class Nesterov(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with Nesterov momentum.
    *
@@ -211,20 +228,30 @@ class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
    *  - v: Updated velocity of the parameters X, of same shape as
    *      input v.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    val fn = if(Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
-    val lastParameter = if(Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
-    if(!Caffe2DML.USE_NESTEROV_UDF) {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    val fn            = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
+    val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+    if (!Caffe2DML.USE_NESTEROV_UDF) {
       l2reg_update(lambda, dmlScript, layer)
     }
-    if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " + 
-        "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + lastParameter + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " + 
-        "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + lastParameter + ")\n")
+    if (layer.shouldUpdateWeight)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+          "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n"
+        )
+    if (layer.shouldUpdateBias)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+          "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
+        )
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_nesterov::init(" + layer.weight + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_nesterov::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n")
+    if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n")
   }
-  def sourceFileName:String = "sgd_nesterov"
-}
\ No newline at end of file
+  def sourceFileName: String = "sgd_nesterov"
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 6b06c26..b68d493 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,301 +27,292 @@ import scala.collection.JavaConversions._
 import caffe.Caffe
 
 trait BaseDMLGenerator {
-  def commaSep(arr:List[String]):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  def commaSep(arr:String*):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  def int_add(v1:String, v2:String):String = {
-    try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"+"+v2+")"}
-  }
-  def int_mult(v1:String, v2:String, v3:String):String = {
-    try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"*"+v2+"*"+v3+")"}
-  }
-  def isNumber(x: String):Boolean = x forall Character.isDigit
-  def transpose(x:String):String = "t(" + x + ")"
-  def write(varName:String, fileName:String, format:String):String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
-  def readWeight(varName:String, fileName:String, sep:String="/"):String =  varName + " = read(weights + \"" + sep + fileName + "\")\n"
-  def asDMLString(str:String):String = "\"" + str + "\""
-  def assign(dmlScript:StringBuilder, lhsVar:String, rhsVar:String):Unit = {
+  def commaSep(arr: List[String]): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+  def commaSep(arr: String*): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+  def int_add(v1: String, v2: String): String =
+    try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "+" + v2 + ")" }
+  def int_mult(v1: String, v2: String, v3: String): String =
+    try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "*" + v2 + "*" + v3 + ")" }
+  def isNumber(x: String): Boolean                                                   = x forall Character.isDigit
+  def transpose(x: String): String                                                   = "t(" + x + ")"
+  def write(varName: String, fileName: String, format: String): String               = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
+  def readWeight(varName: String, fileName: String, sep: String = "/"): String       = varName + " = read(weights + \"" + sep + fileName + "\")\n"
+  def readScalarWeight(varName: String, fileName: String, sep: String = "/"): String = varName + " = as.scalar(read(weights + \"" + sep + fileName + "\"))\n"
+  def asDMLString(str: String): String                                               = "\"" + str + "\""
+  def assign(dmlScript: StringBuilder, lhsVar: String, rhsVar: String): Unit =
     dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("\n")
-  }
-  def sum(dmlScript:StringBuilder, variables:List[String]):StringBuilder = {
-    if(variables.length > 1) dmlScript.append("(")
+  def sum(dmlScript: StringBuilder, variables: List[String]): StringBuilder = {
+    if (variables.length > 1) dmlScript.append("(")
     dmlScript.append(variables(0))
-    for(i <- 1 until variables.length) {
+    for (i <- 1 until variables.length) {
       dmlScript.append(" + ").append(variables(i))
     }
-    if(variables.length > 1) dmlScript.append(")")
+    if (variables.length > 1) dmlScript.append(")")
     return dmlScript
   }
-  def addAndAssign(dmlScript:StringBuilder, lhsVar:String, rhsVars:List[String]):Unit = {
+  def addAndAssign(dmlScript: StringBuilder, lhsVar: String, rhsVars: List[String]): Unit = {
     dmlScript.append(lhsVar).append(" = ")
     sum(dmlScript, rhsVars)
     dmlScript.append("\n")
   }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String]):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
-  }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String], appendNewLine:Boolean):Unit = {
-    if(returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
-    if(returnVariables.length > 1) dmlScript.append("[")
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
+    if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
+    if (returnVariables.length > 1) dmlScript.append("[")
     dmlScript.append(returnVariables(0))
-    if(returnVariables.length > 1) {
-      for(i <- 1 until returnVariables.length) {
-	      dmlScript.append(",").append(returnVariables(i))
-	    }
+    if (returnVariables.length > 1) {
+      for (i <- 1 until returnVariables.length) {
+        dmlScript.append(",").append(returnVariables(i))
+      }
       dmlScript.append("]")
     }
     dmlScript.append(" = ")
     dmlScript.append(namespace1)
     dmlScript.append(functionName)
     dmlScript.append("(")
-    if(arguments != null) {
-      if(arguments.length != 0) 
+    if (arguments != null) {
+      if (arguments.length != 0)
         dmlScript.append(arguments(0))
-      if(arguments.length > 1) {
-        for(i <- 1 until arguments.length) {
-  	      dmlScript.append(",").append(arguments(i))
-  	    }
+      if (arguments.length > 1) {
+        for (i <- 1 until arguments.length) {
+          dmlScript.append(",").append(arguments(i))
+        }
       }
     }
     dmlScript.append(")")
-    if(appendNewLine) 
+    if (appendNewLine)
       dmlScript.append("\n")
   }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, appendNewLine:Boolean, arguments:String*):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
-  }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:String*):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
-  }
-  def rightIndexing(dmlScript:StringBuilder, varName:String, rl:String, ru:String, cl:String, cu:String):StringBuilder = {
+  def rightIndexing(dmlScript: StringBuilder, varName: String, rl: String, ru: String, cl: String, cu: String): StringBuilder = {
     dmlScript.append(varName).append("[")
-    if(rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
+    if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
     dmlScript.append(",")
-    if(cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
+    if (cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
     dmlScript.append("]")
   }
   // Performs assignVar = ceil(lhsVar/rhsVar)
-  def ceilDivide(dmlScript:StringBuilder, assignVar:String, lhsVar:String, rhsVar:String):Unit = 
+  def ceilDivide(dmlScript: StringBuilder, assignVar: String, lhsVar: String, rhsVar: String): Unit =
     dmlScript.append(assignVar).append(" = ").append("ceil(").append(lhsVar).append(" / ").append(rhsVar).append(")\n")
-  def print(arg:String):String = "print(" + arg + ")\n"
-  def dmlConcat(arg:String*):String = {
+  def print(arg: String): String = "print(" + arg + ")\n"
+  def dmlConcat(arg: String*): String = {
     val ret = new StringBuilder
     ret.append(arg(0))
-    for(i <- 1 until arg.length) {
+    for (i <- 1 until arg.length) {
       ret.append(" + ").append(arg(i))
     }
     ret.toString
   }
-  def matrix(init:String, rows:String, cols:String):String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")" 
-  def nrow(m:String):String = "nrow(" + m + ")"
-  def ncol(m:String):String = "ncol(" + m + ")"
-  def customAssert(cond:Boolean, msg:String) = if(!cond) throw new DMLRuntimeException(msg)
-  def multiply(v1:String, v2:String):String = v1 + "*" + v2
-  def colSums(m:String):String = "colSums(" + m + ")"
-  def ifdef(cmdLineVar:String, defaultVal:String):String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
-  def ifdef(cmdLineVar:String):String = ifdef(cmdLineVar, "\" \"")
-  def read(filePathVar:String, format:String):String = "read(" + filePathVar + ", format=\""+ format + "\")"
+  def matrix(init: String, rows: String, cols: String): String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")"
+  def nrow(m: String): String                                  = "nrow(" + m + ")"
+  def ncol(m: String): String                                  = "ncol(" + m + ")"
+  def customAssert(cond: Boolean, msg: String)                 = if (!cond) throw new DMLRuntimeException(msg)
+  def multiply(v1: String, v2: String): String                 = v1 + "*" + v2
+  def colSums(m: String): String                               = "colSums(" + m + ")"
+  def ifdef(cmdLineVar: String, defaultVal: String): String    = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
+  def ifdef(cmdLineVar: String): String                        = ifdef(cmdLineVar, "\" \"")
+  def read(filePathVar: String, format: String): String        = "read(" + filePathVar + ", format=\"" + format + "\")"
 }
 
 trait TabbedDMLGenerator extends BaseDMLGenerator {
-  def tabDMLScript(dmlScript:StringBuilder, numTabs:Int):StringBuilder =  tabDMLScript(dmlScript, numTabs, false)
-  def tabDMLScript(dmlScript:StringBuilder, numTabs:Int, prependNewLine:Boolean):StringBuilder =  {
-    if(prependNewLine) dmlScript.append("\n")
-	  for(i <- 0 until numTabs) dmlScript.append("\t")
-	  dmlScript
+  def tabDMLScript(dmlScript: StringBuilder, numTabs: Int): StringBuilder = tabDMLScript(dmlScript, numTabs, false)
+  def tabDMLScript(dmlScript: StringBuilder, numTabs: Int, prependNewLine: Boolean): StringBuilder = {
+    if (prependNewLine) dmlScript.append("\n")
+    for (i <- 0 until numTabs) dmlScript.append("\t")
+    dmlScript
   }
 }
 
 trait SourceDMLGenerator extends TabbedDMLGenerator {
-  val alreadyImported:HashSet[String] = new HashSet[String]
-  def source(dmlScript:StringBuilder, numTabs:Int, sourceFileName:String, dir:String):Unit = {
-	  if(sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
-      tabDMLScript(dmlScript, numTabs).append("source(\"" + dir +  sourceFileName + ".dml\") as " + sourceFileName + "\n")
+  val alreadyImported: HashSet[String] = new HashSet[String]
+  def source(dmlScript: StringBuilder, numTabs: Int, sourceFileName: String, dir: String): Unit =
+    if (sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
+      tabDMLScript(dmlScript, numTabs).append("source(\"" + dir + sourceFileName + ".dml\") as " + sourceFileName + "\n")
       alreadyImported.add(sourceFileName)
-	  }
+    }
+  def source(dmlScript: StringBuilder, numTabs: Int, net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit = {
+    // Add layers with multiple source files
+    if (net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
+      source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
+      source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
+    }
+    net.getLayers.map(layer => source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
+    if (solver != null)
+      source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
+    if (otherFiles != null)
+      otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
   }
-  def source(dmlScript:StringBuilder, numTabs:Int, net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
-	  // Add layers with multiple source files
-	  if(net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
-	    source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
-	    source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
-	  }
-	  net.getLayers.map(layer =>  source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
-	  if(solver != null)
-	    source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
-	  if(otherFiles != null)
-	    otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
-	}
 }
 
 trait NextBatchGenerator extends TabbedDMLGenerator {
-	def min(lhs:String, rhs:String): String = "min(" + lhs + ", " + rhs + ")"
-	
-	def assignBatch(dmlScript:StringBuilder, Xb:String, X:String, yb:String, y:String, indexPrefix:String, N:String, i:String):StringBuilder = {
-	  dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
-	  dmlScript.append(indexPrefix).append("end = min(beg + " +  Caffe2DML.batchSize + " - 1, " + N + "); ")
-	  dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
-	  if(yb != null && y != null)
-	    dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
-	  dmlScript.append("\n")
-	}
-	def getTestBatch(tabDMLScript:StringBuilder):Unit = {
+  def min(lhs: String, rhs: String): String = "min(" + lhs + ", " + rhs + ")"
+
+  def assignBatch(dmlScript: StringBuilder, Xb: String, X: String, yb: String, y: String, indexPrefix: String, N: String, i: String): StringBuilder = {
+    dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
+    dmlScript.append(indexPrefix).append("end = min(beg + " + Caffe2DML.batchSize + " - 1, " + N + "); ")
+    dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+    if (yb != null && y != null)
+      dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+    dmlScript.append("\n")
+  }
+  def getTestBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", Caffe2DML.numImages, "iter")
-  } 
-  def getTrainingBatch(tabDMLScript:StringBuilder):Unit = {
+
+  def getTrainingBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", Caffe2DML.numImages, "iter")
-  }
-	def getTrainingBatch(tabDMLScript:StringBuilder, X:String, y:String, numImages:String):Unit = {
-	  assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
-  }
-  def getTrainingMaxiBatch(tabDMLScript:StringBuilder):Unit = {
+  def getTrainingBatch(tabDMLScript: StringBuilder, X: String, y: String, numImages: String): Unit =
+    assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
+  def getTrainingMaxiBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "X_group_batch", Caffe2DML.X, "y_group_batch", Caffe2DML.y, "group_", Caffe2DML.numImages, "g")
-  }
-  def getValidationBatch(tabDMLScript:StringBuilder):Unit = {
+  def getValidationBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.XVal, "yb", Caffe2DML.yVal, "", Caffe2DML.numValidationImages, "iVal")
-  }
 }
 
 trait VisualizeDMLGenerator extends TabbedDMLGenerator {
-  var doVisualize = false
-  var _tensorboardLogDir:String = null
-  def setTensorBoardLogDir(log:String): Unit = { _tensorboardLogDir = log }
-  def tensorboardLogDir:String = {
-    if(_tensorboardLogDir == null) {
+  var doVisualize                             = false
+  var _tensorboardLogDir: String              = null
+  def setTensorBoardLogDir(log: String): Unit = _tensorboardLogDir = log
+  def tensorboardLogDir: String = {
+    if (_tensorboardLogDir == null) {
       _tensorboardLogDir = java.io.File.createTempFile("temp", System.nanoTime().toString()).getAbsolutePath
     }
     _tensorboardLogDir
   }
   def visualizeLoss(): Unit = {
-	   checkTensorBoardDependency()
-	   doVisualize = true
-	   // Visualize for both training and validation
-	   visualize(" ", " ", "training_loss", "iter", "training_loss", true)
-	   visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
-	   visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
-	   visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
-	}
-  val visTrainingDMLScript: StringBuilder = new StringBuilder 
+    checkTensorBoardDependency()
+    doVisualize = true
+    // Visualize for both training and validation
+    visualize(" ", " ", "training_loss", "iter", "training_loss", true)
+    visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
+    visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
+    visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
+  }
+  val visTrainingDMLScript: StringBuilder   = new StringBuilder
   val visValidationDMLScript: StringBuilder = new StringBuilder
-	def checkTensorBoardDependency():Unit = {
-	  try {
-	    if(!doVisualize)
-	      Class.forName( "com.google.protobuf.GeneratedMessageV3")
-	  } catch {
-	    case _:ClassNotFoundException => throw new DMLRuntimeException("To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar")   
-	  }
-	}
-  private def visualize(layerName:String, varType:String, aggFn:String, x:String, y:String,  isTraining:Boolean) = {
-    val dmlScript = if(isTraining) visTrainingDMLScript else visValidationDMLScript
-    dmlScript.append("viz_counter1 = visualize(" + 
-        commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
-        + ");\n")
+  def checkTensorBoardDependency(): Unit =
+    try {
+      if (!doVisualize)
+        Class.forName("com.google.protobuf.GeneratedMessageV3")
+    } catch {
+      case _: ClassNotFoundException =>
+        throw new DMLRuntimeException(
+          "To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar"
+        )
+    }
+  private def visualize(layerName: String, varType: String, aggFn: String, x: String, y: String, isTraining: Boolean) = {
+    val dmlScript = if (isTraining) visTrainingDMLScript else visValidationDMLScript
+    dmlScript.append(
+      "viz_counter1 = visualize(" +
+      commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
+      + ");\n"
+    )
     dmlScript.append("viz_counter = viz_counter + viz_counter1\n")
   }
-  def visualizeLayer(net:CaffeNetwork, layerName:String, varType:String, aggFn:String): Unit = {
-	  // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
-	  // 'sum', 'mean', 'var' or 'sd'
-	  checkTensorBoardDependency()
-	  doVisualize = true
-	  if(net.getLayers.filter(_.equals(layerName)).size == 0)
-	    throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
-	  val dmlVar = {
-	    val l = net.getCaffeLayer(layerName)
-	    varType match {
-	      case "weight" => l.weight
-	      case "bias" => l.bias
-	      case "dweight" => l.dWeight
-	      case "dbias" => l.dBias
-	      case "output" => l.out
-	      // case "doutput" => l.dX
-	      case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
-	    }
-	   }
-	  if(dmlVar == null)
-	    throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
-	  // Visualize for both training and validation
-	  visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
-	  visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
-	}
-  
-  def appendTrainingVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
-    if(doVisualize)
-        tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
-  }
-  def appendValidationVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
-    if(doVisualize)
-        tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
+  def visualizeLayer(net: CaffeNetwork, layerName: String, varType: String, aggFn: String): Unit = {
+    // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
+    // 'sum', 'mean', 'var' or 'sd'
+    checkTensorBoardDependency()
+    doVisualize = true
+    if (net.getLayers.filter(_.equals(layerName)).size == 0)
+      throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
+    val dmlVar = {
+      val l = net.getCaffeLayer(layerName)
+      varType match {
+        case "weight"  => l.weight
+        case "bias"    => l.bias
+        case "dweight" => l.dWeight
+        case "dbias"   => l.dBias
+        case "output"  => l.out
+        // case "doutput" => l.dX
+        case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+      }
+    }
+    if (dmlVar == null)
+      throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+    // Visualize for both training and validation
+    visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
+    visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
   }
+
+  def appendTrainingVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+    if (doVisualize)
+      tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
+  def appendValidationVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+    if (doVisualize)
+      tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
 }
 
 trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with VisualizeDMLGenerator {
   // Also makes "code reading" possible for Caffe2DML :)
-	var dmlScript = new StringBuilder
-	var numTabs = 0
-	def reset():Unit = {
-	  dmlScript.clear()
-	  alreadyImported.clear()
-	  numTabs = 0
-	  visTrainingDMLScript.clear()
-	  visValidationDMLScript.clear()
-	  doVisualize = false
-	}
-	// -------------------------------------------------------------------------------------------------
-	// Helper functions that calls super class methods and simplifies the code of this trait
-	def tabDMLScript():StringBuilder = tabDMLScript(dmlScript, numTabs, false)
-	def tabDMLScript(prependNewLine:Boolean):StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
-	def source(net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
-	  source(dmlScript, numTabs, net, solver, otherFiles)
-	}
-	// -------------------------------------------------------------------------------------------------
-	
-	def ifBlock(cond:String)(op: => Unit) {
-	  tabDMLScript.append("if(" + cond + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def whileBlock(cond:String)(op: => Unit) {
-	  tabDMLScript.append("while(" + cond + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def forBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
-	  tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def parForBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
-	  tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	
-	def printClassificationReport():Unit = {
-    ifBlock("debug"){
+  var dmlScript = new StringBuilder
+  var numTabs   = 0
+  def reset(): Unit = {
+    dmlScript.clear()
+    alreadyImported.clear()
+    numTabs = 0
+    visTrainingDMLScript.clear()
+    visValidationDMLScript.clear()
+    doVisualize = false
+  }
+  // -------------------------------------------------------------------------------------------------
+  // Helper functions that calls super class methods and simplifies the code of this trait
+  def tabDMLScript(): StringBuilder                        = tabDMLScript(dmlScript, numTabs, false)
+  def tabDMLScript(prependNewLine: Boolean): StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
+  def source(net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit =
+    source(dmlScript, numTabs, net, solver, otherFiles)
+  // -------------------------------------------------------------------------------------------------
+
+  def ifBlock(cond: String)(op: => Unit) {
+    tabDMLScript.append("if(" + cond + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def whileBlock(cond: String)(op: => Unit) {
+    tabDMLScript.append("while(" + cond + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def forBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+    tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def parForBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+    tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+
+  def printClassificationReport(): Unit =
+    ifBlock("debug") {
       assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
       assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
       forBlock("class_i", "1", "num_rows_error_measures") {
@@ -337,35 +328,38 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
         assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
         assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
       }
-      val dmlTab = "\\t"
-      val header = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+      val dmlTab        = "\\t"
+      val header        = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
       val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
       tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
     }
-  }
-	
-	// Appends DML corresponding to source and externalFunction statements. 
-  def appendHeaders(net:CaffeNetwork, solver:CaffeSolver, isTraining:Boolean):Unit = {
+
+  // Appends DML corresponding to source and externalFunction statements.
+  def appendHeaders(net: CaffeNetwork, solver: CaffeSolver, isTraining: Boolean): Unit = {
     // Append source statements for layers as well as solver
-	  source(net, solver, if(isTraining) Array[String]("l2_reg") else null)
-	  
-	  if(isTraining) {
-  	  // Append external built-in function headers:
-  	  // 1. visualize external built-in function header
-      if(doVisualize) {
-  	    tabDMLScript.append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
-  	        "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
-  	    tabDMLScript.append("viz_counter = 0\n")
-  	    System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
-  	  }
-  	  // 2. update_nesterov external built-in function header
-  	  if(Caffe2DML.USE_NESTEROV_UDF) {
-  	    tabDMLScript.append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n")
-  	  }
-	  }
+    source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+
+    if (isTraining) {
+      // Append external built-in function headers:
+      // 1. visualize external built-in function header
+      if (doVisualize) {
+        tabDMLScript.append(
+          "visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
+          "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n"
+        )
+        tabDMLScript.append("viz_counter = 0\n")
+        System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
+      }
+      // 2. update_nesterov external built-in function header
+      if (Caffe2DML.USE_NESTEROV_UDF) {
+        tabDMLScript.append(
+          "update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n"
+        )
+      }
+    }
   }
-  
-  def readMatrix(varName:String, cmdLineVar:String):Unit = {
+
+  def readMatrix(varName: String, cmdLineVar: String): Unit = {
     val pathVar = varName + "_path"
     assign(tabDMLScript, pathVar, ifdef(cmdLineVar))
     // Uncomment the following lines if we want to the user to pass the format
@@ -374,47 +368,47 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
     // assign(tabDMLScript, varName, "read(" + pathVar + ", format=" + formatVar + ")")
     assign(tabDMLScript, varName, "read(" + pathVar + ")")
   }
-  
-  def readInputData(net:CaffeNetwork, isTraining:Boolean):Unit = {
+
+  def readInputData(net: CaffeNetwork, isTraining: Boolean): Unit = {
     // Read and convert to one-hot encoding
     readMatrix("X_full", "$X")
-	  if(isTraining) {
-	    readMatrix("y_full", "$y")
-  	  tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
-  	  tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
-	    tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
-	  }
-	  else {
-	    tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
-	  }
+    if (isTraining) {
+      readMatrix("y_full", "$y")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+      tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+      tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+    } else {
+      tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+    }
   }
-  
-  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean): Unit = {
+
+  def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean): Unit =
     initWeights(net, solver, readWeights, new HashSet[String]())
-  }
-  
-  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean, layersToIgnore:HashSet[String]): Unit = {
+
+  def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean, layersToIgnore: HashSet[String]): Unit = {
     tabDMLScript.append("weights = ifdef($weights, \" \")\n")
-	  // Initialize the layers and solvers
-	  tabDMLScript.append("# Initialize the layers and solvers\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
-	  if(readWeights) {
-		  // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
-		  tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
-	  }
-	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+    // Initialize the layers and solvers
+    tabDMLScript.append("# Initialize the layers and solvers\n")
+    net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+    if (readWeights) {
+      // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+      tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+      val allLayers = net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_))
+      allLayers.filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
+      allLayers.filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
+    }
+    net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
   }
-  
-  def getLossLayers(net:CaffeNetwork):List[IsLossLayer] = {
+
+  def getLossLayers(net: CaffeNetwork): List[IsLossLayer] = {
     val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
-	  if(lossLayers.length != 1) 
-	    throw new DMLRuntimeException("Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
-	  lossLayers
+    if (lossLayers.length != 1)
+      throw new DMLRuntimeException(
+        "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer])
+      )
+    lossLayers
   }
-  
-  def updateMeanVarianceForBatchNorm(net:CaffeNetwork, value:Boolean):Unit = {
+
+  def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
     net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
-  }
-}
\ No newline at end of file
+}