You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/15 18:03:09 UTC

[2/4] systemml git commit: [SYSTEMML-540] Support loading of batch normalization weights in .caffemodel file using Caffe2DML

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index 67682d5..2b07788 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,205 +33,197 @@ import org.apache.commons.logging.LogFactory
 
 trait Network {
   def getLayers(): List[String]
-  def getCaffeLayer(layerName:String):CaffeLayer
-  def getBottomLayers(layerName:String): Set[String]
-  def getTopLayers(layerName:String): Set[String]
-  def getLayerID(layerName:String): Int
+  def getCaffeLayer(layerName: String): CaffeLayer
+  def getBottomLayers(layerName: String): Set[String]
+  def getTopLayers(layerName: String): Set[String]
+  def getLayerID(layerName: String): Int
 }
 
 object CaffeNetwork {
   val LOG = LogFactory.getLog(classOf[CaffeNetwork].getName)
 }
 
-class CaffeNetwork(netFilePath:String, val currentPhase:Phase, 
-     var numChannels:String, var height:String, var width:String
-    ) extends Network {
-  private def isIncludedInCurrentPhase(l:LayerParameter): Boolean = {
-    if(currentPhase == null) return true // while deployment
-    else if(l.getIncludeCount == 0) true 
+class CaffeNetwork(netFilePath: String, val currentPhase: Phase, var numChannels: String, var height: String, var width: String) extends Network {
+  private def isIncludedInCurrentPhase(l: LayerParameter): Boolean =
+    if (currentPhase == null) return true // while deployment
+    else if (l.getIncludeCount == 0) true
     else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
-  }
   private var id = 1
-  def this(deployFilePath:String) {
+  def this(deployFilePath: String) {
     this(deployFilePath, null, null, null, null)
   }
   // --------------------------------------------------------------------------------
-  private var _net:NetParameter = Utils.readCaffeNet(netFilePath)
-  private var _caffeLayerParams:List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+  private var _net: NetParameter                      = Utils.readCaffeNet(netFilePath)
+  private var _caffeLayerParams: List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
   // This method is used if the user doesnot provide number of channels, height and width
-  private def setCHW(inputShapes:java.util.List[caffe.Caffe.BlobShape]):Unit = {
-    if(inputShapes.size != 1)
-        throw new DMLRuntimeException("Expected only one input shape")
+  private def setCHW(inputShapes: java.util.List[caffe.Caffe.BlobShape]): Unit = {
+    if (inputShapes.size != 1)
+      throw new DMLRuntimeException("Expected only one input shape")
     val inputShape = inputShapes.get(0)
-    if(inputShape.getDimCount != 4)
+    if (inputShape.getDimCount != 4)
       throw new DMLRuntimeException("Expected the input shape of dimension 4")
     numChannels = inputShape.getDim(1).toString
     height = inputShape.getDim(2).toString
     width = inputShape.getDim(3).toString
   }
-  if(numChannels == null && height == null && width == null) {
-    val inputLayer:List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
-    if(inputLayer.size == 1) {
+  if (numChannels == null && height == null && width == null) {
+    val inputLayer: List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
+    if (inputLayer.size == 1) {
       setCHW(inputLayer(0).getInputParam.getShapeList)
-    }
-    else if(inputLayer.size == 0) {
-      throw new DMLRuntimeException("Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer.")
-    }
-    else {
+    } else if (inputLayer.size == 0) {
+      throw new DMLRuntimeException(
+        "Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer."
+      )
+    } else {
       throw new DMLRuntimeException("Multiple Input layer is not supported")
     }
   }
   // --------------------------------------------------------------------------------
-  
+
   private var _layerNames: List[String] = _caffeLayerParams.map(l => l.getName).toList
   CaffeNetwork.LOG.debug("Layers in current phase:" + _layerNames)
-  
+
   // Condition 1: assert that each name is unique
   private val _duplicateLayerNames = _layerNames.diff(_layerNames.distinct)
-  if(_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
-  
+  if (_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
+
   // Condition 2: only 1 top name, except Data layer
   private val _condition2Exceptions = Set("data")
-  _caffeLayerParams.filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase)).map(l => if(l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
+  _caffeLayerParams
+    .filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase))
+    .map(l => if (l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
 
   // Condition 3: Replace top layer names referring to a Data layer with its name
   // Example: layer{ name: mnist, top: data, top: label, ... }
-  private val _topToNameMappingForDataLayer = new HashMap[String, String]()
-  private def containsOnly(list:java.util.List[String], v:String): Boolean = list.toSet.diff(Set(v)).size() == 0
-  private def isData(l:LayerParameter):Boolean = l.getType.equalsIgnoreCase("data")
-  private def replaceTopWithNameOfDataLayer(l:LayerParameter):LayerParameter =  {
-    if(containsOnly(l.getTopList,l.getName))
+  private val _topToNameMappingForDataLayer                                  = new HashMap[String, String]()
+  private def containsOnly(list: java.util.List[String], v: String): Boolean = list.toSet.diff(Set(v)).size() == 0
+  private def isData(l: LayerParameter): Boolean                             = l.getType.equalsIgnoreCase("data")
+  private def replaceTopWithNameOfDataLayer(l: LayerParameter): LayerParameter =
+    if (containsOnly(l.getTopList, l.getName))
       return l
     else {
-      val builder = l.toBuilder(); 
-      for(i <- 0 until l.getTopCount) {
-        if(! l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
-        builder.setTop(i, l.getName) 
+      val builder = l.toBuilder();
+      for (i <- 0 until l.getTopCount) {
+        if (!l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
+        builder.setTop(i, l.getName)
       }
-      return builder.build() 
+      return builder.build()
     }
-  }
   // 3a: Replace top of DataLayer with its names
   // Example: layer{ name: mnist, top: mnist, top: mnist, ... }
-  _caffeLayerParams = _caffeLayerParams.map(l => if(isData(l)) replaceTopWithNameOfDataLayer(l) else l)
-  private def replaceBottomOfNonDataLayers(l:LayerParameter):LayerParameter = {
+  _caffeLayerParams = _caffeLayerParams.map(l => if (isData(l)) replaceTopWithNameOfDataLayer(l) else l)
+  private def replaceBottomOfNonDataLayers(l: LayerParameter): LayerParameter = {
     val builder = l.toBuilder();
     // Note: Top will never be Data layer
-    for(i <- 0 until l.getBottomCount) {
-      if(_topToNameMappingForDataLayer.containsKey(l.getBottom(i))) 
+    for (i <- 0 until l.getBottomCount) {
+      if (_topToNameMappingForDataLayer.containsKey(l.getBottom(i)))
         builder.setBottom(i, _topToNameMappingForDataLayer.get(l.getBottom(i)))
     }
     return builder.build()
   }
   // 3a: If top/bottom of other layers refer DataLayer, then replace them
   // layer { name: "conv1_1", type: "Convolution", bottom: "data"
-  _caffeLayerParams = if(_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if(isData(l)) l else replaceBottomOfNonDataLayers(l))
-  
+  _caffeLayerParams = if (_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if (isData(l)) l else replaceBottomOfNonDataLayers(l))
+
   // Condition 4: Deal with fused layer
   // Example: layer { name: conv1, top: conv1, ... } layer { name: foo, bottom: conv1, top: conv1 }
-  private def isFusedLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
-  private def containsReferencesToFusedLayer(l:LayerParameter):Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
-  private val _fusedTopLayer = new HashMap[String, String]()
+  private def isFusedLayer(l: LayerParameter): Boolean                   = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
+  private def containsReferencesToFusedLayer(l: LayerParameter): Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
+  private val _fusedTopLayer                                             = new HashMap[String, String]()
   _caffeLayerParams = _caffeLayerParams.map(l => {
-    if(isFusedLayer(l)) {
+    if (isFusedLayer(l)) {
       val builder = l.toBuilder();
-      if(_fusedTopLayer.containsKey(l.getBottom(0))) {
+      if (_fusedTopLayer.containsKey(l.getBottom(0))) {
         builder.setBottom(0, _fusedTopLayer.get(l.getBottom(0)))
       }
       builder.setTop(0, l.getName)
       _fusedTopLayer.put(l.getBottom(0), l.getName)
       builder.build()
-    }
-    else if(containsReferencesToFusedLayer(l)) {
+    } else if (containsReferencesToFusedLayer(l)) {
       val builder = l.toBuilder();
-      for(i <- 0 until l.getBottomCount) {
-        if(_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
+      for (i <- 0 until l.getBottomCount) {
+        if (_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
           builder.setBottom(i, _fusedTopLayer.get(l.getBottomList.get(i)))
         }
       }
       builder.build()
-    }
-    else l
+    } else l
   })
-  
+
   // Used while reading caffemodel
   val replacedLayerNames = new HashMap[String, String]();
-  
+
   // Condition 5: Deal with incorrect naming
   // Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
-  private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
+  private def isIncorrectNamingLayer(l: LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
   _caffeLayerParams = _caffeLayerParams.map(l => {
-    if(isIncorrectNamingLayer(l)) {
+    if (isIncorrectNamingLayer(l)) {
       val builder = l.toBuilder();
       replacedLayerNames.put(l.getName, l.getTop(0))
       builder.setName(l.getTop(0))
       builder.build()
-    }
-    else l
+    } else l
   })
 
   // --------------------------------------------------------------------------------
-  
+
   // Helper functions to extract bottom and top layers
-  private def convertTupleListToMap(m:List[(String, String)]):Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
-  private def flipKeyValues(t:List[(String, Set[String])]): List[(String, String)] = t.flatMap(x => x._2.map(b => b -> x._1))
-  private def expandBottomList(layerName:String, bottomList:java.util.List[String]): List[(String, String)] = bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList 
-  
+  private def convertTupleListToMap(m: List[(String, String)]): Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
+  private def flipKeyValues(t: List[(String, Set[String])]): List[(String, String)]      = t.flatMap(x => x._2.map(b => b -> x._1))
+  private def expandBottomList(layerName: String, bottomList: java.util.List[String]): List[(String, String)] =
+    bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList
+
   // The bottom layers are the layers available in the getBottomList (from Caffe .proto files)
-  private val _bottomLayers:Map[String, Set[String]] = convertTupleListToMap(
-      _caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
+  private val _bottomLayers: Map[String, Set[String]] = convertTupleListToMap(_caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
   CaffeNetwork.LOG.debug("Bottom layers:" + _bottomLayers)
-  
+
   // Find the top layers by reversing the bottom list
-  private val _topLayers:Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
+  private val _topLayers: Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
   CaffeNetwork.LOG.debug("Top layers:" + _topLayers)
-  
+
   private val _layers: Map[String, CaffeLayer] = _caffeLayerParams.map(l => l.getName -> convertLayerParameterToCaffeLayer(l)).toMap
   CaffeNetwork.LOG.debug("Layers:" + _layers)
   private val _layerIDs: Map[String, Int] = _layers.entrySet().map(x => x.getKey -> x.getValue.id).toMap
-  
-  
-  private def throwException(layerName:String) = throw new LanguageException("Layer with name " + layerName + " not found")                              
-  def getLayers(): List[String] =  _layerNames
-  def getCaffeLayer(layerName:String):CaffeLayer = {
-    if(checkKey(_layers, layerName)) _layers.get(layerName).get
+
+  private def throwException(layerName: String) = throw new LanguageException("Layer with name " + layerName + " not found")
+  def getLayers(): List[String]                 = _layerNames
+  def getCaffeLayer(layerName: String): CaffeLayer =
+    if (checkKey(_layers, layerName)) _layers.get(layerName).get
     else {
-      if(replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
+      if (replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
         _layers.get(replacedLayerNames.get(layerName)).get
-      }
-      else throwException(layerName)
+      } else throwException(layerName)
     }
-  }
-  def getBottomLayers(layerName:String): Set[String] =  if(checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
-  def getTopLayers(layerName:String): Set[String] = if(checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
-  def getLayerID(layerName:String): Int = if(checkKey(_layerIDs, layerName))  _layerIDs.get(layerName).get else throwException(layerName)
-  
+  def getBottomLayers(layerName: String): Set[String] = if (checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
+  def getTopLayers(layerName: String): Set[String]    = if (checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
+  def getLayerID(layerName: String): Int              = if (checkKey(_layerIDs, layerName)) _layerIDs.get(layerName).get else throwException(layerName)
+
   // Helper functions
-  private def checkKey(m:Map[String, Any], key:String): Boolean = {
-    if(m == null) throw new LanguageException("Map is null (key=" + key + ")")
-    else if(key == null) throw new LanguageException("key is null (map=" + m + ")")
+  private def checkKey(m: Map[String, Any], key: String): Boolean =
+    if (m == null) throw new LanguageException("Map is null (key=" + key + ")")
+    else if (key == null) throw new LanguageException("key is null (map=" + m + ")")
     else m.containsKey(key)
-  }
-  private def convertLayerParameterToCaffeLayer(param:LayerParameter):CaffeLayer = {
+  private def convertLayerParameterToCaffeLayer(param: LayerParameter): CaffeLayer = {
     id = id + 1
     param.getType.toLowerCase() match {
       case "convolution" => new Convolution(param, id, this)
-      case "pooling" => if(param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX)  new MaxPooling(param, id, this)
-                        else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
-      case "innerproduct" => new InnerProduct(param, id, this)
-      case "relu" => new ReLU(param, id, this)
+      case "pooling" =>
+        if (param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX) new MaxPooling(param, id, this)
+        else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
+      case "innerproduct"    => new InnerProduct(param, id, this)
+      case "relu"            => new ReLU(param, id, this)
       case "softmaxwithloss" => new SoftmaxWithLoss(param, id, this)
-      case "dropout" => new Dropout(param, id, this)
-      case "data" => new Data(param, id, this, numChannels, height, width)
-      case "input" => new Data(param, id, this, numChannels, height, width)
-      case "batchnorm" => new BatchNorm(param, id, this)
-      case "scale" => new Scale(param, id, this)
-      case "eltwise" => new Elementwise(param, id, this)
-      case "concat" => new Concat(param, id, this)
-      case "deconvolution" => new DeConvolution(param, id, this)
-      case "threshold" => new Threshold(param, id, this)
-      case "softmax" => new Softmax(param, id, this)
-      case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
+      case "dropout"         => new Dropout(param, id, this)
+      case "data"            => new Data(param, id, this, numChannels, height, width)
+      case "input"           => new Data(param, id, this, numChannels, height, width)
+      case "batchnorm"       => new BatchNorm(param, id, this)
+      case "scale"           => new Scale(param, id, this)
+      case "eltwise"         => new Elementwise(param, id, this)
+      case "concat"          => new Concat(param, id, this)
+      case "deconvolution"   => new DeConvolution(param, id, this)
+      case "threshold"       => new Threshold(param, id, this)
+      case "softmax"         => new Softmax(param, id, this)
+      case _                 => throw new LanguageException("Layer of type " + param.getType + " is not supported")
     }
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 0e39192..a61ff10 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,72 +23,71 @@ import org.apache.sysml.runtime.DMLRuntimeException
 import caffe.Caffe
 
 trait CaffeSolver {
-  def sourceFileName:String;
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
-  
+  def sourceFileName: String;
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+
   // ----------------------------------------------------------------
   // Used for Fine-tuning
-  private def getLayerLr(layer:CaffeLayer, paramIndex:Int):String = {
+  private def getLayerLr(layer: CaffeLayer, paramIndex: Int): String = {
     val param = layer.param.getParamList
-    if(param == null || param.size() <= paramIndex)
+    if (param == null || param.size() <= paramIndex)
       return "lr"
     else
       // TODO: Ignoring param.get(index).getDecayMult for now
       return "(lr * " + param.get(paramIndex).getLrMult + ")"
   }
   // the first param { } is for the weights and the second is for the biases.
-  def getWeightLr(layer:CaffeLayer):String = getLayerLr(layer, 0)
-  def getBiasLr(layer:CaffeLayer):String = getLayerLr(layer, 1)
+  def getWeightLr(layer: CaffeLayer): String = getLayerLr(layer, 0)
+  def getBiasLr(layer: CaffeLayer): String   = getLayerLr(layer, 1)
   // ----------------------------------------------------------------
-  
-  def commaSep(arr:String*):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  
-  def l2reg_update(lambda:Double, dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+
+  def commaSep(arr: String*): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+
+  def l2reg_update(lambda: Double, dmlScript: StringBuilder, layer: CaffeLayer): Unit =
     // val donotRegularizeLayers:Boolean = layer.isInstanceOf[BatchNorm] || layer.isInstanceOf[Scale];
-    if(lambda != 0 && layer.shouldUpdateWeight) {
+    if (lambda != 0 && layer.shouldUpdateWeight) {
       dmlScript.append("\t").append(layer.dWeight + "_reg = l2_reg::backward(" + layer.weight + ", " + lambda + ")\n")
       dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n")
     }
-  }
 }
 
-class LearningRatePolicy(lr_policy:String="exp", base_lr:Double=0.01) {
-  def this(solver:Caffe.SolverParameter) {
+class LearningRatePolicy(lr_policy: String = "exp", base_lr: Double = 0.01) {
+  def this(solver: Caffe.SolverParameter) {
     this(solver.getLrPolicy, solver.getBaseLr)
-    if(solver.hasGamma) setGamma(solver.getGamma)
-    if(solver.hasStepsize) setStepsize(solver.getStepsize)
-    if(solver.hasPower()) setPower(solver.getPower)
+    if (solver.hasGamma) setGamma(solver.getGamma)
+    if (solver.hasStepsize) setStepsize(solver.getStepsize)
+    if (solver.hasPower()) setPower(solver.getPower)
   }
-  var gamma:Double = 0.95
-  var step:Double = 100000
-  var power:Double = 0.75
-  def setGamma(gamma1:Double):Unit = { gamma = gamma1 } 
-  def setStepsize(step1:Double):Unit = { step = step1 } 
-  def setPower(power1:Double): Unit = { power = power1 }
-  def updateLearningRate(dmlScript:StringBuilder):Unit = {
+  var gamma: Double                    = 0.95
+  var step: Double                     = 100000
+  var power: Double                    = 0.75
+  def setGamma(gamma1: Double): Unit   = gamma = gamma1
+  def setStepsize(step1: Double): Unit = step = step1
+  def setPower(power1: Double): Unit   = power = power1
+  def updateLearningRate(dmlScript: StringBuilder): Unit = {
     val new_lr = lr_policy.toLowerCase match {
-      case "fixed" => base_lr.toString
-      case "step" => "(" + base_lr + " * " +  gamma + " ^ " + " floor(e/" + step + "))"
-      case "exp" => "(" + base_lr + " * " + gamma + "^e)"
-      case "inv" =>  "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
-      case "poly" => "(" + base_lr  + " * (1 - e/ max_epochs) ^ " + power + ")"
+      case "fixed"   => base_lr.toString
+      case "step"    => "(" + base_lr + " * " + gamma + " ^ " + " floor(e/" + step + "))"
+      case "exp"     => "(" + base_lr + " * " + gamma + "^e)"
+      case "inv"     => "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
+      case "poly"    => "(" + base_lr + " * (1 - e/ max_epochs) ^ " + power + ")"
       case "sigmoid" => "(" + base_lr + "( 1/(1 + exp(-" + gamma + "* (e - " + step + "))))"
-      case _ => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
+      case _         => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
     }
     dmlScript.append("lr = " + new_lr + "\n")
   }
 }
 
-class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with momentum.
    *
@@ -112,31 +111,39 @@ class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
    *  - v: Updated velocity of the parameters `X`, of same shape as
    *      input `X`.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     l2reg_update(lambda, dmlScript, layer)
-    if(momentum == 0) {
+    if (momentum == 0) {
       // Use sgd
-      if(layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
-    }
-    else {
+      if (layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
+      if (layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
+    } else {
       // Use sgd_momentum
-      if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " + 
-          "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " + 
-          "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + ")\n")
+      if (layer.shouldUpdateWeight)
+        dmlScript
+          .append("\t")
+          .append(
+            "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+            "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + ")\n"
+          )
+      if (layer.shouldUpdateBias)
+        dmlScript
+          .append("\t")
+          .append(
+            "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+            "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + ")\n"
+          )
     }
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(momentum != 0) {
-      if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_momentum::init(" + layer.weight + ")\n")
-      if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_momentum::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit =
+    if (momentum != 0) {
+      if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_momentum::init(" + layer.weight + ")\n")
+      if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_momentum::init(" + layer.bias + ")\n")
     }
-  }
-  def sourceFileName:String = if(momentum == 0) "sgd" else "sgd_momentum" 
+  def sourceFileName: String = if (momentum == 0) "sgd" else "sgd_momentum"
 }
 
-class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
+class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 1e-6) extends CaffeSolver {
   /*
    * Performs an Adagrad update.
    *
@@ -164,21 +171,31 @@ class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
    *  - cache: State that maintains per-parameter sum of squared
    *      gradients, of same shape as `X`.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
     l2reg_update(lambda, dmlScript, layer)
-    if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_cache") + "] " + 
-        "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight+"_cache") + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_cache") + "] " + 
-        "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias+"_cache") + ")\n")
+    if (layer.shouldUpdateWeight)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.weight, layer.weight + "_cache") + "] " +
+          "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight + "_cache") + ")\n"
+        )
+    if (layer.shouldUpdateBias)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.bias, layer.bias + "_cache") + "] " +
+          "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias + "_cache") + ")\n"
+        )
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_cache = adagrad::init(" + layer.weight + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_cache = adagrad::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_cache = adagrad::init(" + layer.weight + ")\n")
+    if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_cache = adagrad::init(" + layer.bias + ")\n")
   }
-  def sourceFileName:String = "adagrad"
+  def sourceFileName: String = "adagrad"
 }
 
-class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class Nesterov(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
   /*
    * Performs an SGD update with Nesterov momentum.
    *
@@ -211,20 +228,30 @@ class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
    *  - v: Updated velocity of the parameters X, of same shape as
    *      input v.
    */
-  def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    val fn = if(Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
-    val lastParameter = if(Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
-    if(!Caffe2DML.USE_NESTEROV_UDF) {
+  def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    val fn            = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
+    val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+    if (!Caffe2DML.USE_NESTEROV_UDF) {
       l2reg_update(lambda, dmlScript, layer)
     }
-    if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " + 
-        "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + lastParameter + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " + 
-        "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + lastParameter + ")\n")
+    if (layer.shouldUpdateWeight)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+          "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n"
+        )
+    if (layer.shouldUpdateBias)
+      dmlScript
+        .append("\t")
+        .append(
+          "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+          "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
+        )
   }
-  def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
-    if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_nesterov::init(" + layer.weight + ")\n")
-    if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_nesterov::init(" + layer.bias + ")\n")
+  def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+    if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n")
+    if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n")
   }
-  def sourceFileName:String = "sgd_nesterov"
-}
\ No newline at end of file
+  def sourceFileName: String = "sgd_nesterov"
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 6b06c26..b68d493 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -6,9 +6,9 @@
  * to you under the Apache License, Version 2.0 (the
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
- * 
+ *
  *   http://www.apache.org/licenses/LICENSE-2.0
- * 
+ *
  * Unless required by applicable law or agreed to in writing,
  * software distributed under the License is distributed on an
  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,301 +27,292 @@ import scala.collection.JavaConversions._
 import caffe.Caffe
 
 trait BaseDMLGenerator {
-  def commaSep(arr:List[String]):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  def commaSep(arr:String*):String = {
-	  if(arr.length == 1) arr(0) else {
-	    var ret = arr(0)
-	    for(i <- 1 until arr.length) {
-	      ret = ret + "," + arr(i)
-	    }
-	    ret
-	  }
-	}
-  def int_add(v1:String, v2:String):String = {
-    try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"+"+v2+")"}
-  }
-  def int_mult(v1:String, v2:String, v3:String):String = {
-    try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"*"+v2+"*"+v3+")"}
-  }
-  def isNumber(x: String):Boolean = x forall Character.isDigit
-  def transpose(x:String):String = "t(" + x + ")"
-  def write(varName:String, fileName:String, format:String):String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
-  def readWeight(varName:String, fileName:String, sep:String="/"):String =  varName + " = read(weights + \"" + sep + fileName + "\")\n"
-  def asDMLString(str:String):String = "\"" + str + "\""
-  def assign(dmlScript:StringBuilder, lhsVar:String, rhsVar:String):Unit = {
+  def commaSep(arr: List[String]): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+  def commaSep(arr: String*): String =
+    if (arr.length == 1) arr(0)
+    else {
+      var ret = arr(0)
+      for (i <- 1 until arr.length) {
+        ret = ret + "," + arr(i)
+      }
+      ret
+    }
+  def int_add(v1: String, v2: String): String =
+    try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "+" + v2 + ")" }
+  def int_mult(v1: String, v2: String, v3: String): String =
+    try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "*" + v2 + "*" + v3 + ")" }
+  def isNumber(x: String): Boolean                                                   = x forall Character.isDigit
+  def transpose(x: String): String                                                   = "t(" + x + ")"
+  def write(varName: String, fileName: String, format: String): String               = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
+  def readWeight(varName: String, fileName: String, sep: String = "/"): String       = varName + " = read(weights + \"" + sep + fileName + "\")\n"
+  def readScalarWeight(varName: String, fileName: String, sep: String = "/"): String = varName + " = as.scalar(read(weights + \"" + sep + fileName + "\"))\n"
+  def asDMLString(str: String): String                                               = "\"" + str + "\""
+  def assign(dmlScript: StringBuilder, lhsVar: String, rhsVar: String): Unit =
     dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("\n")
-  }
-  def sum(dmlScript:StringBuilder, variables:List[String]):StringBuilder = {
-    if(variables.length > 1) dmlScript.append("(")
+  def sum(dmlScript: StringBuilder, variables: List[String]): StringBuilder = {
+    if (variables.length > 1) dmlScript.append("(")
     dmlScript.append(variables(0))
-    for(i <- 1 until variables.length) {
+    for (i <- 1 until variables.length) {
       dmlScript.append(" + ").append(variables(i))
     }
-    if(variables.length > 1) dmlScript.append(")")
+    if (variables.length > 1) dmlScript.append(")")
     return dmlScript
   }
-  def addAndAssign(dmlScript:StringBuilder, lhsVar:String, rhsVars:List[String]):Unit = {
+  def addAndAssign(dmlScript: StringBuilder, lhsVar: String, rhsVars: List[String]): Unit = {
     dmlScript.append(lhsVar).append(" = ")
     sum(dmlScript, rhsVars)
     dmlScript.append("\n")
   }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String]):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
-  }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String], appendNewLine:Boolean):Unit = {
-    if(returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
-    if(returnVariables.length > 1) dmlScript.append("[")
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
+    if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
+    if (returnVariables.length > 1) dmlScript.append("[")
     dmlScript.append(returnVariables(0))
-    if(returnVariables.length > 1) {
-      for(i <- 1 until returnVariables.length) {
-	      dmlScript.append(",").append(returnVariables(i))
-	    }
+    if (returnVariables.length > 1) {
+      for (i <- 1 until returnVariables.length) {
+        dmlScript.append(",").append(returnVariables(i))
+      }
       dmlScript.append("]")
     }
     dmlScript.append(" = ")
     dmlScript.append(namespace1)
     dmlScript.append(functionName)
     dmlScript.append("(")
-    if(arguments != null) {
-      if(arguments.length != 0) 
+    if (arguments != null) {
+      if (arguments.length != 0)
         dmlScript.append(arguments(0))
-      if(arguments.length > 1) {
-        for(i <- 1 until arguments.length) {
-  	      dmlScript.append(",").append(arguments(i))
-  	    }
+      if (arguments.length > 1) {
+        for (i <- 1 until arguments.length) {
+          dmlScript.append(",").append(arguments(i))
+        }
       }
     }
     dmlScript.append(")")
-    if(appendNewLine) 
+    if (appendNewLine)
       dmlScript.append("\n")
   }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, appendNewLine:Boolean, arguments:String*):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
-  }
-  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:String*):Unit = {
+  def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
     invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
-  }
-  def rightIndexing(dmlScript:StringBuilder, varName:String, rl:String, ru:String, cl:String, cu:String):StringBuilder = {
+  def rightIndexing(dmlScript: StringBuilder, varName: String, rl: String, ru: String, cl: String, cu: String): StringBuilder = {
     dmlScript.append(varName).append("[")
-    if(rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
+    if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
     dmlScript.append(",")
-    if(cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
+    if (cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
     dmlScript.append("]")
   }
   // Performs assignVar = ceil(lhsVar/rhsVar)
-  def ceilDivide(dmlScript:StringBuilder, assignVar:String, lhsVar:String, rhsVar:String):Unit = 
+  def ceilDivide(dmlScript: StringBuilder, assignVar: String, lhsVar: String, rhsVar: String): Unit =
     dmlScript.append(assignVar).append(" = ").append("ceil(").append(lhsVar).append(" / ").append(rhsVar).append(")\n")
-  def print(arg:String):String = "print(" + arg + ")\n"
-  def dmlConcat(arg:String*):String = {
+  def print(arg: String): String = "print(" + arg + ")\n"
+  def dmlConcat(arg: String*): String = {
     val ret = new StringBuilder
     ret.append(arg(0))
-    for(i <- 1 until arg.length) {
+    for (i <- 1 until arg.length) {
       ret.append(" + ").append(arg(i))
     }
     ret.toString
   }
-  def matrix(init:String, rows:String, cols:String):String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")" 
-  def nrow(m:String):String = "nrow(" + m + ")"
-  def ncol(m:String):String = "ncol(" + m + ")"
-  def customAssert(cond:Boolean, msg:String) = if(!cond) throw new DMLRuntimeException(msg)
-  def multiply(v1:String, v2:String):String = v1 + "*" + v2
-  def colSums(m:String):String = "colSums(" + m + ")"
-  def ifdef(cmdLineVar:String, defaultVal:String):String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
-  def ifdef(cmdLineVar:String):String = ifdef(cmdLineVar, "\" \"")
-  def read(filePathVar:String, format:String):String = "read(" + filePathVar + ", format=\""+ format + "\")"
+  def matrix(init: String, rows: String, cols: String): String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")"
+  def nrow(m: String): String                                  = "nrow(" + m + ")"
+  def ncol(m: String): String                                  = "ncol(" + m + ")"
+  def customAssert(cond: Boolean, msg: String)                 = if (!cond) throw new DMLRuntimeException(msg)
+  def multiply(v1: String, v2: String): String                 = v1 + "*" + v2
+  def colSums(m: String): String                               = "colSums(" + m + ")"
+  def ifdef(cmdLineVar: String, defaultVal: String): String    = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
+  def ifdef(cmdLineVar: String): String                        = ifdef(cmdLineVar, "\" \"")
+  def read(filePathVar: String, format: String): String        = "read(" + filePathVar + ", format=\"" + format + "\")"
 }
 
 trait TabbedDMLGenerator extends BaseDMLGenerator {
-  def tabDMLScript(dmlScript:StringBuilder, numTabs:Int):StringBuilder =  tabDMLScript(dmlScript, numTabs, false)
-  def tabDMLScript(dmlScript:StringBuilder, numTabs:Int, prependNewLine:Boolean):StringBuilder =  {
-    if(prependNewLine) dmlScript.append("\n")
-	  for(i <- 0 until numTabs) dmlScript.append("\t")
-	  dmlScript
+  def tabDMLScript(dmlScript: StringBuilder, numTabs: Int): StringBuilder = tabDMLScript(dmlScript, numTabs, false)
+  def tabDMLScript(dmlScript: StringBuilder, numTabs: Int, prependNewLine: Boolean): StringBuilder = {
+    if (prependNewLine) dmlScript.append("\n")
+    for (i <- 0 until numTabs) dmlScript.append("\t")
+    dmlScript
   }
 }
 
 trait SourceDMLGenerator extends TabbedDMLGenerator {
-  val alreadyImported:HashSet[String] = new HashSet[String]
-  def source(dmlScript:StringBuilder, numTabs:Int, sourceFileName:String, dir:String):Unit = {
-	  if(sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
-      tabDMLScript(dmlScript, numTabs).append("source(\"" + dir +  sourceFileName + ".dml\") as " + sourceFileName + "\n")
+  val alreadyImported: HashSet[String] = new HashSet[String]
+  def source(dmlScript: StringBuilder, numTabs: Int, sourceFileName: String, dir: String): Unit =
+    if (sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
+      tabDMLScript(dmlScript, numTabs).append("source(\"" + dir + sourceFileName + ".dml\") as " + sourceFileName + "\n")
       alreadyImported.add(sourceFileName)
-	  }
+    }
+  def source(dmlScript: StringBuilder, numTabs: Int, net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit = {
+    // Add layers with multiple source files
+    if (net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
+      source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
+      source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
+    }
+    net.getLayers.map(layer => source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
+    if (solver != null)
+      source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
+    if (otherFiles != null)
+      otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
   }
-  def source(dmlScript:StringBuilder, numTabs:Int, net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
-	  // Add layers with multiple source files
-	  if(net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
-	    source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
-	    source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
-	  }
-	  net.getLayers.map(layer =>  source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
-	  if(solver != null)
-	    source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
-	  if(otherFiles != null)
-	    otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
-	}
 }
 
 trait NextBatchGenerator extends TabbedDMLGenerator {
-	def min(lhs:String, rhs:String): String = "min(" + lhs + ", " + rhs + ")"
-	
-	def assignBatch(dmlScript:StringBuilder, Xb:String, X:String, yb:String, y:String, indexPrefix:String, N:String, i:String):StringBuilder = {
-	  dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
-	  dmlScript.append(indexPrefix).append("end = min(beg + " +  Caffe2DML.batchSize + " - 1, " + N + "); ")
-	  dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
-	  if(yb != null && y != null)
-	    dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
-	  dmlScript.append("\n")
-	}
-	def getTestBatch(tabDMLScript:StringBuilder):Unit = {
+  def min(lhs: String, rhs: String): String = "min(" + lhs + ", " + rhs + ")"
+
+  def assignBatch(dmlScript: StringBuilder, Xb: String, X: String, yb: String, y: String, indexPrefix: String, N: String, i: String): StringBuilder = {
+    dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
+    dmlScript.append(indexPrefix).append("end = min(beg + " + Caffe2DML.batchSize + " - 1, " + N + "); ")
+    dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+    if (yb != null && y != null)
+      dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+    dmlScript.append("\n")
+  }
+  def getTestBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", Caffe2DML.numImages, "iter")
-  } 
-  def getTrainingBatch(tabDMLScript:StringBuilder):Unit = {
+
+  def getTrainingBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", Caffe2DML.numImages, "iter")
-  }
-	def getTrainingBatch(tabDMLScript:StringBuilder, X:String, y:String, numImages:String):Unit = {
-	  assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
-  }
-  def getTrainingMaxiBatch(tabDMLScript:StringBuilder):Unit = {
+  def getTrainingBatch(tabDMLScript: StringBuilder, X: String, y: String, numImages: String): Unit =
+    assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
+  def getTrainingMaxiBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "X_group_batch", Caffe2DML.X, "y_group_batch", Caffe2DML.y, "group_", Caffe2DML.numImages, "g")
-  }
-  def getValidationBatch(tabDMLScript:StringBuilder):Unit = {
+  def getValidationBatch(tabDMLScript: StringBuilder): Unit =
     assignBatch(tabDMLScript, "Xb", Caffe2DML.XVal, "yb", Caffe2DML.yVal, "", Caffe2DML.numValidationImages, "iVal")
-  }
 }
 
 trait VisualizeDMLGenerator extends TabbedDMLGenerator {
-  var doVisualize = false
-  var _tensorboardLogDir:String = null
-  def setTensorBoardLogDir(log:String): Unit = { _tensorboardLogDir = log }
-  def tensorboardLogDir:String = {
-    if(_tensorboardLogDir == null) {
+  var doVisualize                             = false
+  var _tensorboardLogDir: String              = null
+  def setTensorBoardLogDir(log: String): Unit = _tensorboardLogDir = log
+  def tensorboardLogDir: String = {
+    if (_tensorboardLogDir == null) {
       _tensorboardLogDir = java.io.File.createTempFile("temp", System.nanoTime().toString()).getAbsolutePath
     }
     _tensorboardLogDir
   }
   def visualizeLoss(): Unit = {
-	   checkTensorBoardDependency()
-	   doVisualize = true
-	   // Visualize for both training and validation
-	   visualize(" ", " ", "training_loss", "iter", "training_loss", true)
-	   visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
-	   visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
-	   visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
-	}
-  val visTrainingDMLScript: StringBuilder = new StringBuilder 
+    checkTensorBoardDependency()
+    doVisualize = true
+    // Visualize for both training and validation
+    visualize(" ", " ", "training_loss", "iter", "training_loss", true)
+    visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
+    visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
+    visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
+  }
+  val visTrainingDMLScript: StringBuilder   = new StringBuilder
   val visValidationDMLScript: StringBuilder = new StringBuilder
-	def checkTensorBoardDependency():Unit = {
-	  try {
-	    if(!doVisualize)
-	      Class.forName( "com.google.protobuf.GeneratedMessageV3")
-	  } catch {
-	    case _:ClassNotFoundException => throw new DMLRuntimeException("To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar")   
-	  }
-	}
-  private def visualize(layerName:String, varType:String, aggFn:String, x:String, y:String,  isTraining:Boolean) = {
-    val dmlScript = if(isTraining) visTrainingDMLScript else visValidationDMLScript
-    dmlScript.append("viz_counter1 = visualize(" + 
-        commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
-        + ");\n")
+  def checkTensorBoardDependency(): Unit =
+    try {
+      if (!doVisualize)
+        Class.forName("com.google.protobuf.GeneratedMessageV3")
+    } catch {
+      case _: ClassNotFoundException =>
+        throw new DMLRuntimeException(
+          "To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar"
+        )
+    }
+  private def visualize(layerName: String, varType: String, aggFn: String, x: String, y: String, isTraining: Boolean) = {
+    val dmlScript = if (isTraining) visTrainingDMLScript else visValidationDMLScript
+    dmlScript.append(
+      "viz_counter1 = visualize(" +
+      commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
+      + ");\n"
+    )
     dmlScript.append("viz_counter = viz_counter + viz_counter1\n")
   }
-  def visualizeLayer(net:CaffeNetwork, layerName:String, varType:String, aggFn:String): Unit = {
-	  // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
-	  // 'sum', 'mean', 'var' or 'sd'
-	  checkTensorBoardDependency()
-	  doVisualize = true
-	  if(net.getLayers.filter(_.equals(layerName)).size == 0)
-	    throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
-	  val dmlVar = {
-	    val l = net.getCaffeLayer(layerName)
-	    varType match {
-	      case "weight" => l.weight
-	      case "bias" => l.bias
-	      case "dweight" => l.dWeight
-	      case "dbias" => l.dBias
-	      case "output" => l.out
-	      // case "doutput" => l.dX
-	      case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
-	    }
-	   }
-	  if(dmlVar == null)
-	    throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
-	  // Visualize for both training and validation
-	  visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
-	  visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
-	}
-  
-  def appendTrainingVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
-    if(doVisualize)
-        tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
-  }
-  def appendValidationVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
-    if(doVisualize)
-        tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
+  def visualizeLayer(net: CaffeNetwork, layerName: String, varType: String, aggFn: String): Unit = {
+    // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
+    // 'sum', 'mean', 'var' or 'sd'
+    checkTensorBoardDependency()
+    doVisualize = true
+    if (net.getLayers.filter(_.equals(layerName)).size == 0)
+      throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
+    val dmlVar = {
+      val l = net.getCaffeLayer(layerName)
+      varType match {
+        case "weight"  => l.weight
+        case "bias"    => l.bias
+        case "dweight" => l.dWeight
+        case "dbias"   => l.dBias
+        case "output"  => l.out
+        // case "doutput" => l.dX
+        case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+      }
+    }
+    if (dmlVar == null)
+      throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+    // Visualize for both training and validation
+    visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
+    visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
   }
+
+  def appendTrainingVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+    if (doVisualize)
+      tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
+  def appendValidationVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+    if (doVisualize)
+      tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
 }
 
 trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with VisualizeDMLGenerator {
   // Also makes "code reading" possible for Caffe2DML :)
-	var dmlScript = new StringBuilder
-	var numTabs = 0
-	def reset():Unit = {
-	  dmlScript.clear()
-	  alreadyImported.clear()
-	  numTabs = 0
-	  visTrainingDMLScript.clear()
-	  visValidationDMLScript.clear()
-	  doVisualize = false
-	}
-	// -------------------------------------------------------------------------------------------------
-	// Helper functions that calls super class methods and simplifies the code of this trait
-	def tabDMLScript():StringBuilder = tabDMLScript(dmlScript, numTabs, false)
-	def tabDMLScript(prependNewLine:Boolean):StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
-	def source(net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
-	  source(dmlScript, numTabs, net, solver, otherFiles)
-	}
-	// -------------------------------------------------------------------------------------------------
-	
-	def ifBlock(cond:String)(op: => Unit) {
-	  tabDMLScript.append("if(" + cond + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def whileBlock(cond:String)(op: => Unit) {
-	  tabDMLScript.append("while(" + cond + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def forBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
-	  tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	def parForBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
-	  tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
-	  numTabs += 1
-	  op
-	  numTabs -= 1
-	  tabDMLScript.append("}\n")
-	}
-	
-	def printClassificationReport():Unit = {
-    ifBlock("debug"){
+  var dmlScript = new StringBuilder
+  var numTabs   = 0
+  def reset(): Unit = {
+    dmlScript.clear()
+    alreadyImported.clear()
+    numTabs = 0
+    visTrainingDMLScript.clear()
+    visValidationDMLScript.clear()
+    doVisualize = false
+  }
+  // -------------------------------------------------------------------------------------------------
+  // Helper functions that calls super class methods and simplifies the code of this trait
+  def tabDMLScript(): StringBuilder                        = tabDMLScript(dmlScript, numTabs, false)
+  def tabDMLScript(prependNewLine: Boolean): StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
+  def source(net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit =
+    source(dmlScript, numTabs, net, solver, otherFiles)
+  // -------------------------------------------------------------------------------------------------
+
+  def ifBlock(cond: String)(op: => Unit) {
+    tabDMLScript.append("if(" + cond + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def whileBlock(cond: String)(op: => Unit) {
+    tabDMLScript.append("while(" + cond + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def forBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+    tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+  def parForBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+    tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+    numTabs += 1
+    op
+    numTabs -= 1
+    tabDMLScript.append("}\n")
+  }
+
+  def printClassificationReport(): Unit =
+    ifBlock("debug") {
       assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
       assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
       forBlock("class_i", "1", "num_rows_error_measures") {
@@ -337,35 +328,38 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
         assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
         assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
       }
-      val dmlTab = "\\t"
-      val header = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+      val dmlTab        = "\\t"
+      val header        = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
       val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
       tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
     }
-  }
-	
-	// Appends DML corresponding to source and externalFunction statements. 
-  def appendHeaders(net:CaffeNetwork, solver:CaffeSolver, isTraining:Boolean):Unit = {
+
+  // Appends DML corresponding to source and externalFunction statements.
+  def appendHeaders(net: CaffeNetwork, solver: CaffeSolver, isTraining: Boolean): Unit = {
     // Append source statements for layers as well as solver
-	  source(net, solver, if(isTraining) Array[String]("l2_reg") else null)
-	  
-	  if(isTraining) {
-  	  // Append external built-in function headers:
-  	  // 1. visualize external built-in function header
-      if(doVisualize) {
-  	    tabDMLScript.append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
-  	        "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
-  	    tabDMLScript.append("viz_counter = 0\n")
-  	    System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
-  	  }
-  	  // 2. update_nesterov external built-in function header
-  	  if(Caffe2DML.USE_NESTEROV_UDF) {
-  	    tabDMLScript.append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n")
-  	  }
-	  }
+    source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+
+    if (isTraining) {
+      // Append external built-in function headers:
+      // 1. visualize external built-in function header
+      if (doVisualize) {
+        tabDMLScript.append(
+          "visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
+          "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n"
+        )
+        tabDMLScript.append("viz_counter = 0\n")
+        System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
+      }
+      // 2. update_nesterov external built-in function header
+      if (Caffe2DML.USE_NESTEROV_UDF) {
+        tabDMLScript.append(
+          "update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\");  \n"
+        )
+      }
+    }
   }
-  
-  def readMatrix(varName:String, cmdLineVar:String):Unit = {
+
+  def readMatrix(varName: String, cmdLineVar: String): Unit = {
     val pathVar = varName + "_path"
     assign(tabDMLScript, pathVar, ifdef(cmdLineVar))
     // Uncomment the following lines if we want to the user to pass the format
@@ -374,47 +368,47 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
     // assign(tabDMLScript, varName, "read(" + pathVar + ", format=" + formatVar + ")")
     assign(tabDMLScript, varName, "read(" + pathVar + ")")
   }
-  
-  def readInputData(net:CaffeNetwork, isTraining:Boolean):Unit = {
+
+  def readInputData(net: CaffeNetwork, isTraining: Boolean): Unit = {
     // Read and convert to one-hot encoding
     readMatrix("X_full", "$X")
-	  if(isTraining) {
-	    readMatrix("y_full", "$y")
-  	  tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
-  	  tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
-	    tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
-	  }
-	  else {
-	    tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
-	  }
+    if (isTraining) {
+      readMatrix("y_full", "$y")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+      tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+      tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+    } else {
+      tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+    }
   }
-  
-  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean): Unit = {
+
+  def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean): Unit =
     initWeights(net, solver, readWeights, new HashSet[String]())
-  }
-  
-  def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean, layersToIgnore:HashSet[String]): Unit = {
+
+  def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean, layersToIgnore: HashSet[String]): Unit = {
     tabDMLScript.append("weights = ifdef($weights, \" \")\n")
-	  // Initialize the layers and solvers
-	  tabDMLScript.append("# Initialize the layers and solvers\n")
-	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
-	  if(readWeights) {
-		  // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
-		  tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
-		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
-	  }
-	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+    // Initialize the layers and solvers
+    tabDMLScript.append("# Initialize the layers and solvers\n")
+    net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+    if (readWeights) {
+      // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+      tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+      val allLayers = net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_))
+      allLayers.filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
+      allLayers.filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
+    }
+    net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
   }
-  
-  def getLossLayers(net:CaffeNetwork):List[IsLossLayer] = {
+
+  def getLossLayers(net: CaffeNetwork): List[IsLossLayer] = {
     val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
-	  if(lossLayers.length != 1) 
-	    throw new DMLRuntimeException("Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
-	  lossLayers
+    if (lossLayers.length != 1)
+      throw new DMLRuntimeException(
+        "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer])
+      )
+    lossLayers
   }
-  
-  def updateMeanVarianceForBatchNorm(net:CaffeNetwork, value:Boolean):Unit = {
+
+  def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
     net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
-  }
-}
\ No newline at end of file
+}