You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/13 17:47:41 UTC

incubator-systemml git commit: [SYSTEMML-1479] Integrated DeConvolution and Concat layer in Caffe2DML

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 59ff8b3d4 -> 7ed8e3f49


[SYSTEMML-1479] Integrated DeConvolution and Concat layer in Caffe2DML

Closes #492.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7ed8e3f4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7ed8e3f4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7ed8e3f4

Branch: refs/heads/master
Commit: 7ed8e3f4994b80f8e98f16e40f12667180fd6a02
Parents: 59ff8b3
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Sat May 13 10:45:55 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Sat May 13 10:45:55 2017 -0700

----------------------------------------------------------------------
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 713 ++++++++++++++++++-
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  14 +
 .../org/apache/sysml/api/dl/DMLGenerator.scala  |  14 +-
 3 files changed, 713 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index 4faa203..0d1740e 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -76,12 +76,12 @@ trait CaffeLayer extends BaseDMLGenerator {
     if(computedDout == null) {
       val ret = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
       if(ret.size == 0) throw new LanguageException("Expected atleast 1 top layer for " + param.getName)
-      else if(ret.size == 1)     computedDout = ret(0).dX
-      else                       computedDout = sum(new StringBuilder, ret.map(_.dX).toList).toString()
+      else if(ret.size == 1)     computedDout = ret(0).dX(id)
+      else                       computedDout = sum(new StringBuilder, ret.map(_.dX(id)).toList).toString()
     }
     computedDout
   }
-  val dX = "dOut" + id
+  def dX(bottomLayerID:Int) = "dOut" + id + "_" + bottomLayerID
   // --------------------------------------------------------------------------------------
   // No need to override these methods in subclasses, instead classes that have weights and biases 
   // should implement HasWeight and HasBias traits.
@@ -99,8 +99,33 @@ trait CaffeLayer extends BaseDMLGenerator {
   def invokeForward(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
     invoke(dmlScript, sourceFileName + "::", returnVariables, "forward", arguments.toList)
   }
+  // -----------------------------------------------------------------------------------
+  // All the layers (with the exception of Concat) call one of the below methods in the backward function.
+  // The preceding layer expects that 'dX(bottomLayerID) + outSuffix' is assigned. 
+  // l1 <--- dX(1) <-----|
+  //                     |-- [current layer: dOut3 (computed by backward)] <---- "dOut" + id + outSuffix
+  // l2 <--- dX(2) <-----|
+  // The below functions perform two functions:
+  // 1. Compute backward: either call dml file's backward (for example: invokeBackward) or just propagate next layers errors (assignDoutToDX)
+  // 2. Then make sure that all the preceding layer get the errors using:
+  //        bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+  
+  // The layers that have a corresponding dml script call this method.
+  // Assumption: the first variable of resultVariables is always dX
   def invokeBackward(dmlScript:StringBuilder, outSuffix:String, resultVariables:List[String],  arguments:String*):Unit = {
-    invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList)
+    invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
+    val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+    dmlScript.append("; ")
+    bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
+    dmlScript.append("\n")
+  }
+  // On-the-fly layers (such as Scale and Elementwise) call this function to propagate next layers errors to the previous layer
+  def assignDoutToDX(dmlScript:StringBuilder, outSuffix:String):Unit = {
+    dmlScript.append("dOut" + id  + outSuffix + " = " + dout)
+    val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+    dmlScript.append("; ")
+    bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+    dmlScript.append("\n")
   }
   // --------------------------------------------------------------------------------------
 }
@@ -144,16 +169,135 @@ class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numCh
 class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
   // val scale =  
   override def sourceFileName = "batch_norm2d"
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - C: Number of input channels (dimensionality of input depth).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   */
   override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](gamma, beta, ema_mean, ema_var), numChannels)
   var update_mean_var = true
