You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/13 17:47:41 UTC
incubator-systemml git commit: [SYSTEMML-1479] Integrated
DeConvolution and Concat layer in Caffe2DML
Repository: incubator-systemml
Updated Branches:
refs/heads/master 59ff8b3d4 -> 7ed8e3f49
[SYSTEMML-1479] Integrated DeConvolution and Concat layer in Caffe2DML
Closes #492.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7ed8e3f4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7ed8e3f4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7ed8e3f4
Branch: refs/heads/master
Commit: 7ed8e3f4994b80f8e98f16e40f12667180fd6a02
Parents: 59ff8b3
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Sat May 13 10:45:55 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Sat May 13 10:45:55 2017 -0700
----------------------------------------------------------------------
.../org/apache/sysml/api/dl/CaffeLayer.scala | 713 ++++++++++++++++++-
.../org/apache/sysml/api/dl/CaffeNetwork.scala | 14 +
.../org/apache/sysml/api/dl/DMLGenerator.scala | 14 +-
3 files changed, 713 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index 4faa203..0d1740e 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -76,12 +76,12 @@ trait CaffeLayer extends BaseDMLGenerator {
if(computedDout == null) {
val ret = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
if(ret.size == 0) throw new LanguageException("Expected atleast 1 top layer for " + param.getName)
- else if(ret.size == 1) computedDout = ret(0).dX
- else computedDout = sum(new StringBuilder, ret.map(_.dX).toList).toString()
+ else if(ret.size == 1) computedDout = ret(0).dX(id)
+ else computedDout = sum(new StringBuilder, ret.map(_.dX(id)).toList).toString()
}
computedDout
}
- val dX = "dOut" + id
+ def dX(bottomLayerID:Int) = "dOut" + id + "_" + bottomLayerID
// --------------------------------------------------------------------------------------
// No need to override these methods in subclasses, instead classes that have weights and biases
// should implement HasWeight and HasBias traits.
@@ -99,8 +99,33 @@ trait CaffeLayer extends BaseDMLGenerator {
def invokeForward(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
invoke(dmlScript, sourceFileName + "::", returnVariables, "forward", arguments.toList)
}
+ // -----------------------------------------------------------------------------------
+ // All the layers (with the exception of Concat) call one of the below methods in the backward function.
+ // The preceding layer expects that 'dX(bottomLayerID) + outSuffix' is assigned.
+ // l1 <--- dX(1) <-----|
+ // |-- [current layer: dOut3 (computed by backward)] <---- "dOut" + id + outSuffix
+ // l2 <--- dX(2) <-----|
+ // The below functions perform two functions:
+ // 1. Compute backward: either call dml file's backward (for example: invokeBackward) or just propagate next layers errors (assignDoutToDX)
+ // 2. Then make sure that all the preceding layer get the errors using:
+ // bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+
+ // The layers that have a corresponding dml script call this method.
+ // Assumption: the first variable of resultVariables is always dX
def invokeBackward(dmlScript:StringBuilder, outSuffix:String, resultVariables:List[String], arguments:String*):Unit = {
- invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList)
+ invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList, false)
+ val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+ dmlScript.append("; ")
+ bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + resultVariables(0) + outSuffix + "; "))
+ dmlScript.append("\n")
+ }
+ // On-the-fly layers (such as Scale and Elementwise) call this function to propagate next layers errors to the previous layer
+ def assignDoutToDX(dmlScript:StringBuilder, outSuffix:String):Unit = {
+ dmlScript.append("dOut" + id + outSuffix + " = " + dout)
+ val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+ dmlScript.append("; ")
+ bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+ dmlScript.append("\n")
}
// --------------------------------------------------------------------------------------
}
@@ -144,16 +169,135 @@ class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numCh
class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
// val scale =
override def sourceFileName = "batch_norm2d"
+ /*
+ * Initialize the parameters of this layer.
+ *
+ * Note: This is just a convenience function, and parameters
+ * may be initialized manually if needed.
+ *
+ * Inputs:
+ * - C: Number of input channels (dimensionality of input depth).
+ *
+ * Outputs:
+ * - gamma: Scale parameters, of shape (C, 1).
+ * - beta: Shift parameters, of shape (C, 1).
+ * - ema_mean: Exponential moving average of the mean, of
+ * shape (C, 1).
+ * - ema_var: Exponential moving average of the variance, of
+ * shape (C, 1).
+ */
override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](gamma, beta, ema_mean, ema_var), numChannels)
var update_mean_var = true
+ /*
+ * Computes the forward pass for a 2D (spatial) batch normalization
+ * layer. The input data has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * A spatial batch normalization layer uses the per-channel sample
+ * mean and per-channel uncorrected sample variance during training
+ * to normalize each channel of the input data. Additionally, it
+ * introduces learnable parameters (gamma, beta) to control the
+ * amount of normalization.
+ *
+ * `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
+ *
+ * This implementation maintains exponential moving averages of the
+ * mean and variance during training for use during testing.
+ *
+ * Reference:
+ * - Batch Normalization: Accelerating Deep Network Training by
+ * Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+ * - https://arxiv.org/abs/1502.03167
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - gamma: Scale parameters, of shape (C, 1).
+ * - beta: Shift parameters, of shape (C, 1).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - mode: 'train' or 'test' to indicate if the model is currently
+ * being trained or tested. During training, the current batch
+ * mean and variance will be used to normalize the inputs, while
+ * during testing, the exponential average of the mean and
+ * variance over all previous batches will be used.
+ * - ema_mean: Exponential moving average of the mean, of
+ * shape (C, 1).
+ * - ema_var: Exponential moving average of the variance, of
+ * shape (C, 1).
+ * - mu: Momentum value for moving averages.
+ * Typical values are in the range of [0.9, 0.999].
+ * - epsilon: Smoothing term to avoid divide by zero errors.
+ * Typical values are in the range of [1e-5, 1e-3].
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, C*Hin*Win).
+ * - ema_mean_upd: Updated exponential moving average of the mean,
+ * of shape (C, 1).
+ * - ema_var_upd: Updated exponential moving average of the variance,
+ * of shape (C, 1).
+ * - cache_mean: Cache of the batch mean, of shape (C, 1).
+ * Note: This is used for performance during training.
+ * - cache_var: Cache of the batch variance, of shape (C, 1).
+ * Note: This is used for performance during training.
+ * - cache_norm: Cache of the normalized inputs, of
+ * shape (C, N*Hin*Win). Note: This is used for performance
+ * during training.
+ */
def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
val mode = if(isPrediction) "\"test\"" else "\"train\""
invokeForward(dmlScript, List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)),
X, gamma, beta, numChannels, Hin, Win, mode, ema_mean, ema_var, ma_fraction, eps)
}
-
+ /*
+ * Computes the backward pass for a 2D (spatial) batch normalization
+ * layer.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
+ * - out: Outputs from the forward pass, of shape (N, C*Hin*Win).
+ * - ema_mean_upd: Updated exponential moving average of the mean
+ * from the forward pass, of shape (C, 1).
+ * - ema_var_upd: Updated exponential moving average of the variance
+ * from the forward pass, of shape (C, 1).
+ * - cache_mean: Cache of the batch mean from the forward pass, of
+ * shape (C, 1). Note: This is used for performance during
+ * training.
+ * - cache_var: Cache of the batch variance from the forward pass,
+ * of shape (C, 1). Note: This is used for performance during
+ * training.
+ * - cache_norm: Cache of the normalized inputs from the forward
+ * pass, of shape (C, N*Hin*Win). Note: This is used for
+ * performance during training.
+ * - X: Input data matrix to the forward pass, of
+ * shape (N, C*Hin*Win).
+ * - gamma: Scale parameters, of shape (C, 1).
+ * - beta: Shift parameters, of shape (C, 1).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - mode: 'train' or 'test' to indicate if the model is currently
+ * being trained or tested. During training, the current batch
+ * mean and variance will be used to normalize the inputs, while
+ * during testing, the exponential average of the mean and
+ * variance over all previous batches will be used.
+ * - ema_mean: Exponential moving average of the mean, of
+ * shape (C, 1).
+ * - ema_var: Exponential moving average of the variance, of
+ * shape (C, 1).
+ * - mu: Momentum value for moving averages.
+ * Typical values are in the range of [0.9, 0.999].
+ * - epsilon: Smoothing term to avoid divide by zero errors.
+ * Typical values are in the range of [1e-5, 1e-3].
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+ * - dgamma: Gradient wrt `W`, of shape (C, 1).
+ * - dbeta: Gradient wrt `b`, of shape (C, 1).
+ *
+ */
def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
- invokeBackward(dmlScript, outSuffix, List[String](dX, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels,
+ invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels,
Hin, Win, "\"train\"", ema_mean, ema_var, ma_fraction, eps)
}
@@ -190,8 +334,9 @@ class Scale(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends
if(!param.getScaleParam.getBiasTerm) throw new LanguageException("Add \"scale_param { bias_term: true }\" to the layer " + param.getName)
override def sourceFileName = null
override def init(dmlScript: StringBuilder): Unit = {}
+ // TODO: Generalize this !!
def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = assign(dmlScript, out, X)
- override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+ override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
}
// ------------------------------------------------------------------
@@ -200,10 +345,10 @@ class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
override def init(dmlScript: StringBuilder): Unit = {}
if(param.getEltwiseParam.hasOperation && param.getEltwiseParam.getOperation != EltwiseOp.SUM)
throw new LanguageException("Currently only elementwise sum operation supported")
- def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+ override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
addAndAssign(dmlScript, out, param.getBottomList.map(b => net.getCaffeLayer(b).out).toList)
}
- override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+ override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
override def outputShape = {
if(_out == null) _out = net.getCaffeLayer(net.getBottomLayers(param.getName).take(1).toSeq.get(0)).outputShape
_out
@@ -212,6 +357,117 @@ class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
}
+class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+ override def sourceFileName = null
+ override def init(dmlScript: StringBuilder): Unit = {}
+ var _childLayers:List[CaffeLayer] = null
+
+ // Utility function to create string of format:
+ // fn(fn(fn(_childLayers(0).out, _childLayers(1).out), _childLayers(2).out), ...)
+ // This is useful because we do not support multi-input cbind and rbind in DML.
+ def _getMultiFn(fn:String):String = {
+ if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+ var tmp = fn + "(" + _childLayers(0).out + ", " + _childLayers(1).out + ")"
+ for(i <- 2 until _childLayers.size) {
+ tmp = fn + "(" + tmp + ", " + _childLayers(i).out + ")"
+ }
+ tmp
+ }
+
+ /*
+ * Computes the forward pass for a concatenation layer.
+ *
+ * Inputs:
+ * - n_i * c_i * h * w for each input blob i from 1 to K.
+ *
+ * Outputs:
+ * - out: Outputs, of shape
+ * - if axis = 0: (n_1 + n_2 + ... + n_K) * c_1 * h * w, and all input c_i should be the same.
+ * - if axis = 1: n_1 * (c_1 + c_2 + ... + c_K) * h * w, and all input n_i should be the same.
+ */
+ override def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+ if(param.getConcatParam.getAxis == 0) {
+ // rbind the inputs
+ assign(dmlScript, out, _getMultiFn("rbind"))
+ }
+ else if(param.getConcatParam.getAxis == 1) {
+ // cbind the inputs
+ assign(dmlScript, out, _getMultiFn("cbind"))
+ }
+ else {
+ throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
+ }
+ }
+
+ def startIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
+ def endIndex(outSuffix:String):String = "concat_start_index_" + outSuffix
+ def getConcatIndex(bottomLayerOut:String, outSuffix:String):String =
+ startIndex(outSuffix) + " = " + endIndex(outSuffix) + " + 1; " +
+ endIndex(outSuffix) + " = " + startIndex(outSuffix) + " + nrow(" + bottomLayerOut + "); "
+
+ /*
+ * Computes the backward pass for a concatenation layer.
+ *
+ * The top gradients are deconcatenated back to the inputs.
+ *
+ */
+ override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
+ val bottomLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+ val dOutVar = "dOut" + id + outSuffix
+ // concat_end_index = 0
+ dmlScript.append(dOutVar + " = " + dout + "; concat_end_index" + outSuffix + " = 0; ")
+
+ val indexString = "concat_start_index" + outSuffix + " : concat_end_index" + outSuffix
+ val doutVarAssignment = if(param.getConcatParam.getAxis == 0) " = " + dOutVar + "[" + indexString + ", ]; "
+ else " = " + dOutVar + "[," + indexString + " ]; "
+
+ // concat_start_index = concat_end_index + 1
+ // concat_end_index = concat_start_index + $$ - 1
+ val initializeIndexString = "concat_start_index" + outSuffix + " = concat_end_index" + outSuffix + " + 1; concat_end_index" + outSuffix +
+ " = concat_start_index" + outSuffix + " + $$ - 1; "
+ if(param.getConcatParam.getAxis == 0) {
+ bottomLayers.map(l => {
+ dmlScript.append(initializeIndexString.replaceAll("$$", nrow(l.out)))
+ // X1 = Z[concat_start_index:concat_end_index,]
+ .append( dX(l.id) + outSuffix + doutVarAssignment)
+ })
+ }
+ else {
+ bottomLayers.map(l => {
+ dmlScript.append(initializeIndexString.replaceAll("$$", int_mult(l.outputShape._1, l.outputShape._2, l.outputShape._3) ))
+ // X1 = Z[concat_start_index:concat_end_index,]
+ .append( dX(l.id) + outSuffix + doutVarAssignment)
+ })
+ }
+ dmlScript.append("\n")
+ }
+ def sumChannels():String = {
+ val channels = _childLayers.map(_.outputShape._1)
+ try {
+ channels.reduce((c1, c2) => (c1.toInt + c2.toInt).toString())
+ }
+ catch {
+ case _:Throwable => sum(new StringBuilder, channels).toString
+ }
+ }
+ override def outputShape = {
+ if(_out == null) {
+ if(_childLayers == null) _childLayers = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+ if(param.getConcatParam.getAxis == 0) {
+ _out = _childLayers(0).outputShape
+ }
+ else if(param.getConcatParam.getAxis == 1) {
+ _out = (sumChannels(), _childLayers(0).outputShape._2, _childLayers(0).outputShape._3)
+ }
+ else {
+ throw new DMLRuntimeException("Incorrect axis parameter for the layer " + param.getName)
+ }
+ }
+ _out
+ }
+ var _out:(String, String, String) = null
+}
+
class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with IsLossLayer {
// -------------------------------------------------
override def sourceFileName = "softmax"
@@ -219,8 +475,13 @@ class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
invokeForward(dmlScript, List[String](out), scores)
override def backward(dmlScript:StringBuilder, outSuffix:String) = {
- invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", out, "yb")
- invoke(dmlScript.append("\t"), "softmax::", List[String](dX + outSuffix), "backward", "dProbs", scores)
+ invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", false, out, "yb")
+ dmlScript.append("; ")
+ invoke(dmlScript, "softmax::", List[String]("dOut" + id + outSuffix), "backward", false, "dProbs", scores)
+ val bottomLayerIDs = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l).id)
+ dmlScript.append("; ")
+ bottomLayerIDs.map(bottomLayerID => dmlScript.append( dX(bottomLayerID) + outSuffix + " = " + "dOut" + id + outSuffix + "; "))
+ dmlScript.append("\n")
}
override def computeLoss(dmlScript:StringBuilder, numTabs:Int) = {
val tabBuilder = new StringBuilder
@@ -249,11 +510,36 @@ class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork
}
class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+ // TODO: Leaky ReLU: negative_slope [default 0]: specifies whether to leak the negative part by multiplying it with the slope value rather than setting it to 0.
// -------------------------------------------------
override def sourceFileName = "relu"
override def init(dmlScript:StringBuilder) = { }
+ /*
+ * Computes the forward pass for a ReLU nonlinearity layer.
+ *
+ * Performs an element-wise evaluation of `f(input) = max(0, input)`.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (any, any).
+ *
+ * Outputs:
+ * - out: Outputs, of same shape as `X`.
+ */
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
- override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X)
+ /*
+ * Computes the backward pass for a ReLU nonlinearity layer.
+ *
+ * Essentially performs a pass-through of the upstream gradient
+ * for cells > 0.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of same shape as `X`.
+ * - X: Previous input data matrix, of shape (any, any).
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of same shape as `X`.
+ */
+ override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
// -------------------------------------------------
}
@@ -261,29 +547,114 @@ class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
// -------------------------------------------------
override def sourceFileName = "dropout"
override def init(dmlScript:StringBuilder) = { }
+ /*
+ * Computes the forward pass for an inverted dropout layer.
+ *
+ * Drops the inputs element-wise with a probability p, and divides
+ * by p to maintain the expected values of those inputs (which are
+ * the outputs of neurons) at test time.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (any, any).
+ * - p: Probability of keeping a neuron output.
+ * - seed: [Optional: -1] Random number generator seed to allow for
+ * deterministic evaluation. Set to -1 for a random seed.
+ *
+ * Outputs:
+ * - out: Outputs, of same shape as `X`.
+ * - mask: Dropout mask used to compute the output.
+ */
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
if(!isPrediction)
invokeForward(dmlScript, List[String](out, mask), X, p, seed)
else
assign(dmlScript, out, X) // Forward-pass not required to be performed during prediction for Dropout layer
- override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X, p, mask)
+ /*
+ * Computes the backward pass for an inverted dropout layer.
+ *
+ * Applies the mask to the upstream gradient, and divides by p to
+ * maintain the expected values at test time.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out`, of same shape as `X`.
+ * - X: Inputs, of shape (any, any).
+ * - p: Probability of keeping a neuron output.
+ * - mask: Dropout mask used to compute the output.
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of same shape as `X`.
+ */
+ override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix,
+ List[String]("dOut" + id), dout, X, p, mask)
// -------------------------------------------------
def mask = "mask" + id
- def p = param.getDropoutParam.getDropoutRatio.toString
+ // dropout ratio
+ def p = if(param.getDropoutParam.hasDropoutRatio()) param.getDropoutParam.getDropoutRatio.toString else "0.5"
def seed = "-1"
}
class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
// -------------------------------------------------
+ // TODO: bias_filler [default type: 'constant' value: 0]; bias_term [default true]: specifies whether to learn and apply a set of additive biases to the filter outputs
override def sourceFileName = "affine"
+ /*
+ * Initialize the parameters of this layer.
+ *
+ * Note: This is just a convenience function, and parameters
+ * may be initialized manually if needed.
+ *
+ * We use the heuristic by He et al., which limits the magnification
+ * of inputs/gradients during forward/backward passes by scaling
+ * unit-Gaussian weights by a factor of sqrt(2/n), under the
+ * assumption of relu neurons.
+ * - http://arxiv.org/abs/1502.01852
+ *
+ * Inputs:
+ * - D: Dimensionality of the input features (number of features).
+ * - M: Number of neurons in this layer.
+ *
+ * Outputs:
+ * - W: Weights, of shape (D, M).
+ * - b: Biases, of shape (1, M).
+ */
override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numFeatures, numNeurons)
+ /*
+ * Computes the forward pass for an affine (fully-connected) layer
+ * with M neurons. The input data has N examples, each with D
+ * features.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, D).
+ * - W: Weights, of shape (D, M).
+ * - b: Biases, of shape (1, M).
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, M).
+ */
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
invokeForward(dmlScript, List[String](out), X, weight, bias)
+ /*
+ * Computes the backward pass for a fully-connected (affine) layer
+ * with M neurons.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of shape (N, M).
+ * - X: Inputs, of shape (N, D).
+ * - W: Weights, of shape (D, M).
+ * - b: Biases, of shape (1, M).
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of shape (N, D).
+ * - dW: Gradient wrt `W`, of shape (D, M).
+ * - db: Gradient wrt `b`, of shape (1, M).
+ */
override def backward(dmlScript:StringBuilder, outSuffix:String) =
- invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, X, weight, bias)
+ invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, X, weight, bias)
// -------------------------------------------------
+ // num_output (c_o): the number of filters
def numNeurons = param.getInnerProductParam.getNumOutput.toString
def numFeatures = int_mult(bottomLayerOutputShape._1, bottomLayerOutputShape._2, bottomLayerOutputShape._3)
+ // n * c_o * 1 * 1
override def outputShape = ( param.getInnerProductParam.getNumOutput.toString, "1", "1" )
}
@@ -291,11 +662,65 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
// -------------------------------------------------
override def sourceFileName = "max_pool2d_builtin"
override def init(dmlScript:StringBuilder) = {}
+ /*
+ * Computes the forward pass for a 2D spatial max pooling layer.
+ * The input data has N examples, each represented as a 3D volume
+ * unrolled into a single vector.
+ *
+ * This implementation uses a built-in operator for higher
+ * performance.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * A typical value is 0.
+ * - padw: Padding for left and right sides.
+ * A typical value is 0.
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, C*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ */
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id),
X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ /*
+ * Computes the backward pass for a 2D spatial max pooling layer.
+ * The input data has N examples, each represented as a 3D volume
+ * unrolled into a single vector.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of
+ * shape (N, C*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - C: Number of input channels (dimensionality of input depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * A typical value is 0.
+ * - padw: Padding for left and right sides.
+ * A typical value is 0.
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+ */
override def backward(dmlScript:StringBuilder, outSuffix:String) =
- invokeBackward(dmlScript, outSuffix, List[String](dX), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ // n * c * h_o * w_o, where h_o and w_o are computed in the same way as convolution.
override def outputShape = ( numChannels, Hout, Wout )
// -------------------------------------------------
def Hin = bottomLayerOutputShape._2
@@ -304,29 +729,129 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
def Wout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
def poolingParam = param.getPoolingParam
def numChannels = bottomLayerOutputShape._1
+ // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
def kernel_h = if(poolingParam.hasKernelH) poolingParam.getKernelH.toString
else poolingParam.getKernelSize.toString
def kernel_w = if(poolingParam.hasKernelW) poolingParam.getKernelW.toString
else poolingParam.getKernelSize.toString
+ // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
def stride_h = if(poolingParam.hasStrideH) poolingParam.getStrideH.toString
- else poolingParam.getStride.toString
+ else if(poolingParam.hasStride) poolingParam.getStride.toString
+ else "1"
def stride_w = if(poolingParam.hasStrideW) poolingParam.getStrideW.toString
- else poolingParam.getStride.toString
+ else if(poolingParam.hasStride) poolingParam.getStride.toString
+ else "1"
+ // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
def pad_h = if(poolingParam.hasPadH) poolingParam.getPadH.toString
- else poolingParam.getPad.toString
+ else if(poolingParam.hasPad) poolingParam.getPad.toString
+ else "0"
def pad_w = if(poolingParam.hasPadW) poolingParam.getPadW.toString
- else poolingParam.getPad.toString
+ else if(poolingParam.hasPad) poolingParam.getPad.toString
+ else "0"
}
class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
// -------------------------------------------------
override def sourceFileName = "conv2d_builtin";
+ /*
+ * Initialize the parameters of this layer.
+ *
+ * Note: This is just a convenience function, and parameters
+ * may be initialized manually if needed.
+ *
+ * We use the heuristic by He et al., which limits the magnification
+ * of inputs/gradients during forward/backward passes by scaling
+ * unit-Gaussian weights by a factor of sqrt(2/n), under the
+ * assumption of relu neurons.
+ * - http://arxiv.org/abs/1502.01852
+ *
+ * Inputs:
+ * - F: Number of filters.
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ *
+ * Outputs:
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ */
override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
+ /*
+ * Computes the forward pass for a 2D spatial convolutional layer with
+ * F filters. The input data has N examples, each represented as a 3D
+ * volume unrolled into a single vector.
+ *
+ * This implementation uses a built-in operator for higher
+ * performance.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * For same output height as input, set `padh = (Hf - 1) / 2`,
+ * assuming `strideh = 1`.
+ * More generally, `padh = (Hin*(strideh-1) + Hf - strideh) / 2`
+ * preserves the spatial dimensions of the input.
+ * - padw: Padding for left and right sides.
+ * For same output width as input, set `padw = (Wf - 1) / 2`,
+ * assuming `stridew = 1`.
+ * More generally, `padw = (Win*(stridew-1) + Wf - stridew) / 2`
+ * preserves the spatial dimensions of the input.
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, F*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ */
override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id),
X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ /*
+ * Computes the backward pass for a 2D spatial convolutional layer
+ * with F filters.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of
+ * shape (N, F*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * For same output height as input, set `padh = (Hf - 1) / 2`,
+ * assuming `strideh = 1`.
+ * More generally, `padh = (Hin*(strideh-1) + Hf - strideh) / 2`
+ * preserves the spatial dimensions of the input.
+ * - padw: Padding for left and right sides.
+ * For same output width as input, set `padw = (Wf - 1) / 2`,
+ * assuming `stridew = 1`.
+ * More generally, `padw = (Win*(stridew-1) + Wf - stridew) / 2`
+ * preserves the spatial dimensions of the input.
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+ * - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf).
+ * - db: Gradient wrt `b`, of shape (F, 1).
+ */
override def backward(dmlScript:StringBuilder, outSuffix:String) =
- invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ // n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
override def outputShape = ( numKernels, Hout, Wout )
// -------------------------------------------------
def numChannels = bottomLayerOutputShape._1
@@ -334,24 +859,162 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
def Win = bottomLayerOutputShape._3
def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h)
def Wout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+ // -------------------------------------------------
+ def convParam = param.getConvolutionParam
+ // num_output (c_o): the number of filters
+ def numKernels = convParam.getNumOutput.toString
+ // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
+ def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString
+ else if(convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+ else throw new LanguageException("Incorrect kernel parameters")
+ def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString
+ else if(convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
+ else throw new LanguageException("Incorrect kernel parameters")
+ // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
+ def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString
+ else if(convParam.getStrideCount > 0) convParam.getStride(0).toString
+ else "1"
+ def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString
+ else if(convParam.getStrideCount > 0) convParam.getStride(0).toString
+ else "1"
+ // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
+ def pad_h = if(convParam.hasPadH) convParam.getPadH.toString
+ else if(convParam.getPadCount > 0) convParam.getPad(0).toString
+ else "0"
+ def pad_w = if(convParam.hasPadW) convParam.getPadW.toString
+ else if(convParam.getPadCount > 0) convParam.getPad(0).toString
+ else "0"
+}
+
+class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+ override def sourceFileName: String = "conv2d_transpose"
+ /*
+ * Utility function to initialize the parameters of this layer.
+ *
+ * We use the heuristic by He et al., which limits the magnification
+ * of inputs/gradients during forward/backward passes by scaling
+ * unit-Gaussian weights by a factor of sqrt(2/n), under the
+ * assumption of relu neurons.
+ * - http://arxiv.org/abs/1502.01852
+ *
+ * Inputs:
+ * - F: Number of filters.
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ *
+ * Outputs:
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ */
+ override def init(dmlScript: StringBuilder): Unit =
+ invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
+
+ /*
+ * Computes the forward pass for a 2D spatial transpose convolutional
+ * layer with F filters. The input data has N examples, each
+ * represented as a 3D tensor flattened into a single vector.
+ *
+ * Inputs:
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * - padw: Padding for left and right sides.
+ * - out_padh: extra padding for top side. This should
+ * lie in [0, strideh-1].
+ * - out_padw: extra padding for right side. This should
+ * lie in [0, stridew-1].
+ *
+ * Outputs:
+ * - out: Outputs, of shape (N, F*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ */
+ override def forward(dmlScript: StringBuilder,isPrediction: Boolean): Unit =
+ invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id),
+ X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, "0", "0")
+
+ /*
+ * Computes the backward pass for a 2D spatial transpose
+ * convolutional layer with F filters.
+ *
+ * Inputs:
+ * - dout: Gradient wrt `out` from upstream, of
+ * shape (N, F*Hout*Wout).
+ * - Hout: Output height.
+ * - Wout: Output width.
+ * - X: Inputs, of shape (N, C*Hin*Win).
+ * - W: Weights, of shape (F, C*Hf*Wf).
+ * - b: Biases, of shape (F, 1).
+ * - C: Number of input channels (dimensionality of depth).
+ * - Hin: Input height.
+ * - Win: Input width.
+ * - Hf: Filter height.
+ * - Wf: Filter width.
+ * - strideh: Stride over height.
+ * - stridew: Stride over width.
+ * - padh: Padding for top and bottom sides.
+ * - padw: Padding for left and right sides.
+ *
+ * Outputs:
+ * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
+ * - dW: Gradient wrt `W`, of shape (F, C*Hf*Wf).
+ * - db: Gradient wrt `b`, of shape (F, 1).
+ */
+ override def backward(dmlScript:StringBuilder, outSuffix:String) =
+ invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id, dWeight, dBias),
+ dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+ // n * c_o * h_o * w_o, where h_o = (h_i + 2 * pad_h - kernel_h) / stride_h + 1 and w_o likewise.
+ override def outputShape = ( numChannels, Hout, Wout )
+ // -------------------------------------------------
+ def numChannels = bottomLayerOutputShape._1
+ def Hin = bottomLayerOutputShape._2
+ def Win = bottomLayerOutputShape._3
+ // Hout = strideh * (Hin-1) - 2*padh + Hf + out_padh
+ def Hout:String = try {
+ (stride_h.toInt * (Hin.toInt-1) - 2*pad_h.toInt + kernel_h.toInt).toString()
+ }
+ catch {
+ case _:Throwable => stride_h + " * " + "(" + Hin + "-1) - 2*" + pad_h + " + " + kernel_h
+ }
+ // Wout = stridew * (Win-1) - 2*padw + Wf + out_padw
+ def Wout:String = try {
+ (stride_w.toInt * (Win.toInt-1) - 2*pad_w.toInt + kernel_w.toInt).toString()
+ }
+ catch {
+ case _:Throwable => stride_w + " * " + "(" + Win + "-1) - 2*" + pad_w + " + " + kernel_w
+ }
+ // -------------------------------------------------
def convParam = param.getConvolutionParam
+ // num_output (c_o): the number of filters
def numKernels = convParam.getNumOutput.toString
+ // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString
else if(convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
else throw new LanguageException("Incorrect kernel parameters")
def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString
else if(convParam.getKernelSizeCount > 0) convParam.getKernelSize(0).toString
else throw new LanguageException("Incorrect kernel parameters")
+ // stride (or stride_h and stride_w) [default 1]: specifies the intervals at which to apply the filters to the input
def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString
else if(convParam.getStrideCount > 0) convParam.getStride(0).toString
- else throw new LanguageException("Incorrect stride parameters:" + convParam.getStrideH + " " + convParam.getStrideList + " " + param.getName)
+ else "1"
def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString
else if(convParam.getStrideCount > 0) convParam.getStride(0).toString
- else throw new LanguageException("Incorrect stride parameters")
+ else "1"
+ // pad (or pad_h and pad_w) [default 0]: specifies the number of pixels to (implicitly) add to each side of the input
def pad_h = if(convParam.hasPadH) convParam.getPadH.toString
else if(convParam.getPadCount > 0) convParam.getPad(0).toString
- else throw new LanguageException("Incorrect pad parameters")
+ else "0"
def pad_w = if(convParam.hasPadW) convParam.getPadW.toString
else if(convParam.getPadCount > 0) convParam.getPad(0).toString
- else throw new LanguageException("Incorrect pad parameters")
+ else "0"
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index 73490b4..c106cb7 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -125,6 +125,18 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
}
else l
})
+
+ // Condition 5: Deal with incorrect naming
+ // Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
+ private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
+ _caffeLayerParams = _caffeLayerParams.map(l => {
+ if(isIncorrectNamingLayer(l)) {
+ val builder = l.toBuilder();
+ builder.setName(l.getTop(0))
+ builder.build()
+ }
+ else l
+ })
// --------------------------------------------------------------------------------
@@ -174,6 +186,8 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
case "batchnorm" => new BatchNorm(param, id, this)
case "scale" => new Scale(param, id, this)
case "eltwise" => new Elementwise(param, id, this)
+ case "concat" => new Concat(param, id, this)
+ case "deconvolution" => new DeConvolution(param, id, this)
case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ed8e3f4/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index ec4269a..668d996 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -74,6 +74,9 @@ trait BaseDMLGenerator {
dmlScript.append("\n")
}
def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String]):Unit = {
+ invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
+ }
+ def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String], appendNewLine:Boolean):Unit = {
if(returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
if(returnVariables.length > 1) dmlScript.append("[")
dmlScript.append(returnVariables(0))
@@ -96,10 +99,15 @@ trait BaseDMLGenerator {
}
}
}
- dmlScript.append(")\n")
+ dmlScript.append(")")
+ if(appendNewLine)
+ dmlScript.append("\n")
+ }
+ def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, appendNewLine:Boolean, arguments:String*):Unit = {
+ invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
}
def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:String*):Unit = {
- invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList)
+ invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
}
def rightIndexing(dmlScript:StringBuilder, varName:String, rl:String, ru:String, cl:String, cu:String):StringBuilder = {
dmlScript.append(varName).append("[")
@@ -244,7 +252,7 @@ trait VisualizeDMLGenerator extends TabbedDMLGenerator {
case "dweight" => l.dWeight
case "dbias" => l.dBias
case "output" => l.out
- case "doutput" => l.dX
+ // case "doutput" => l.dX
case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
}
}