You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/09/15 18:03:09 UTC
[2/4] systemml git commit: [SYSTEMML-540] Support loading of batch
normalization weights in .caffemodel file using Caffe2DML
http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index 67682d5..2b07788 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -33,205 +33,197 @@ import org.apache.commons.logging.LogFactory
trait Network {
def getLayers(): List[String]
- def getCaffeLayer(layerName:String):CaffeLayer
- def getBottomLayers(layerName:String): Set[String]
- def getTopLayers(layerName:String): Set[String]
- def getLayerID(layerName:String): Int
+ def getCaffeLayer(layerName: String): CaffeLayer
+ def getBottomLayers(layerName: String): Set[String]
+ def getTopLayers(layerName: String): Set[String]
+ def getLayerID(layerName: String): Int
}
object CaffeNetwork {
val LOG = LogFactory.getLog(classOf[CaffeNetwork].getName)
}
-class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
- var numChannels:String, var height:String, var width:String
- ) extends Network {
- private def isIncludedInCurrentPhase(l:LayerParameter): Boolean = {
- if(currentPhase == null) return true // while deployment
- else if(l.getIncludeCount == 0) true
+class CaffeNetwork(netFilePath: String, val currentPhase: Phase, var numChannels: String, var height: String, var width: String) extends Network {
+ private def isIncludedInCurrentPhase(l: LayerParameter): Boolean =
+ if (currentPhase == null) return true // while deployment
+ else if (l.getIncludeCount == 0) true
else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
- }
private var id = 1
- def this(deployFilePath:String) {
+ def this(deployFilePath: String) {
this(deployFilePath, null, null, null, null)
}
// --------------------------------------------------------------------------------
- private var _net:NetParameter = Utils.readCaffeNet(netFilePath)
- private var _caffeLayerParams:List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+ private var _net: NetParameter = Utils.readCaffeNet(netFilePath)
+ private var _caffeLayerParams: List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
// This method is used if the user doesnot provide number of channels, height and width
- private def setCHW(inputShapes:java.util.List[caffe.Caffe.BlobShape]):Unit = {
- if(inputShapes.size != 1)
- throw new DMLRuntimeException("Expected only one input shape")
+ private def setCHW(inputShapes: java.util.List[caffe.Caffe.BlobShape]): Unit = {
+ if (inputShapes.size != 1)
+ throw new DMLRuntimeException("Expected only one input shape")
val inputShape = inputShapes.get(0)
- if(inputShape.getDimCount != 4)
+ if (inputShape.getDimCount != 4)
throw new DMLRuntimeException("Expected the input shape of dimension 4")
numChannels = inputShape.getDim(1).toString
height = inputShape.getDim(2).toString
width = inputShape.getDim(3).toString
}
- if(numChannels == null && height == null && width == null) {
- val inputLayer:List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
- if(inputLayer.size == 1) {
+ if (numChannels == null && height == null && width == null) {
+ val inputLayer: List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
+ if (inputLayer.size == 1) {
setCHW(inputLayer(0).getInputParam.getShapeList)
- }
- else if(inputLayer.size == 0) {
- throw new DMLRuntimeException("Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer.")
- }
- else {
+ } else if (inputLayer.size == 0) {
+ throw new DMLRuntimeException(
+ "Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer."
+ )
+ } else {
throw new DMLRuntimeException("Multiple Input layer is not supported")
}
}
// --------------------------------------------------------------------------------
-
+
private var _layerNames: List[String] = _caffeLayerParams.map(l => l.getName).toList
CaffeNetwork.LOG.debug("Layers in current phase:" + _layerNames)
-
+
// Condition 1: assert that each name is unique
private val _duplicateLayerNames = _layerNames.diff(_layerNames.distinct)
- if(_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
-
+ if (_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
+
// Condition 2: only 1 top name, except Data layer
private val _condition2Exceptions = Set("data")
- _caffeLayerParams.filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase)).map(l => if(l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
+ _caffeLayerParams
+ .filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase))
+ .map(l => if (l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
// Condition 3: Replace top layer names referring to a Data layer with its name
// Example: layer{ name: mnist, top: data, top: label, ... }
- private val _topToNameMappingForDataLayer = new HashMap[String, String]()
- private def containsOnly(list:java.util.List[String], v:String): Boolean = list.toSet.diff(Set(v)).size() == 0
- private def isData(l:LayerParameter):Boolean = l.getType.equalsIgnoreCase("data")
- private def replaceTopWithNameOfDataLayer(l:LayerParameter):LayerParameter = {
- if(containsOnly(l.getTopList,l.getName))
+ private val _topToNameMappingForDataLayer = new HashMap[String, String]()
+ private def containsOnly(list: java.util.List[String], v: String): Boolean = list.toSet.diff(Set(v)).size() == 0
+ private def isData(l: LayerParameter): Boolean = l.getType.equalsIgnoreCase("data")
+ private def replaceTopWithNameOfDataLayer(l: LayerParameter): LayerParameter =
+ if (containsOnly(l.getTopList, l.getName))
return l
else {
- val builder = l.toBuilder();
- for(i <- 0 until l.getTopCount) {
- if(! l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
- builder.setTop(i, l.getName)
+ val builder = l.toBuilder();
+ for (i <- 0 until l.getTopCount) {
+ if (!l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
+ builder.setTop(i, l.getName)
}
- return builder.build()
+ return builder.build()
}
- }
// 3a: Replace top of DataLayer with its names
// Example: layer{ name: mnist, top: mnist, top: mnist, ... }
- _caffeLayerParams = _caffeLayerParams.map(l => if(isData(l)) replaceTopWithNameOfDataLayer(l) else l)
- private def replaceBottomOfNonDataLayers(l:LayerParameter):LayerParameter = {
+ _caffeLayerParams = _caffeLayerParams.map(l => if (isData(l)) replaceTopWithNameOfDataLayer(l) else l)
+ private def replaceBottomOfNonDataLayers(l: LayerParameter): LayerParameter = {
val builder = l.toBuilder();
// Note: Top will never be Data layer
- for(i <- 0 until l.getBottomCount) {
- if(_topToNameMappingForDataLayer.containsKey(l.getBottom(i)))
+ for (i <- 0 until l.getBottomCount) {
+ if (_topToNameMappingForDataLayer.containsKey(l.getBottom(i)))
builder.setBottom(i, _topToNameMappingForDataLayer.get(l.getBottom(i)))
}
return builder.build()
}
// 3a: If top/bottom of other layers refer DataLayer, then replace them
// layer { name: "conv1_1", type: "Convolution", bottom: "data"
- _caffeLayerParams = if(_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if(isData(l)) l else replaceBottomOfNonDataLayers(l))
-
+ _caffeLayerParams = if (_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if (isData(l)) l else replaceBottomOfNonDataLayers(l))
+
// Condition 4: Deal with fused layer
// Example: layer { name: conv1, top: conv1, ... } layer { name: foo, bottom: conv1, top: conv1 }
- private def isFusedLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
- private def containsReferencesToFusedLayer(l:LayerParameter):Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
- private val _fusedTopLayer = new HashMap[String, String]()
+ private def isFusedLayer(l: LayerParameter): Boolean = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
+ private def containsReferencesToFusedLayer(l: LayerParameter): Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
+ private val _fusedTopLayer = new HashMap[String, String]()
_caffeLayerParams = _caffeLayerParams.map(l => {
- if(isFusedLayer(l)) {
+ if (isFusedLayer(l)) {
val builder = l.toBuilder();
- if(_fusedTopLayer.containsKey(l.getBottom(0))) {
+ if (_fusedTopLayer.containsKey(l.getBottom(0))) {
builder.setBottom(0, _fusedTopLayer.get(l.getBottom(0)))
}
builder.setTop(0, l.getName)
_fusedTopLayer.put(l.getBottom(0), l.getName)
builder.build()
- }
- else if(containsReferencesToFusedLayer(l)) {
+ } else if (containsReferencesToFusedLayer(l)) {
val builder = l.toBuilder();
- for(i <- 0 until l.getBottomCount) {
- if(_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
+ for (i <- 0 until l.getBottomCount) {
+ if (_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
builder.setBottom(i, _fusedTopLayer.get(l.getBottomList.get(i)))
}
}
builder.build()
- }
- else l
+ } else l
})
-
+
// Used while reading caffemodel
val replacedLayerNames = new HashMap[String, String]();
-
+
// Condition 5: Deal with incorrect naming
// Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
- private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
+ private def isIncorrectNamingLayer(l: LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
_caffeLayerParams = _caffeLayerParams.map(l => {
- if(isIncorrectNamingLayer(l)) {
+ if (isIncorrectNamingLayer(l)) {
val builder = l.toBuilder();
replacedLayerNames.put(l.getName, l.getTop(0))
builder.setName(l.getTop(0))
builder.build()
- }
- else l
+ } else l
})
// --------------------------------------------------------------------------------
-
+
// Helper functions to extract bottom and top layers
- private def convertTupleListToMap(m:List[(String, String)]):Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
- private def flipKeyValues(t:List[(String, Set[String])]): List[(String, String)] = t.flatMap(x => x._2.map(b => b -> x._1))
- private def expandBottomList(layerName:String, bottomList:java.util.List[String]): List[(String, String)] = bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList
-
+ private def convertTupleListToMap(m: List[(String, String)]): Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
+ private def flipKeyValues(t: List[(String, Set[String])]): List[(String, String)] = t.flatMap(x => x._2.map(b => b -> x._1))
+ private def expandBottomList(layerName: String, bottomList: java.util.List[String]): List[(String, String)] =
+ bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList
+
// The bottom layers are the layers available in the getBottomList (from Caffe .proto files)
- private val _bottomLayers:Map[String, Set[String]] = convertTupleListToMap(
- _caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
+ private val _bottomLayers: Map[String, Set[String]] = convertTupleListToMap(_caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
CaffeNetwork.LOG.debug("Bottom layers:" + _bottomLayers)
-
+
// Find the top layers by reversing the bottom list
- private val _topLayers:Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
+ private val _topLayers: Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
CaffeNetwork.LOG.debug("Top layers:" + _topLayers)
-
+
private val _layers: Map[String, CaffeLayer] = _caffeLayerParams.map(l => l.getName -> convertLayerParameterToCaffeLayer(l)).toMap
CaffeNetwork.LOG.debug("Layers:" + _layers)
private val _layerIDs: Map[String, Int] = _layers.entrySet().map(x => x.getKey -> x.getValue.id).toMap
-
-
- private def throwException(layerName:String) = throw new LanguageException("Layer with name " + layerName + " not found")
- def getLayers(): List[String] = _layerNames
- def getCaffeLayer(layerName:String):CaffeLayer = {
- if(checkKey(_layers, layerName)) _layers.get(layerName).get
+
+ private def throwException(layerName: String) = throw new LanguageException("Layer with name " + layerName + " not found")
+ def getLayers(): List[String] = _layerNames
+ def getCaffeLayer(layerName: String): CaffeLayer =
+ if (checkKey(_layers, layerName)) _layers.get(layerName).get
else {
- if(replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
+ if (replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
_layers.get(replacedLayerNames.get(layerName)).get
- }
- else throwException(layerName)
+ } else throwException(layerName)
}
- }
- def getBottomLayers(layerName:String): Set[String] = if(checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
- def getTopLayers(layerName:String): Set[String] = if(checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
- def getLayerID(layerName:String): Int = if(checkKey(_layerIDs, layerName)) _layerIDs.get(layerName).get else throwException(layerName)
-
+ def getBottomLayers(layerName: String): Set[String] = if (checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
+ def getTopLayers(layerName: String): Set[String] = if (checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
+ def getLayerID(layerName: String): Int = if (checkKey(_layerIDs, layerName)) _layerIDs.get(layerName).get else throwException(layerName)
+
// Helper functions
- private def checkKey(m:Map[String, Any], key:String): Boolean = {
- if(m == null) throw new LanguageException("Map is null (key=" + key + ")")
- else if(key == null) throw new LanguageException("key is null (map=" + m + ")")
+ private def checkKey(m: Map[String, Any], key: String): Boolean =
+ if (m == null) throw new LanguageException("Map is null (key=" + key + ")")
+ else if (key == null) throw new LanguageException("key is null (map=" + m + ")")
else m.containsKey(key)
- }
- private def convertLayerParameterToCaffeLayer(param:LayerParameter):CaffeLayer = {
+ private def convertLayerParameterToCaffeLayer(param: LayerParameter): CaffeLayer = {
id = id + 1
param.getType.toLowerCase() match {
case "convolution" => new Convolution(param, id, this)
- case "pooling" => if(param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX) new MaxPooling(param, id, this)
- else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
- case "innerproduct" => new InnerProduct(param, id, this)
- case "relu" => new ReLU(param, id, this)
+ case "pooling" =>
+ if (param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX) new MaxPooling(param, id, this)
+ else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
+ case "innerproduct" => new InnerProduct(param, id, this)
+ case "relu" => new ReLU(param, id, this)
case "softmaxwithloss" => new SoftmaxWithLoss(param, id, this)
- case "dropout" => new Dropout(param, id, this)
- case "data" => new Data(param, id, this, numChannels, height, width)
- case "input" => new Data(param, id, this, numChannels, height, width)
- case "batchnorm" => new BatchNorm(param, id, this)
- case "scale" => new Scale(param, id, this)
- case "eltwise" => new Elementwise(param, id, this)
- case "concat" => new Concat(param, id, this)
- case "deconvolution" => new DeConvolution(param, id, this)
- case "threshold" => new Threshold(param, id, this)
- case "softmax" => new Softmax(param, id, this)
- case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
+ case "dropout" => new Dropout(param, id, this)
+ case "data" => new Data(param, id, this, numChannels, height, width)
+ case "input" => new Data(param, id, this, numChannels, height, width)
+ case "batchnorm" => new BatchNorm(param, id, this)
+ case "scale" => new Scale(param, id, this)
+ case "eltwise" => new Elementwise(param, id, this)
+ case "concat" => new Concat(param, id, this)
+ case "deconvolution" => new DeConvolution(param, id, this)
+ case "threshold" => new Threshold(param, id, this)
+ case "softmax" => new Softmax(param, id, this)
+ case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
}
}
-}
\ No newline at end of file
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 0e39192..a61ff10 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -23,72 +23,71 @@ import org.apache.sysml.runtime.DMLRuntimeException
import caffe.Caffe
trait CaffeSolver {
- def sourceFileName:String;
- def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
- def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
-
+ def sourceFileName: String;
+ def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+ def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit;
+
// ----------------------------------------------------------------
// Used for Fine-tuning
- private def getLayerLr(layer:CaffeLayer, paramIndex:Int):String = {
+ private def getLayerLr(layer: CaffeLayer, paramIndex: Int): String = {
val param = layer.param.getParamList
- if(param == null || param.size() <= paramIndex)
+ if (param == null || param.size() <= paramIndex)
return "lr"
else
// TODO: Ignoring param.get(index).getDecayMult for now
return "(lr * " + param.get(paramIndex).getLrMult + ")"
}
// the first param { } is for the weights and the second is for the biases.
- def getWeightLr(layer:CaffeLayer):String = getLayerLr(layer, 0)
- def getBiasLr(layer:CaffeLayer):String = getLayerLr(layer, 1)
+ def getWeightLr(layer: CaffeLayer): String = getLayerLr(layer, 0)
+ def getBiasLr(layer: CaffeLayer): String = getLayerLr(layer, 1)
// ----------------------------------------------------------------
-
- def commaSep(arr:String*):String = {
- if(arr.length == 1) arr(0) else {
- var ret = arr(0)
- for(i <- 1 until arr.length) {
- ret = ret + "," + arr(i)
- }
- ret
- }
- }
-
- def l2reg_update(lambda:Double, dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+
+ def commaSep(arr: String*): String =
+ if (arr.length == 1) arr(0)
+ else {
+ var ret = arr(0)
+ for (i <- 1 until arr.length) {
+ ret = ret + "," + arr(i)
+ }
+ ret
+ }
+
+ def l2reg_update(lambda: Double, dmlScript: StringBuilder, layer: CaffeLayer): Unit =
// val donotRegularizeLayers:Boolean = layer.isInstanceOf[BatchNorm] || layer.isInstanceOf[Scale];
- if(lambda != 0 && layer.shouldUpdateWeight) {
+ if (lambda != 0 && layer.shouldUpdateWeight) {
dmlScript.append("\t").append(layer.dWeight + "_reg = l2_reg::backward(" + layer.weight + ", " + lambda + ")\n")
dmlScript.append("\t").append(layer.dWeight + " = " + layer.dWeight + " + " + layer.dWeight + "_reg\n")
}
- }
}
-class LearningRatePolicy(lr_policy:String="exp", base_lr:Double=0.01) {
- def this(solver:Caffe.SolverParameter) {
+class LearningRatePolicy(lr_policy: String = "exp", base_lr: Double = 0.01) {
+ def this(solver: Caffe.SolverParameter) {
this(solver.getLrPolicy, solver.getBaseLr)
- if(solver.hasGamma) setGamma(solver.getGamma)
- if(solver.hasStepsize) setStepsize(solver.getStepsize)
- if(solver.hasPower()) setPower(solver.getPower)
+ if (solver.hasGamma) setGamma(solver.getGamma)
+ if (solver.hasStepsize) setStepsize(solver.getStepsize)
+ if (solver.hasPower()) setPower(solver.getPower)
}
- var gamma:Double = 0.95
- var step:Double = 100000
- var power:Double = 0.75
- def setGamma(gamma1:Double):Unit = { gamma = gamma1 }
- def setStepsize(step1:Double):Unit = { step = step1 }
- def setPower(power1:Double): Unit = { power = power1 }
- def updateLearningRate(dmlScript:StringBuilder):Unit = {
+ var gamma: Double = 0.95
+ var step: Double = 100000
+ var power: Double = 0.75
+ def setGamma(gamma1: Double): Unit = gamma = gamma1
+ def setStepsize(step1: Double): Unit = step = step1
+ def setPower(power1: Double): Unit = power = power1
+ def updateLearningRate(dmlScript: StringBuilder): Unit = {
val new_lr = lr_policy.toLowerCase match {
- case "fixed" => base_lr.toString
- case "step" => "(" + base_lr + " * " + gamma + " ^ " + " floor(e/" + step + "))"
- case "exp" => "(" + base_lr + " * " + gamma + "^e)"
- case "inv" => "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
- case "poly" => "(" + base_lr + " * (1 - e/ max_epochs) ^ " + power + ")"
+ case "fixed" => base_lr.toString
+ case "step" => "(" + base_lr + " * " + gamma + " ^ " + " floor(e/" + step + "))"
+ case "exp" => "(" + base_lr + " * " + gamma + "^e)"
+ case "inv" => "(" + base_lr + "* (1 + " + gamma + " * e) ^ (-" + power + "))"
+ case "poly" => "(" + base_lr + " * (1 - e/ max_epochs) ^ " + power + ")"
case "sigmoid" => "(" + base_lr + "( 1/(1 + exp(-" + gamma + "* (e - " + step + "))))"
- case _ => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
+ case _ => throw new DMLRuntimeException("The lr policy is not supported:" + lr_policy)
}
dmlScript.append("lr = " + new_lr + "\n")
}
}
-class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class SGD(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
/*
* Performs an SGD update with momentum.
*
@@ -112,31 +111,39 @@ class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
* - v: Updated velocity of the parameters `X`, of same shape as
* input `X`.
*/
- def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+ def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
l2reg_update(lambda, dmlScript, layer)
- if(momentum == 0) {
+ if (momentum == 0) {
// Use sgd
- if(layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
- }
- else {
+ if (layer.shouldUpdateWeight) dmlScript.append("\t").append(layer.weight + " = sgd::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer)) + ")\n")
+ if (layer.shouldUpdateBias) dmlScript.append("\t").append(layer.bias + " = sgd::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer)) + ")\n")
+ } else {
// Use sgd_momentum
- if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " +
- "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " +
- "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + ")\n")
+ if (layer.shouldUpdateWeight)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+ "= sgd_momentum::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + ")\n"
+ )
+ if (layer.shouldUpdateBias)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+ "= sgd_momentum::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + ")\n"
+ )
}
}
- def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
- if(momentum != 0) {
- if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_momentum::init(" + layer.weight + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_momentum::init(" + layer.bias + ")\n")
+ def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit =
+ if (momentum != 0) {
+ if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_momentum::init(" + layer.weight + ")\n")
+ if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_momentum::init(" + layer.bias + ")\n")
}
- }
- def sourceFileName:String = if(momentum == 0) "sgd" else "sgd_momentum"
+ def sourceFileName: String = if (momentum == 0) "sgd" else "sgd_momentum"
}
-class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
+class AdaGrad(lambda: Double = 5e-04, epsilon: Double = 1e-6) extends CaffeSolver {
/*
* Performs an Adagrad update.
*
@@ -164,21 +171,31 @@ class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
* - cache: State that maintains per-parameter sum of squared
* gradients, of same shape as `X`.
*/
- def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
+ def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
l2reg_update(lambda, dmlScript, layer)
- if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_cache") + "] " +
- "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight+"_cache") + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_cache") + "] " +
- "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias+"_cache") + ")\n")
+ if (layer.shouldUpdateWeight)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.weight, layer.weight + "_cache") + "] " +
+ "= adagrad::update(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), epsilon.toString, layer.weight + "_cache") + ")\n"
+ )
+ if (layer.shouldUpdateBias)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.bias, layer.bias + "_cache") + "] " +
+ "= adagrad::update(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), epsilon.toString, layer.bias + "_cache") + ")\n"
+ )
}
- def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
- if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_cache = adagrad::init(" + layer.weight + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_cache = adagrad::init(" + layer.bias + ")\n")
+ def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+ if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_cache = adagrad::init(" + layer.weight + ")\n")
+ if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_cache = adagrad::init(" + layer.bias + ")\n")
}
- def sourceFileName:String = "adagrad"
+ def sourceFileName: String = "adagrad"
}
-class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+class Nesterov(lambda: Double = 5e-04, momentum: Double = 0.9) extends CaffeSolver {
/*
* Performs an SGD update with Nesterov momentum.
*
@@ -211,20 +228,30 @@ class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
* - v: Updated velocity of the parameters X, of same shape as
* input v.
*/
- def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
- val fn = if(Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
- val lastParameter = if(Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
- if(!Caffe2DML.USE_NESTEROV_UDF) {
+ def update(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+ val fn = if (Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
+ val lastParameter = if (Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
+ if (!Caffe2DML.USE_NESTEROV_UDF) {
l2reg_update(lambda, dmlScript, layer)
}
- if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_v") + "] " +
- "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight+"_v") + lastParameter + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append("\t").append("["+ commaSep(layer.bias, layer.bias+"_v") + "] " +
- "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias+"_v") + lastParameter + ")\n")
+ if (layer.shouldUpdateWeight)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.weight, layer.weight + "_v") + "] " +
+ "= " + fn + "(" + commaSep(layer.weight, layer.dWeight, getWeightLr(layer), momentum.toString, layer.weight + "_v") + lastParameter + ")\n"
+ )
+ if (layer.shouldUpdateBias)
+ dmlScript
+ .append("\t")
+ .append(
+ "[" + commaSep(layer.bias, layer.bias + "_v") + "] " +
+ "= " + fn + "(" + commaSep(layer.bias, layer.dBias, getBiasLr(layer), momentum.toString, layer.bias + "_v") + lastParameter + ")\n"
+ )
}
- def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
- if(layer.shouldUpdateWeight) dmlScript.append(layer.weight+"_v = sgd_nesterov::init(" + layer.weight + ")\n")
- if(layer.shouldUpdateBias) dmlScript.append(layer.bias+"_v = sgd_nesterov::init(" + layer.bias + ")\n")
+ def init(dmlScript: StringBuilder, layer: CaffeLayer): Unit = {
+ if (layer.shouldUpdateWeight) dmlScript.append(layer.weight + "_v = sgd_nesterov::init(" + layer.weight + ")\n")
+ if (layer.shouldUpdateBias) dmlScript.append(layer.bias + "_v = sgd_nesterov::init(" + layer.bias + ")\n")
}
- def sourceFileName:String = "sgd_nesterov"
-}
\ No newline at end of file
+ def sourceFileName: String = "sgd_nesterov"
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/f07b5a2d/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 6b06c26..b68d493 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -27,301 +27,292 @@ import scala.collection.JavaConversions._
import caffe.Caffe
trait BaseDMLGenerator {
- def commaSep(arr:List[String]):String = {
- if(arr.length == 1) arr(0) else {
- var ret = arr(0)
- for(i <- 1 until arr.length) {
- ret = ret + "," + arr(i)
- }
- ret
- }
- }
- def commaSep(arr:String*):String = {
- if(arr.length == 1) arr(0) else {
- var ret = arr(0)
- for(i <- 1 until arr.length) {
- ret = ret + "," + arr(i)
- }
- ret
- }
- }
- def int_add(v1:String, v2:String):String = {
- try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"+"+v2+")"}
- }
- def int_mult(v1:String, v2:String, v3:String):String = {
- try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _:Throwable => "("+v1+"*"+v2+"*"+v3+")"}
- }
- def isNumber(x: String):Boolean = x forall Character.isDigit
- def transpose(x:String):String = "t(" + x + ")"
- def write(varName:String, fileName:String, format:String):String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
- def readWeight(varName:String, fileName:String, sep:String="/"):String = varName + " = read(weights + \"" + sep + fileName + "\")\n"
- def asDMLString(str:String):String = "\"" + str + "\""
- def assign(dmlScript:StringBuilder, lhsVar:String, rhsVar:String):Unit = {
+ def commaSep(arr: List[String]): String =
+ if (arr.length == 1) arr(0)
+ else {
+ var ret = arr(0)
+ for (i <- 1 until arr.length) {
+ ret = ret + "," + arr(i)
+ }
+ ret
+ }
+ def commaSep(arr: String*): String =
+ if (arr.length == 1) arr(0)
+ else {
+ var ret = arr(0)
+ for (i <- 1 until arr.length) {
+ ret = ret + "," + arr(i)
+ }
+ ret
+ }
+ def int_add(v1: String, v2: String): String =
+ try { (v1.toDouble + v2.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "+" + v2 + ")" }
+ def int_mult(v1: String, v2: String, v3: String): String =
+ try { (v1.toDouble * v2.toDouble * v3.toDouble).toInt.toString } catch { case _: Throwable => "(" + v1 + "*" + v2 + "*" + v3 + ")" }
+ def isNumber(x: String): Boolean = x forall Character.isDigit
+ def transpose(x: String): String = "t(" + x + ")"
+ def write(varName: String, fileName: String, format: String): String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
+ def readWeight(varName: String, fileName: String, sep: String = "/"): String = varName + " = read(weights + \"" + sep + fileName + "\")\n"
+ def readScalarWeight(varName: String, fileName: String, sep: String = "/"): String = varName + " = as.scalar(read(weights + \"" + sep + fileName + "\"))\n"
+ def asDMLString(str: String): String = "\"" + str + "\""
+ def assign(dmlScript: StringBuilder, lhsVar: String, rhsVar: String): Unit =
dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("\n")
- }
- def sum(dmlScript:StringBuilder, variables:List[String]):StringBuilder = {
- if(variables.length > 1) dmlScript.append("(")
+ def sum(dmlScript: StringBuilder, variables: List[String]): StringBuilder = {
+ if (variables.length > 1) dmlScript.append("(")
dmlScript.append(variables(0))
- for(i <- 1 until variables.length) {
+ for (i <- 1 until variables.length) {
dmlScript.append(" + ").append(variables(i))
}
- if(variables.length > 1) dmlScript.append(")")
+ if (variables.length > 1) dmlScript.append(")")
return dmlScript
}
- def addAndAssign(dmlScript:StringBuilder, lhsVar:String, rhsVars:List[String]):Unit = {
+ def addAndAssign(dmlScript: StringBuilder, lhsVar: String, rhsVars: List[String]): Unit = {
dmlScript.append(lhsVar).append(" = ")
sum(dmlScript, rhsVars)
dmlScript.append("\n")
}
- def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String]):Unit = {
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String]): Unit =
invoke(dmlScript, namespace1, returnVariables, functionName, arguments, true)
- }
- def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:List[String], appendNewLine:Boolean):Unit = {
- if(returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
- if(returnVariables.length > 1) dmlScript.append("[")
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: List[String], appendNewLine: Boolean): Unit = {
+ if (returnVariables.length == 0) throw new DMLRuntimeException("User-defined functions should have atleast one return value")
+ if (returnVariables.length > 1) dmlScript.append("[")
dmlScript.append(returnVariables(0))
- if(returnVariables.length > 1) {
- for(i <- 1 until returnVariables.length) {
- dmlScript.append(",").append(returnVariables(i))
- }
+ if (returnVariables.length > 1) {
+ for (i <- 1 until returnVariables.length) {
+ dmlScript.append(",").append(returnVariables(i))
+ }
dmlScript.append("]")
}
dmlScript.append(" = ")
dmlScript.append(namespace1)
dmlScript.append(functionName)
dmlScript.append("(")
- if(arguments != null) {
- if(arguments.length != 0)
+ if (arguments != null) {
+ if (arguments.length != 0)
dmlScript.append(arguments(0))
- if(arguments.length > 1) {
- for(i <- 1 until arguments.length) {
- dmlScript.append(",").append(arguments(i))
- }
+ if (arguments.length > 1) {
+ for (i <- 1 until arguments.length) {
+ dmlScript.append(",").append(arguments(i))
+ }
}
}
dmlScript.append(")")
- if(appendNewLine)
+ if (appendNewLine)
dmlScript.append("\n")
}
- def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, appendNewLine:Boolean, arguments:String*):Unit = {
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, appendNewLine: Boolean, arguments: String*): Unit =
invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, appendNewLine)
- }
- def invoke(dmlScript:StringBuilder, namespace1:String, returnVariables:List[String], functionName:String, arguments:String*):Unit = {
+ def invoke(dmlScript: StringBuilder, namespace1: String, returnVariables: List[String], functionName: String, arguments: String*): Unit =
invoke(dmlScript, namespace1, returnVariables, functionName, arguments.toList, true)
- }
- def rightIndexing(dmlScript:StringBuilder, varName:String, rl:String, ru:String, cl:String, cu:String):StringBuilder = {
+ def rightIndexing(dmlScript: StringBuilder, varName: String, rl: String, ru: String, cl: String, cu: String): StringBuilder = {
dmlScript.append(varName).append("[")
- if(rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
+ if (rl != null && ru != null) dmlScript.append(rl).append(":").append(ru)
dmlScript.append(",")
- if(cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
+ if (cl != null && cu != null) dmlScript.append(cl).append(":").append(cu)
dmlScript.append("]")
}
// Performs assignVar = ceil(lhsVar/rhsVar)
- def ceilDivide(dmlScript:StringBuilder, assignVar:String, lhsVar:String, rhsVar:String):Unit =
+ def ceilDivide(dmlScript: StringBuilder, assignVar: String, lhsVar: String, rhsVar: String): Unit =
dmlScript.append(assignVar).append(" = ").append("ceil(").append(lhsVar).append(" / ").append(rhsVar).append(")\n")
- def print(arg:String):String = "print(" + arg + ")\n"
- def dmlConcat(arg:String*):String = {
+ def print(arg: String): String = "print(" + arg + ")\n"
+ def dmlConcat(arg: String*): String = {
val ret = new StringBuilder
ret.append(arg(0))
- for(i <- 1 until arg.length) {
+ for (i <- 1 until arg.length) {
ret.append(" + ").append(arg(i))
}
ret.toString
}
- def matrix(init:String, rows:String, cols:String):String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")"
- def nrow(m:String):String = "nrow(" + m + ")"
- def ncol(m:String):String = "ncol(" + m + ")"
- def customAssert(cond:Boolean, msg:String) = if(!cond) throw new DMLRuntimeException(msg)
- def multiply(v1:String, v2:String):String = v1 + "*" + v2
- def colSums(m:String):String = "colSums(" + m + ")"
- def ifdef(cmdLineVar:String, defaultVal:String):String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
- def ifdef(cmdLineVar:String):String = ifdef(cmdLineVar, "\" \"")
- def read(filePathVar:String, format:String):String = "read(" + filePathVar + ", format=\""+ format + "\")"
+ def matrix(init: String, rows: String, cols: String): String = "matrix(" + init + ", rows=" + rows + ", cols=" + cols + ")"
+ def nrow(m: String): String = "nrow(" + m + ")"
+ def ncol(m: String): String = "ncol(" + m + ")"
+ def customAssert(cond: Boolean, msg: String) = if (!cond) throw new DMLRuntimeException(msg)
+ def multiply(v1: String, v2: String): String = v1 + "*" + v2
+ def colSums(m: String): String = "colSums(" + m + ")"
+ def ifdef(cmdLineVar: String, defaultVal: String): String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
+ def ifdef(cmdLineVar: String): String = ifdef(cmdLineVar, "\" \"")
+ def read(filePathVar: String, format: String): String = "read(" + filePathVar + ", format=\"" + format + "\")"
}
trait TabbedDMLGenerator extends BaseDMLGenerator {
- def tabDMLScript(dmlScript:StringBuilder, numTabs:Int):StringBuilder = tabDMLScript(dmlScript, numTabs, false)
- def tabDMLScript(dmlScript:StringBuilder, numTabs:Int, prependNewLine:Boolean):StringBuilder = {
- if(prependNewLine) dmlScript.append("\n")
- for(i <- 0 until numTabs) dmlScript.append("\t")
- dmlScript
+ def tabDMLScript(dmlScript: StringBuilder, numTabs: Int): StringBuilder = tabDMLScript(dmlScript, numTabs, false)
+ def tabDMLScript(dmlScript: StringBuilder, numTabs: Int, prependNewLine: Boolean): StringBuilder = {
+ if (prependNewLine) dmlScript.append("\n")
+ for (i <- 0 until numTabs) dmlScript.append("\t")
+ dmlScript
}
}
trait SourceDMLGenerator extends TabbedDMLGenerator {
- val alreadyImported:HashSet[String] = new HashSet[String]
- def source(dmlScript:StringBuilder, numTabs:Int, sourceFileName:String, dir:String):Unit = {
- if(sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
- tabDMLScript(dmlScript, numTabs).append("source(\"" + dir + sourceFileName + ".dml\") as " + sourceFileName + "\n")
+ val alreadyImported: HashSet[String] = new HashSet[String]
+ def source(dmlScript: StringBuilder, numTabs: Int, sourceFileName: String, dir: String): Unit =
+ if (sourceFileName != null && !alreadyImported.contains(sourceFileName)) {
+ tabDMLScript(dmlScript, numTabs).append("source(\"" + dir + sourceFileName + ".dml\") as " + sourceFileName + "\n")
alreadyImported.add(sourceFileName)
- }
+ }
+ def source(dmlScript: StringBuilder, numTabs: Int, net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit = {
+ // Add layers with multiple source files
+ if (net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
+ source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
+ source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
+ }
+ net.getLayers.map(layer => source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
+ if (solver != null)
+ source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
+ if (otherFiles != null)
+ otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
}
- def source(dmlScript:StringBuilder, numTabs:Int, net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
- // Add layers with multiple source files
- if(net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[SoftmaxWithLoss]).length > 0) {
- source(dmlScript, numTabs, "softmax", Caffe2DML.layerDir)
- source(dmlScript, numTabs, "cross_entropy_loss", Caffe2DML.layerDir)
- }
- net.getLayers.map(layer => source(dmlScript, numTabs, net.getCaffeLayer(layer).sourceFileName, Caffe2DML.layerDir))
- if(solver != null)
- source(dmlScript, numTabs, solver.sourceFileName, Caffe2DML.optimDir)
- if(otherFiles != null)
- otherFiles.map(sourceFileName => source(dmlScript, numTabs, sourceFileName, Caffe2DML.layerDir))
- }
}
trait NextBatchGenerator extends TabbedDMLGenerator {
- def min(lhs:String, rhs:String): String = "min(" + lhs + ", " + rhs + ")"
-
- def assignBatch(dmlScript:StringBuilder, Xb:String, X:String, yb:String, y:String, indexPrefix:String, N:String, i:String):StringBuilder = {
- dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
- dmlScript.append(indexPrefix).append("end = min(beg + " + Caffe2DML.batchSize + " - 1, " + N + "); ")
- dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
- if(yb != null && y != null)
- dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
- dmlScript.append("\n")
- }
- def getTestBatch(tabDMLScript:StringBuilder):Unit = {
+ def min(lhs: String, rhs: String): String = "min(" + lhs + ", " + rhs + ")"
+
+ def assignBatch(dmlScript: StringBuilder, Xb: String, X: String, yb: String, y: String, indexPrefix: String, N: String, i: String): StringBuilder = {
+ dmlScript.append(indexPrefix).append("beg = ((" + i + "-1) * " + Caffe2DML.batchSize + ") %% " + N + " + 1; ")
+ dmlScript.append(indexPrefix).append("end = min(beg + " + Caffe2DML.batchSize + " - 1, " + N + "); ")
+ dmlScript.append(Xb).append(" = ").append(X).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+ if (yb != null && y != null)
+ dmlScript.append(yb).append(" = ").append(y).append("[").append(indexPrefix).append("beg:").append(indexPrefix).append("end,]; ")
+ dmlScript.append("\n")
+ }
+ def getTestBatch(tabDMLScript: StringBuilder): Unit =
assignBatch(tabDMLScript, "Xb", Caffe2DML.X, null, null, "", Caffe2DML.numImages, "iter")
- }
- def getTrainingBatch(tabDMLScript:StringBuilder):Unit = {
+
+ def getTrainingBatch(tabDMLScript: StringBuilder): Unit =
assignBatch(tabDMLScript, "Xb", Caffe2DML.X, "yb", Caffe2DML.y, "", Caffe2DML.numImages, "iter")
- }
- def getTrainingBatch(tabDMLScript:StringBuilder, X:String, y:String, numImages:String):Unit = {
- assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
- }
- def getTrainingMaxiBatch(tabDMLScript:StringBuilder):Unit = {
+ def getTrainingBatch(tabDMLScript: StringBuilder, X: String, y: String, numImages: String): Unit =
+ assignBatch(tabDMLScript, "Xb", X, "yb", y, "", numImages, "i")
+ def getTrainingMaxiBatch(tabDMLScript: StringBuilder): Unit =
assignBatch(tabDMLScript, "X_group_batch", Caffe2DML.X, "y_group_batch", Caffe2DML.y, "group_", Caffe2DML.numImages, "g")
- }
- def getValidationBatch(tabDMLScript:StringBuilder):Unit = {
+ def getValidationBatch(tabDMLScript: StringBuilder): Unit =
assignBatch(tabDMLScript, "Xb", Caffe2DML.XVal, "yb", Caffe2DML.yVal, "", Caffe2DML.numValidationImages, "iVal")
- }
}
trait VisualizeDMLGenerator extends TabbedDMLGenerator {
- var doVisualize = false
- var _tensorboardLogDir:String = null
- def setTensorBoardLogDir(log:String): Unit = { _tensorboardLogDir = log }
- def tensorboardLogDir:String = {
- if(_tensorboardLogDir == null) {
+ var doVisualize = false
+ var _tensorboardLogDir: String = null
+ def setTensorBoardLogDir(log: String): Unit = _tensorboardLogDir = log
+ def tensorboardLogDir: String = {
+ if (_tensorboardLogDir == null) {
_tensorboardLogDir = java.io.File.createTempFile("temp", System.nanoTime().toString()).getAbsolutePath
}
_tensorboardLogDir
}
def visualizeLoss(): Unit = {
- checkTensorBoardDependency()
- doVisualize = true
- // Visualize for both training and validation
- visualize(" ", " ", "training_loss", "iter", "training_loss", true)
- visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
- visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
- visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
- }
- val visTrainingDMLScript: StringBuilder = new StringBuilder
+ checkTensorBoardDependency()
+ doVisualize = true
+ // Visualize for both training and validation
+ visualize(" ", " ", "training_loss", "iter", "training_loss", true)
+ visualize(" ", " ", "training_accuracy", "iter", "training_accuracy", true)
+ visualize(" ", " ", "validation_loss", "iter", "validation_loss", false)
+ visualize(" ", " ", "validation_accuracy", "iter", "validation_accuracy", false)
+ }
+ val visTrainingDMLScript: StringBuilder = new StringBuilder
val visValidationDMLScript: StringBuilder = new StringBuilder
- def checkTensorBoardDependency():Unit = {
- try {
- if(!doVisualize)
- Class.forName( "com.google.protobuf.GeneratedMessageV3")
- } catch {
- case _:ClassNotFoundException => throw new DMLRuntimeException("To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar")
- }
- }
- private def visualize(layerName:String, varType:String, aggFn:String, x:String, y:String, isTraining:Boolean) = {
- val dmlScript = if(isTraining) visTrainingDMLScript else visValidationDMLScript
- dmlScript.append("viz_counter1 = visualize(" +
- commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
- + ");\n")
+ def checkTensorBoardDependency(): Unit =
+ try {
+ if (!doVisualize)
+ Class.forName("com.google.protobuf.GeneratedMessageV3")
+ } catch {
+ case _: ClassNotFoundException =>
+ throw new DMLRuntimeException(
+ "To use visualize() feature, you will have to include protobuf-java-3.2.0.jar in your classpath. Hint: you can download the jar from http://central.maven.org/maven2/com/google/protobuf/protobuf-java/3.2.0/protobuf-java-3.2.0.jar"
+ )
+ }
+ private def visualize(layerName: String, varType: String, aggFn: String, x: String, y: String, isTraining: Boolean) = {
+ val dmlScript = if (isTraining) visTrainingDMLScript else visValidationDMLScript
+ dmlScript.append(
+ "viz_counter1 = visualize(" +
+ commaSep(asDMLString(layerName), asDMLString(varType), asDMLString(aggFn), x, y, asDMLString(tensorboardLogDir))
+ + ");\n"
+ )
dmlScript.append("viz_counter = viz_counter + viz_counter1\n")
}
- def visualizeLayer(net:CaffeNetwork, layerName:String, varType:String, aggFn:String): Unit = {
- // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
- // 'sum', 'mean', 'var' or 'sd'
- checkTensorBoardDependency()
- doVisualize = true
- if(net.getLayers.filter(_.equals(layerName)).size == 0)
- throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
- val dmlVar = {
- val l = net.getCaffeLayer(layerName)
- varType match {
- case "weight" => l.weight
- case "bias" => l.bias
- case "dweight" => l.dWeight
- case "dbias" => l.dBias
- case "output" => l.out
- // case "doutput" => l.dX
- case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
- }
- }
- if(dmlVar == null)
- throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
- // Visualize for both training and validation
- visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
- visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
- }
-
- def appendTrainingVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
- if(doVisualize)
- tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
- }
- def appendValidationVisualizationBody(dmlScript:StringBuilder, numTabs:Int): Unit = {
- if(doVisualize)
- tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
+ def visualizeLayer(net: CaffeNetwork, layerName: String, varType: String, aggFn: String): Unit = {
+ // 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
+ // 'sum', 'mean', 'var' or 'sd'
+ checkTensorBoardDependency()
+ doVisualize = true
+ if (net.getLayers.filter(_.equals(layerName)).size == 0)
+ throw new DMLRuntimeException("Cannot visualize the layer:" + layerName)
+ val dmlVar = {
+ val l = net.getCaffeLayer(layerName)
+ varType match {
+ case "weight" => l.weight
+ case "bias" => l.bias
+ case "dweight" => l.dWeight
+ case "dbias" => l.dBias
+ case "output" => l.out
+ // case "doutput" => l.dX
+ case _ => throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+ }
+ }
+ if (dmlVar == null)
+ throw new DMLRuntimeException("Cannot visualize the variable of type:" + varType)
+ // Visualize for both training and validation
+ visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", true)
+ visualize(layerName, varType, aggFn, "iter", aggFn + "(" + dmlVar + ")", false)
}
+
+ def appendTrainingVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+ if (doVisualize)
+ tabDMLScript(dmlScript, numTabs).append(visTrainingDMLScript.toString)
+ def appendValidationVisualizationBody(dmlScript: StringBuilder, numTabs: Int): Unit =
+ if (doVisualize)
+ tabDMLScript(dmlScript, numTabs).append(visValidationDMLScript.toString)
}
trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with VisualizeDMLGenerator {
// Also makes "code reading" possible for Caffe2DML :)
- var dmlScript = new StringBuilder
- var numTabs = 0
- def reset():Unit = {
- dmlScript.clear()
- alreadyImported.clear()
- numTabs = 0
- visTrainingDMLScript.clear()
- visValidationDMLScript.clear()
- doVisualize = false
- }
- // -------------------------------------------------------------------------------------------------
- // Helper functions that calls super class methods and simplifies the code of this trait
- def tabDMLScript():StringBuilder = tabDMLScript(dmlScript, numTabs, false)
- def tabDMLScript(prependNewLine:Boolean):StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
- def source(net:CaffeNetwork, solver:CaffeSolver, otherFiles:Array[String]):Unit = {
- source(dmlScript, numTabs, net, solver, otherFiles)
- }
- // -------------------------------------------------------------------------------------------------
-
- def ifBlock(cond:String)(op: => Unit) {
- tabDMLScript.append("if(" + cond + ") {\n")
- numTabs += 1
- op
- numTabs -= 1
- tabDMLScript.append("}\n")
- }
- def whileBlock(cond:String)(op: => Unit) {
- tabDMLScript.append("while(" + cond + ") {\n")
- numTabs += 1
- op
- numTabs -= 1
- tabDMLScript.append("}\n")
- }
- def forBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
- tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
- numTabs += 1
- op
- numTabs -= 1
- tabDMLScript.append("}\n")
- }
- def parForBlock(iterVarName:String, startVal:String, endVal:String)(op: => Unit) {
- tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
- numTabs += 1
- op
- numTabs -= 1
- tabDMLScript.append("}\n")
- }
-
- def printClassificationReport():Unit = {
- ifBlock("debug"){
+ var dmlScript = new StringBuilder
+ var numTabs = 0
+ def reset(): Unit = {
+ dmlScript.clear()
+ alreadyImported.clear()
+ numTabs = 0
+ visTrainingDMLScript.clear()
+ visValidationDMLScript.clear()
+ doVisualize = false
+ }
+ // -------------------------------------------------------------------------------------------------
+ // Helper functions that calls super class methods and simplifies the code of this trait
+ def tabDMLScript(): StringBuilder = tabDMLScript(dmlScript, numTabs, false)
+ def tabDMLScript(prependNewLine: Boolean): StringBuilder = tabDMLScript(dmlScript, numTabs, prependNewLine)
+ def source(net: CaffeNetwork, solver: CaffeSolver, otherFiles: Array[String]): Unit =
+ source(dmlScript, numTabs, net, solver, otherFiles)
+ // -------------------------------------------------------------------------------------------------
+
+ def ifBlock(cond: String)(op: => Unit) {
+ tabDMLScript.append("if(" + cond + ") {\n")
+ numTabs += 1
+ op
+ numTabs -= 1
+ tabDMLScript.append("}\n")
+ }
+ def whileBlock(cond: String)(op: => Unit) {
+ tabDMLScript.append("while(" + cond + ") {\n")
+ numTabs += 1
+ op
+ numTabs -= 1
+ tabDMLScript.append("}\n")
+ }
+ def forBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+ tabDMLScript.append("for(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+ numTabs += 1
+ op
+ numTabs -= 1
+ tabDMLScript.append("}\n")
+ }
+ def parForBlock(iterVarName: String, startVal: String, endVal: String)(op: => Unit) {
+ tabDMLScript.append("parfor(" + iterVarName + " in " + startVal + ":" + endVal + ") {\n")
+ numTabs += 1
+ op
+ numTabs -= 1
+ tabDMLScript.append("}\n")
+ }
+
+ def printClassificationReport(): Unit =
+ ifBlock("debug") {
assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
forBlock("class_i", "1", "num_rows_error_measures") {
@@ -337,35 +328,38 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
}
- val dmlTab = "\\t"
- val header = "class " + dmlTab + "precision" + dmlTab + "recall " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+ val dmlTab = "\\t"
+ val header = "class " + dmlTab + "precision" + dmlTab + "recall " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
}
- }
-
- // Appends DML corresponding to source and externalFunction statements.
- def appendHeaders(net:CaffeNetwork, solver:CaffeSolver, isTraining:Boolean):Unit = {
+
+ // Appends DML corresponding to source and externalFunction statements.
+ def appendHeaders(net: CaffeNetwork, solver: CaffeSolver, isTraining: Boolean): Unit = {
// Append source statements for layers as well as solver
- source(net, solver, if(isTraining) Array[String]("l2_reg") else null)
-
- if(isTraining) {
- // Append external built-in function headers:
- // 1. visualize external built-in function header
- if(doVisualize) {
- tabDMLScript.append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
- "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
- tabDMLScript.append("viz_counter = 0\n")
- System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
- }
- // 2. update_nesterov external built-in function header
- if(Caffe2DML.USE_NESTEROV_UDF) {
- tabDMLScript.append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\"); \n")
- }
- }
+ source(net, solver, if (isTraining) Array[String]("l2_reg") else null)
+
+ if (isTraining) {
+ // Append external built-in function headers:
+ // 1. visualize external built-in function header
+ if (doVisualize) {
+ tabDMLScript.append(
+ "visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
+ "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n"
+ )
+ tabDMLScript.append("viz_counter = 0\n")
+ System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
+ }
+ // 2. update_nesterov external built-in function header
+ if (Caffe2DML.USE_NESTEROV_UDF) {
+ tabDMLScript.append(
+ "update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\"); \n"
+ )
+ }
+ }
}
-
- def readMatrix(varName:String, cmdLineVar:String):Unit = {
+
+ def readMatrix(varName: String, cmdLineVar: String): Unit = {
val pathVar = varName + "_path"
assign(tabDMLScript, pathVar, ifdef(cmdLineVar))
// Uncomment the following lines if we want to the user to pass the format
@@ -374,47 +368,47 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
// assign(tabDMLScript, varName, "read(" + pathVar + ", format=" + formatVar + ")")
assign(tabDMLScript, varName, "read(" + pathVar + ")")
}
-
- def readInputData(net:CaffeNetwork, isTraining:Boolean):Unit = {
+
+ def readInputData(net: CaffeNetwork, isTraining: Boolean): Unit = {
// Read and convert to one-hot encoding
readMatrix("X_full", "$X")
- if(isTraining) {
- readMatrix("y_full", "$y")
- tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
- tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
- tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
- }
- else {
- tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
- }
+ if (isTraining) {
+ readMatrix("y_full", "$y")
+ tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+ tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+ tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+ } else {
+ tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+ }
}
-
- def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean): Unit = {
+
+ def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean): Unit =
initWeights(net, solver, readWeights, new HashSet[String]())
- }
-
- def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean, layersToIgnore:HashSet[String]): Unit = {
+
+ def initWeights(net: CaffeNetwork, solver: CaffeSolver, readWeights: Boolean, layersToIgnore: HashSet[String]): Unit = {
tabDMLScript.append("weights = ifdef($weights, \" \")\n")
- // Initialize the layers and solvers
- tabDMLScript.append("# Initialize the layers and solvers\n")
- net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
- if(readWeights) {
- // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
- tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
- net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
- net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
- }
- net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+ // Initialize the layers and solvers
+ tabDMLScript.append("# Initialize the layers and solvers\n")
+ net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+ if (readWeights) {
+ // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+ tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+ val allLayers = net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_))
+ allLayers.filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
+ allLayers.filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
+ }
+ net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
}
-
- def getLossLayers(net:CaffeNetwork):List[IsLossLayer] = {
+
+ def getLossLayers(net: CaffeNetwork): List[IsLossLayer] = {
val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
- if(lossLayers.length != 1)
- throw new DMLRuntimeException("Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
- lossLayers
+ if (lossLayers.length != 1)
+ throw new DMLRuntimeException(
+ "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer])
+ )
+ lossLayers
}
-
- def updateMeanVarianceForBatchNorm(net:CaffeNetwork, value:Boolean):Unit = {
+
+ def updateMeanVarianceForBatchNorm(net: CaffeNetwork, value: Boolean): Unit =
net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
- }
-}
\ No newline at end of file
+}