+  /*
+   * Computes the forward pass for a 2D (spatial) batch normalization
+   * layer.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * A spatial batch normalization layer uses the per-channel sample
+   * mean and per-channel uncorrected sample variance during training
+   * to normalize each channel of the input data.  Additionally, it
+   * introduces learnable parameters (gamma, beta) to control the
+   * amount of normalization.
+   *
+   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
+   *
+   * This implementation maintains exponential moving averages of the
+   * mean and variance during training for use during testing.
+   *
+   * Reference:
+   *  - Batch Normalization: Accelerating Deep Network Training by
+   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+   *    - https://arxiv.org/abs/1502.03167
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hin*Win).
+   *  - ema_mean_upd: Updated exponential moving average of the mean,
+   *      of shape (C, 1).
+   *  - ema_var_upd: Updated exponential moving average of the variance,
+   *      of shape (C, 1).
+   *  - cache_mean: Cache of the batch mean, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   *  - cache_var: Cache of the batch variance, of shape (C, 1).
+   *      Note: This is used for performance during training.
+   *  - cache_norm: Cache of the normalized inputs, of
+   *      shape (C, N*Hin*Win). Note: This is used for performance
+   *      during training.
+   */
   def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
     val mode = if(isPrediction) "\"test\"" else "\"train\""
     invokeForward(dmlScript, List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)), 
         X, gamma, beta, numChannels, Hin, Win, mode, ema_mean, ema_var,  ma_fraction, eps)  
   }
-  
+  /*
+   * Computes the backward pass for a 2D (spatial) batch normalization
+   * layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
+   *  - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
+   *  - ema_mean_upd: Updated exponential moving average of the mean
+   *      from the forward pass, of shape (C, 1).
+   *  - ema_var_upd: Updated exponential moving average of the variance
+   *      from the forward pass, of shape (C, 1).
+   *  - cache_mean: Cache of the batch mean from the forward pass, of
+   *      shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - cache_var: Cache of the batch variance from the forward pass,
+   *      of shape (C, 1).  Note: This is used for performance during
+   *      training.
+   *  - cache_norm: Cache of the normalized inputs from the forward
+   *      pass, of shape (C, N*Hin*Win).  Note: This is used for
+   *      performance during training.
+   *  - X: Input data matrix to the forward pass, of
+   *      shape (N, C*Hin*Win).
+   *  - gamma: Scale parameters, of shape (C, 1).
+   *  - beta: Shift parameters, of shape (C, 1).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (C, 1).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (C, 1).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dgamma: Gradient wrt `W`, of shape (C, 1).
+   *  - dbeta: Gradient wrt `b`, of shape (C, 1).
+   *
+   */
   def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
-    invokeBackward(dmlScript, outSuffix, List[String](dX, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels, 
+    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels, 
           Hin, Win, "\"train\"", ema_mean, ema_var,  ma_fraction, eps)
   }
   
@@ -190,8 +334,9 @@ class Scale(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends
   if(!param.getScaleParam.getBiasTerm) throw new LanguageException("Add \"scale_param { bias_term: true }\" to the layer " + param.getName)
   override def sourceFileName = null
   override def init(dmlScript: StringBuilder): Unit = {}
+  // TODO: Generalize this !!
   def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = assign(dmlScript, out, X)
-  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
 }
 // ------------------------------------------------------------------
 
@@ -200,10 +345,10 @@ class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
   override def init(dmlScript: StringBuilder): Unit = {}
   if(param.getEltwiseParam.hasOperation && param.getEltwiseParam.getOperation != EltwiseOp.SUM)
     throw new LanguageException("Currently only elementwise sum operation supported")
-  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
     addAndAssign(dmlScript, out, param.getBottomList.map(b => net.getCaffeLayer(b).out).toList)
   }
-  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
   override def outputShape = {
     if(_out == null) _out = net.getCaffeLayer(net.getBottomLayers(param.getName).take(1).toSeq.get(0)).outputShape
     _out
@@ -212,6 +357,117 @@ class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
   
 }
 
+class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName = null
+  override def init(dmlScript: StringBuilder): Unit = {}
+  var _childLayers:List[CaffeLayer] = null
+  
+  // Utility function to create string of format:
+  // fn(fn(fn(_childLayers(0).out, _childLayers(1).out), _childLayers(2).out), ...)
+  // This is useful because we do not support multi-input cbind and rbind in DML.
+  def _getMultiFn(fn:String):String = {
+    if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+    var tmp = fn + "(" + _childLayers(0).out + ", " + _childLayers(1).out + ")"
+    for(i <- 2 until _childLayers.size) {
+      tmp = fn + "(" + tmp + ", " +  _childLayers(i).out + ")"
+    }
+    tmp
+  }
+  
+  /*
+   * Computes the forward pass for a concatenation layer.
+   *
+   * Inputs:
+   *  - n_i * c_i * h * w for each input blob i from 1 to K.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape 
+   *    - if axis = 0: (n_1 + n_2 + ... + n_K) * c_1 * h * w, and all input c_i should be the same.
+   *    - if axis = 1: n_1 * (c_1 + c_2 + ... + c_K) * h * w, and all input n_i should be the same.
+   */
+  override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+    if(param.getConcatParam.getAxis == 0) {
+      // rbind the inputs
+      assign(dmlScript, out, _getMultiFn("rbind"))
+    }
+    else if(param.getConcatParam.getAxis == 1) {
+      // cbind the inputs
+      assign(dmlScript, out, _getMultiFn("cbind"))
+    }
+    else {
+      throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
+    }
+  }
+  
+  def startIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
+  def endIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
+  def getConcatIndex(bottomLayerOut:String, outSuffix:String):String = 
+     startIndex(outSuffix) + " = " + endIndex(outSuffix) + " + 1; " +
+     endIndex(outSuffix) + " = " + startIndex(outSuffix) + " + nrow(" + bottomLayerOut + "); "
+  
+  /*
+   * Computes the backward pass for a concatenation layer.
+   *
+   * The top gradients are deconcatenated back to the inputs.
+   *
+   */
+  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
+    val bottomLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+    val dOutVar = "dOut" + id  + outSuffix
+    // concat_end_index = 0
+    dmlScript.append(dOutVar + " = " + dout + "; concat_end_index" + outSuffix + " = 0; ")
+    
+    val indexString = "concat_start_index" + outSuffix + " : concat_end_index" + outSuffix
+    val doutVarAssignment = if(param.getConcatParam.getAxis == 0) " = " + dOutVar + "[" +  indexString + ", ]; "
+                            else " = " + dOutVar + "[," +  indexString + " ]; "
+    
+    // concat_start_index = concat_end_index + 1
+    // concat_end_index = concat_start_index + $$ - 1
+    val initializeIndexString = "concat_start_index" + outSuffix + " = concat_end_index" + outSuffix + " + 1; concat_end_index" + outSuffix + 
+        " = concat_start_index" + outSuffix + " + $$ - 1; "
+    if(param.getConcatParam.getAxis == 0) {
+      bottomLayers.map(l => {
+        dmlScript.append(initializeIndexString.replaceAll("$$", nrow(l.out)))
+                  // X1 = Z[concat_start_index:concat_end_index,]
+                 .append( dX(l.id) + outSuffix + doutVarAssignment)
+      })
+    }
+    else {
+      bottomLayers.map(l => {
+        dmlScript.append(initializeIndexString.replaceAll("$$", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3) ))
+                  // X1 = Z[concat_start_index:concat_end_index,]
+                 .append( dX(l.id) + outSuffix + doutVarAssignment)
+      })
+    }
+    dmlScript.append("\n")
+  }
+  def sumChannels():String = {
+    val channels = _childLayers.map(_.outputShape._1)
+    try { 
+      channels.reduce((c1, c2) => (c1.toInt + c2.toInt).toString())
+    }
+    catch { 
+      case _:Throwable => sum(new StringBuilder, channels).toString
+    }
+  }
+  override def outputShape = {
+    if(_out == null) {
+      if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if(param.getConcatParam.getAxis == 0) {
+        _out = _childLayers(0).outputShape
+      }
+      else if(param.getConcatParam.getAxis == 1) {
+        _out = (sumChannels(), _childLayers(0).outputShape._2, _childLayers(0).outputShape._3)
+      }
+      else {
+        throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
+      }
+    }
+    _out
+  }
+  var _out:(String, String, String) = null
+}
+
 class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with IsLossLayer {
   // -------------------------------------------------
   override def sourceFileName = "softmax"
@@ -219,8 +475,13 @@ class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
     invokeForward(dmlScript, List[String](out), scores)
   override def backward(dmlScript:StringBuilder, outSuffix:String) =  {
-    invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", out, "yb")
-    invoke(dmlScript.append("\t"), "softmax::", List[String](dX + outSuffix), "backward", "dProbs", scores)
+    invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", false, out, "yb")
+    dmlScript.append("; ") 
+    invoke(dmlScript, "softmax::", List[String]("dOut" + id + outSuffix), "backward", false, "dProbs", scores)
+    val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+    dmlScript.append("; ")
+    bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+    dmlScript.append("\n")
   }
   override def computeLoss(dmlScript:StringBuilder, numTabs:Int) = {
     val tabBuilder = new StringBuilder
@@ -249,11 +510,36 @@ class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork
 }
 
 class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  // TODO: Leaky ReLU: negative_slope [default 0]: specifies whether to leak the negative part by multiplying it with the slope value rather than setting it to 0.
   // -------------------------------------------------
   override def sourceFileName = "relu"
   override def init(dmlScript:StringBuilder) = { }
+  /*
+   * Computes the forward pass for a ReLU nonlinearity layer.
+   *
+   * Performs an element-wise evaluation of `f(input) = max(0, input)`.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (any, any).
+   *
+   * Outputs:
+   *  - out: Outputs, of same shape as `X`.
+   */
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X)
+  /*
+   * Computes the backward pass for a ReLU nonlinearity layer.
+   *
+   * Essentially performs a pass-through of the upstream gradient
+   * for cells > 0.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of same shape as `X`.
+   *  - X: Previous input data matrix, of shape (any, any).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of same shape as `X`.
+   */
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
   // -------------------------------------------------
 }
 
@@ -261,29 +547,114 @@ class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
   // -------------------------------------------------
   override def sourceFileName = "dropout"
   override def init(dmlScript:StringBuilder) = { }
+  /*
+   * Computes the forward pass for an inverted dropout layer.
+   *
+   * Drops the inputs element-wise with a probability p, and divides
+   * by p to maintain the expected values of those inputs (which are
+   * the outputs of neurons) at test time.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (any, any).
+   *  - p: Probability of keeping a neuron output.
+   *  - seed: [Optional: -1] Random number generator seed to allow for
+   *      deterministic evaluation.  Set to -1 for a random seed.
+   *
+   * Outputs:
+   *  - out: Outputs, of same shape as `X`.
+   *  - mask: Dropout mask used to compute the output.
+   */
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
     if(!isPrediction)
       invokeForward(dmlScript, List[String](out, mask), X, p, seed)
     else
       assign(dmlScript, out, X) // Forward-pass not required to be performed during prediction for Dropout layer
-  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X, p, mask)
+  /*
+   * Computes the backward pass for an inverted dropout layer.
+   *
+   * Applies the mask to the upstream gradient, and divides by p to
+   * maintain the expected values at test time.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out`, of same shape as `X`.
+   *  - X: Inputs, of shape (any, any).
+   *  - p: Probability of keeping a neuron output.
+   *  - mask: Dropout mask used to compute the output.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of same shape as `X`.
+   */
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, 
+      List[String]("dOut" + id), dout, X, p, mask)
   // -------------------------------------------------
   def mask = "mask" + id
-  def p = param.getDropoutParam.getDropoutRatio.toString
+  // dropout ratio
+  def p = if(param.getDropoutParam.hasDropoutRatio()) param.getDropoutParam.getDropoutRatio.toString else "0.5"
   def seed = "-1"
 }
 
 class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
   // -------------------------------------------------
+  // TODO: bias_filler [default type: 'constant' value: 0]; bias_term [default true]: specifies whether to learn and apply a set of additive biases to the filter outputs
   override def sourceFileName = "affine"
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features (number of features).
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   */
   override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numFeatures, numNeurons)
+  /*
+   * Computes the forward pass for an affine (fully-connected) layer
+   * with M neurons.  The input data has N examples, each with D
+   * features.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, D).
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, M).
+   */
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
       invokeForward(dmlScript, List[String](out), X, weight, bias)
+  /*
+   * Computes the backward pass for a fully-connected (affine) layer
+   * with M neurons.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, M).
+   *  - X: Inputs, of shape (N, D).
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, D).
+   *  - dW: Gradient wrt `W`, of shape (D, M).
+   *  - db: Gradient wrt `b`, of shape (1, M).
+   */
   override def backward(dmlScript:StringBuilder, outSuffix:String) = 
-      invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, X, weight, bias)
+      invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, X, weight, bias)
   // -------------------------------------------------
+  // num_output (c_o): the number of filters
   def numNeurons = param.getInnerProductParam.getNumOutput.toString
   def numFeatures = int_mult(bottomLayerOutputShape._1, bottomLayerOutputShape._2, bottomLayerOutputShape._3)
+  // n * c_o * 1 * 1
   override def outputShape = ( param.getInnerProductParam.getNumOutput.toString, "1", "1" )
 }
 
@@ -291,11 +662,65 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
   // -------------------------------------------------
   override def sourceFileName = "max_pool2d_builtin"
   override def init(dmlScript:StringBuilder) = {}
+  /*
+   * Computes the forward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * This implementation uses a built-in operator for higher
+   * performance.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
     invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
         X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  /*
+   * Computes the backward pass for a 2D spatial max pooling layer.
+   * The input data has N examples, each represented as a 3D volume
+   * unrolled into a single vector.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of
+   *      shape (N, C*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      A typical value is 0.
+   *  - padw: Padding for left and right sides.
+   *      A typical value is 0.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   */
   override def backward(dmlScript:StringBuilder, outSuffix:String) = 
-    invokeBackward(dmlScript, outSuffix, List[String](dX), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  // n * c * h_o * w_o, where h_o and w_o are computed in the same way as convolution.
   override def outputShape = ( numChannels, Hout, Wout )
   // -------------------------------------------------
   def Hin = bottomLayerOutputShape._2
@@ -304,29 +729,129 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
   def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
   def poolingParam = param.getPoolingParam
   def numChannels = bottomLayerOutputShape._1
+  // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
   def kernel_h = if(poolingParam.hasKernelH) poolingParam.getKernelH.toString 
                    else poolingParam.getKernelSize.toString 
   def kernel_w = if(poolingParam.hasKernelW) poolingParam.getKernelW.toString 
                    else poolingParam.getKernelSize.toString
+  // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
   def stride_h = if(poolingParam.hasStrideH) poolingParam.getStrideH.toString 
-                   else poolingParam.getStride.toString
+                   else if(poolingParam.hasStride) poolingParam.getStride.toString
+                   else "1"
   def stride_w = if(poolingParam.hasStrideW) poolingParam.getStrideW.toString 
-                   else poolingParam.getStride.toString
+                   else if(poolingParam.hasStride) poolingParam.getStride.toString
+                   else "1"
+  // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
   def pad_h =   if(poolingParam.hasPadH) poolingParam.getPadH.toString 
-                   else poolingParam.getPad.toString
+                   else if(poolingParam.hasPad) poolingParam.getPad.toString
+                   else "0"
   def pad_w =   if(poolingParam.hasPadW) poolingParam.getPadW.toString 
-                   else poolingParam.getPad.toString
+                   else if(poolingParam.hasPad) poolingParam.getPad.toString
+                   else "0"
 }
 
 class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
   // -------------------------------------------------
   override def sourceFileName = "conv2d_builtin";
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   */
   override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
+  /*
+   * Computes the forward pass for a 2D spatial convolutional layer with
+   * F filters.  The input data has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * This implementation uses a built-in operator for higher
+   * performance.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      For same output height as input, set `padh = (Hf - 1) / 2`,
+   *      assuming `strideh = 1`.
+   *      More generally, `padh = (Hin*(strideh-1) + Hf - strideh) / 2`
+   *      preserves the spatial dimensions of the input.
+   *  - padw: Padding for left and right sides.
+   *      For same output width as input, set `padw = (Wf - 1) / 2`,
+   *      assuming `stridew = 1`.
+   *      More generally, `padw = (Win*(stridew-1) + Wf - stridew) / 2`
+   *      preserves the spatial dimensions of the input.
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
     invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
         X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  /*
+   * Computes the backward pass for a 2D spatial convolutional layer
+   * with F filters.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of
+   *      shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *      For same output height as input, set `padh = (Hf - 1) / 2`,
+   *      assuming `strideh = 1`.
+   *      More generally, `padh = (Hin*(strideh-1) + Hf - strideh) / 2`
+   *      preserves the spatial dimensions of the input.
+   *  - padw: Padding for left and right sides.
+   *      For same output width as input, set `padw = (Wf - 1) / 2`,
+   *      assuming `stridew = 1`.
+   *      More generally, `padw = (Win*(stridew-1) + Wf - stridew) / 2`
+   *      preserves the spatial dimensions of the input.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf).
+   *  - db: Gradient wrt `b`, of shape (F, 1).
+   */
   override def backward(dmlScript:StringBuilder, outSuffix:String) = 
-    invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  // n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
   override def outputShape = ( numKernels, Hout, Wout )
   // -------------------------------------------------
   def numChannels = bottomLayerOutputShape._1
@@ -334,24 +859,162 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
   def Win = bottomLayerOutputShape._3
   def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h) 
   def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+  // -------------------------------------------------
+  def convParam = param.getConvolutionParam
+  // num_output (c_o): the number of filters
+  def numKernels = convParam.getNumOutput.toString
+  // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
+  def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString 
+                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
+                   else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString 
+                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
+                   else throw new LanguageException("Incorrect kernel parameters")
+  // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
+  def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString 
+                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
+                   else "1"
+  def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString 
+                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
+                   else "1"
+  // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
+  def pad_h =   if(convParam.hasPadH) convParam.getPadH.toString 
+                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
+                   else "0"
+  def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
+                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
+                   else "0"
+}
+
+class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  override def sourceFileName: String = "conv2d_transpose"
+  /*
+   * Utility function to initialize the parameters of this layer.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - F: Number of filters.
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   */
+  override def init(dmlScript: StringBuilder): Unit = 
+    invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
+    
+  /*
+   * Computes the forward pass for a 2D spatial transpose convolutional
+   * layer with F filters.  The input data has N examples, each
+   * represented as a 3D tensor flattened into a single vector.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *  - out_padh: extra padding for top side. This should 
+   *      lie in [0, strideh-1].
+   *  - out_padw: extra padding for right side. This should
+   *      lie in [0, stridew-1].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   */
+  override def forward(dmlScript: StringBuilder,isPrediction: Boolean): Unit =
+    invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
+        X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, "0", "0")
+        
+  /*
+   * Computes the backward pass for a 2D spatial transpose
+   * convolutional layer with F filters.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of
+   *      shape (N, F*Hout*Wout).
+   *  - Hout: Output height.
+   *  - Wout: Output width.
+   *  - X: Inputs, of shape (N, C*Hin*Win).
+   *  - W: Weights, of shape (F, C*Hf*Wf).
+   *  - b: Biases, of shape (F, 1).
+   *  - C: Number of input channels (dimensionality of depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - Hf: Filter height.
+   *  - Wf: Filter width.
+   *  - strideh: Stride over height.
+   *  - stridew: Stride over width.
+   *  - padh: Padding for top and bottom sides.
+   *  - padw: Padding for left and right sides.
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+   *  - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf).
+   *  - db: Gradient wrt `b`, of shape (F, 1).
+   */
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
+    invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), 
+        dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  // n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
+  override def outputShape = ( numChannels, Hout, Wout )
+  // -------------------------------------------------
+  def numChannels = bottomLayerOutputShape._1
+  def Hin = bottomLayerOutputShape._2
+  def Win = bottomLayerOutputShape._3
+  // Hout = strideh * (Hin-1) - 2*padh + Hf + out_padh
+  def Hout:String =  try { 
+    (stride_h.toInt * (Hin.toInt-1) - 2*pad_h.toInt + kernel_h.toInt).toString()
+  }
+  catch { 
+    case _:Throwable => stride_h + " * " +  "(" + Hin + "-1) - 2*" + pad_h + " + " + kernel_h
+  }
+  // Wout = stridew * (Win-1) - 2*padw + Wf + out_padw
+  def Wout:String =  try { 
+    (stride_w.toInt * (Win.toInt-1) - 2*pad_w.toInt + kernel_w.toInt).toString()
+  }
+  catch { 
+    case _:Throwable => stride_w + " * " +  "(" + Win + "-1) - 2*" + pad_w + " + " + kernel_w
+  }
+  // -------------------------------------------------
   def convParam = param.getConvolutionParam
+  // num_output (c_o): the number of filters
   def numKernels = convParam.getNumOutput.toString
+  // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
   def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString 
                    else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
                    else throw new LanguageException("Incorrect kernel parameters")
   def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString 
                    else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
                    else throw new LanguageException("Incorrect kernel parameters")
+  // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
   def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString 
                    else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else throw new LanguageException("Incorrect stride parameters:" + convParam.getStrideH + " " + convParam.getStrideList + " " + param.getName)
+                   else "1"
   def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString 
                    else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
-                   else throw new LanguageException("Incorrect stride parameters")
+                   else "1"
+  // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
   def pad_h =   if(convParam.hasPadH) convParam.getPadH.toString 
                    else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else throw new LanguageException("Incorrect pad parameters")
+                   else "0"
   def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
                    else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
-                   else throw new LanguageException("Incorrect pad parameters")
+                   else "0"
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index 73490b4..c106cb7 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -125,6 +125,18 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
     }
     else l
   })
+  
+  // Condition 5: Deal with incorrect naming
+  // Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
+  private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
+  _caffeLayerParams = _caffeLayerParams.map(l => {
+    if(isIncorrectNamingLayer(l)) {
+      val builder = l.toBuilder();
+      builder.setName(l.getTop(0))
+      builder.build()
+    }
+    else l
+  })
 
   // --------------------------------------------------------------------------------
   
@@ -174,6 +186,8 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
       case "batchnorm" => new BatchNorm(param, id, this)
       case "scale" => new Scale(param, id, this)
       case "eltwise" => new Elementwise(param, id, this)
+      case "concat" => new Concat(param, id, this)
+      case "deconvolution" => new DeConvolution(param, id, this)
       case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index ec4269a..668d996 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -74,6 +74,9 @@ trait BaseDMLGenerator {
     dmlScript.append("\n")
   }
   def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String]):Unit = {
+    invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
+  }
+  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String], appendNewLine:Boolean):Unit = {
     if(returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
     if(returnVariables.length > 1) dmlScript.append("[")
     dmlScript.append(returnVariables(0))
@@ -96,10 +99,15 @@ trait BaseDMLGenerator {
   	    }
       }
     }
-    dmlScript.append(")\n")
+    dmlScript.append(")")
+    if(appendNewLine) 
+      dmlScript.append("\n")
+  }
+  def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, appendNewLine:Boolean, arguments:String*):Unit = {
+    invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
   }
   def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:String*):Unit = {
-    invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList)
+    invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
   }
   def rightIndexing(dmlScript:StringBuilder, varName:String, rl:String, ru:String, cl:String, cu:String):StringBuilder = {
     dmlScript.append(varName).append("[")
@@ -244,7 +252,7 @@ trait VisualizeDMLGenerator extends TabbedDMLGenerator {
 	      case "dweight" => l.dWeight
 	      case "dbias" => l.dBias
 	      case "output" => l.out
-	      case "doutput" => l.dX
+	      // case "doutput" => l.dX
 	      case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
 	    }
 	   }