You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/21 20:26:23 UTC
incubator-systemml git commit: [MINOR] Added documentation for
Caffe2DML APIs.
Repository: incubator-systemml
Updated Branches:
refs/heads/master 259742c40 -> 700b08094
[MINOR] Added documentation for Caffe2DML APIs.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/700b0809
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/700b0809
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/700b0809
Branch: refs/heads/master
Commit: 700b080940bf68a71728be12ea4e24c2450e9d13
Parents: 259742c
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Sun May 21 13:24:15 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Sun May 21 13:25:29 2017 -0700
----------------------------------------------------------------------
.../org/apache/sysml/api/dl/Caffe2DML.scala | 318 +++++++++----------
.../org/apache/sysml/api/dl/CaffeSolver.scala | 96 +++++-
.../org/apache/sysml/api/dl/DMLGenerator.scala | 112 ++++++-
3 files changed, 338 insertions(+), 188 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index f7f85c3..fe6b159 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -52,6 +52,43 @@ import org.apache.commons.logging.LogFactory
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer
+/***************************************************************************************
+DESIGN OF CAFFE2DML:
+
+1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that needed to be implemented are:
+- `getTrainingScript` for the Estimator class.
+- `getPredictionScript` for the Model class.
+
+2. To simplify the DML generation in getTrainingScript and getPredictionScript method, we use DMLGenerator interface.
+This interface generates DML string for common operations such as loops (such as if, for, while) as well as built-in functions (read, write), etc.
+Also, this interface helps in "code reading" of this class :)
+
+3. Additionally, we created mapping classes for layer, solver and learning rate that maps the corresponding Caffe abstraction to the SystemML-NN library.
+This greatly simplifies adding new layers into Caffe2DML:
+trait CaffeLayer {
+ // Any layer that wants to reuse SystemML-NN has to override following methods that help in generating the DML for the given layer:
+ def sourceFileName:String;
+ def init(dmlScript:StringBuilder):Unit;
+ def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
+ def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
+ ...
+}
+trait CaffeSolver {
+ def sourceFileName:String;
+ def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
+ def init(dmlScript:StringBuilder, layer:CaffeLayer):Unit;
+}
+
+4. To simplify the traversal of the network, we created a Network interface:
+trait Network {
+ def getLayers(): List[String]
+ def getCaffeLayer(layerName:String):CaffeLayer
+ def getBottomLayers(layerName:String): Set[String]
+ def getTopLayers(layerName:String): Set[String]
+ def getLayerID(layerName:String): Int
+}
+***************************************************************************************/
+
object Caffe2DML {
val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
// ------------------------------------------------------------------------
@@ -62,7 +99,7 @@ object Caffe2DML {
val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
val XVal = "X_val"; val yVal = "y_val"
- var USE_NESTEROV_UDF = {
+ val USE_NESTEROV_UDF = {
// Developer environment variable flag 'USE_NESTEROV_UDF' until codegen starts working.
// Then, we will remove this flag and also the class org.apache.sysml.udf.lib.SGDNesterovUpdate
val envFlagNesterovUDF = System.getenv("USE_NESTEROV_UDF")
@@ -118,6 +155,93 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
// Method called by Python mllearn to visualize variable of certain layer
def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = visualizeLayer(net, layerName, varType, aggFn)
+ // ================================================================================================
+ // The below method parses the provided network and solver file and generates DML script.
+ def getTrainingScript(isSingleNode:Boolean):(Script, String, String) = {
+ val startTrainingTime = System.nanoTime()
+
+ reset // Reset the state of DML generator for training script.
+
+ // Flags passed by user
+ val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+ assign(tabDMLScript, "debug", if(DEBUG_TRAINING) "TRUE" else "FALSE")
+
+ appendHeaders(net, solver, true) // Appends DML corresponding to source and externalFunction statements.
+ readInputData(net, true) // Read X_full and y_full
+ // Initialize the layers and solvers. Reads weights and bias if $weights is set.
+ initWeights(net, solver, inputs.containsKey("$weights"), layersToIgnore)
+
+ // Split into training and validation set
+ // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+ val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+ trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
+
+ // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
+ setIterationVariables
+ val lossLayers = getLossLayers(net)
+ // ----------------------------------------------------------------------------
+ // Main logic
+ forBlock("e", "1", "max_epochs") {
+ solverParam.getTrainAlgo.toLowerCase match {
+ case "minibatch" =>
+ forBlock("i", "1", "num_iters_per_epoch") {
+ getTrainingBatch(tabDMLScript)
+ tabDMLScript.append("iter = start_iter + i\n")
+ forward; backward; update
+ displayLoss(lossLayers(0), shouldValidate)
+ performSnapshot
+ }
+ case "batch" => {
+ tabDMLScript.append("iter = start_iter + i\n")
+ forward; backward; update
+ displayLoss(lossLayers(0), shouldValidate)
+ performSnapshot
+ }
+ case "allreduce" => {
+ forBlock("i", "1", "num_iters_per_epoch") {
+ getTrainingBatch(tabDMLScript)
+ assign(tabDMLScript, "X_group_batch", "Xb")
+ assign(tabDMLScript, "y_group_batch", "yb")
+ tabDMLScript.append("iter = start_iter + i\n")
+ initAggGradients
+ parForBlock("j", "1", "nrow(y_group_batch)") {
+ assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+ assign(tabDMLScript, "yb", "y_group_batch[j,]")
+ forward; backward("_agg")
+ flattenAndStoreAggGradients_j
+ }
+ aggregateAggGradients
+ tabDMLScript.append("iter = start_iter + parallel_batches\n")
+ update
+ displayLoss(lossLayers(0), shouldValidate)
+ performSnapshot
+ }
+ }
+ case _ => throw new DMLRuntimeException("Unsupported train algo:" + solverParam.getTrainAlgo)
+ }
+ // After every epoch, update the learning rate
+ tabDMLScript.append("# Learning rate\n")
+ lrPolicy.updateLearningRate(tabDMLScript)
+ tabDMLScript.append("start_iter = start_iter + num_iters_per_epoch\n")
+ }
+ // ----------------------------------------------------------------------------
+
+ // Check if this is necessary
+ if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
+
+ val trainingScript = tabDMLScript.toString()
+ // Print script generation time and the DML script on stdout
+ System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
+ if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+
+ // Set input/output variables and execute the script
+ val script = dml(trainingScript).in(inputs)
+ net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
+ net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
+ (script, "X_full", "y_full")
+ }
+ // ================================================================================================
+
// -------------------------------------------------------------------------------------------
// Helper functions to generate DML
// Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
@@ -153,30 +277,6 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
}
}
- private def printClassificationReport():Unit = {
- ifBlock("debug"){
- assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
- assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
- forBlock("class_i", "1", "num_rows_error_measures") {
- assign(tabDMLScript, "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )")
- assign(tabDMLScript, "tp_plus_fp", "sum( (predicted_yb == class_i) )")
- assign(tabDMLScript, "tp_plus_fn", "sum( (true_yb == class_i) )")
- assign(tabDMLScript, "precision", "tp / tp_plus_fp")
- assign(tabDMLScript, "recall", "tp / tp_plus_fn")
- assign(tabDMLScript, "f1Score", "2*precision*recall / (precision+recall)")
- assign(tabDMLScript, "error_measures[class_i,1]", "class_i")
- assign(tabDMLScript, "error_measures[class_i,2]", "precision")
- assign(tabDMLScript, "error_measures[class_i,3]", "recall")
- assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
- assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
- }
- val dmlTab = "\\t"
- val header = "class " + dmlTab + "precision" + dmlTab + "recall " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
- val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
- tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
- }
- }
-
// Append the DML to display training and validation loss
private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit = {
if(solverParam.getDisplay > 0) {
@@ -275,54 +375,9 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
})
}
- // -------------------------------------------------------------------------------------------
-
- private def multiply(v1:String, v2:String):String = v1 + "*" + v2
- private def colSums(m:String):String = "colSums(" + m + ")"
-
- // Script generator
- def getTrainingScript(isSingleNode:Boolean):(Script, String, String) = {
- val startTrainingTime = System.nanoTime()
- val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
- reset()
-
- // Add source for layers as well as solver as well as visualization header
- source(net, solver, Array[String]("l2_reg"))
- appendVisualizationHeaders(dmlScript, numTabs)
-
- if(Caffe2DML.USE_NESTEROV_UDF) {
- tabDMLScript(dmlScript, numTabs).append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\"); \n")
- }
-
- // Read and convert to one-hote encoding
- assign(tabDMLScript, "X_full", "read(\" \", format=\"csv\")")
- assign(tabDMLScript, "y_full", "read(\" \", format=\"csv\")")
- tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
- tabDMLScript.append("weights = ifdef($weights, \" \")\n")
- tabDMLScript.append("debug = ifdef($debug, FALSE)\n")
- tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
- tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
-
- // Initialize the layers and solvers
- tabDMLScript.append("# Initialize the layers and solvers\n")
- net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
- if(inputs.containsKey("$weights")) {
- // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
- tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
- net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
- net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
- }
- net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
-
- // Split into training and validation set
- // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
- val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
- trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
-
- // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
- val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
- if(lossLayers.length != 1) throw new DMLRuntimeException("Expected exactly one loss layer")
- solverParam.getTrainAlgo.toLowerCase match {
+ // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
+ def setIterationVariables():Unit = {
+ solverParam.getTrainAlgo.toLowerCase match {
case "batch" =>
assign(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString)
case _ => {
@@ -332,68 +387,8 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
}
assign(tabDMLScript, "start_iter", "0")
assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
-
- // ----------------------------------------------------------------------------
- // Main logic
- forBlock("e", "1", "max_epochs") {
- solverParam.getTrainAlgo.toLowerCase match {
- case "minibatch" =>
- forBlock("i", "1", "num_iters_per_epoch") {
- getTrainingBatch(tabDMLScript)
- tabDMLScript.append("iter = start_iter + i\n")
- forward; backward; update
- displayLoss(lossLayers(0), shouldValidate)
- performSnapshot
- }
- case "batch" => {
- tabDMLScript.append("iter = start_iter + i\n")
- forward; backward; update
- displayLoss(lossLayers(0), shouldValidate)
- performSnapshot
- }
- case "allreduce" => {
- forBlock("i", "1", "num_iters_per_epoch") {
- getTrainingBatch(tabDMLScript)
- assign(tabDMLScript, "X_group_batch", "Xb")
- assign(tabDMLScript, "y_group_batch", "yb")
- tabDMLScript.append("iter = start_iter + i\n")
- initAggGradients
- parForBlock("j", "1", "nrow(y_group_batch)") {
- assign(tabDMLScript, "Xb", "X_group_batch[j,]")
- assign(tabDMLScript, "yb", "y_group_batch[j,]")
- forward; backward("_agg")
- flattenAndStoreAggGradients_j
- }
- aggregateAggGradients
- tabDMLScript.append("iter = start_iter + parallel_batches\n")
- update
- displayLoss(lossLayers(0), shouldValidate)
- performSnapshot
- }
- }
- case _ => throw new DMLRuntimeException("Unsupported train algo:" + solverParam.getTrainAlgo)
- }
- // After every epoch, update the learning rate
- tabDMLScript.append("# Learning rate\n")
- lrPolicy.updateLearningRate(tabDMLScript)
- tabDMLScript.append("start_iter = start_iter + num_iters_per_epoch\n")
- }
- // ----------------------------------------------------------------------------
-
- // Check if this is necessary
- if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
-
- val trainingScript = tabDMLScript.toString()
- // Print script generation time and the DML script on stdout
- System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
- if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
-
- // Set input/output variables and execute the script
- val script = dml(trainingScript).in(inputs)
- net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
- net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
- (script, "X_full", "y_full")
- }
+ }
+ // -------------------------------------------------------------------------------------------
}
class Caffe2DMLModel(val mloutput: MLResults,
@@ -431,40 +426,33 @@ class Caffe2DMLModel(val mloutput: MLResults,
ml.execute(script)
}
+ // ================================================================================================
+ // The below method parses the provided network and solver file and generates DML script.
def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String) = {
- reset()
val startPredictionTime = System.nanoTime()
- val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+
+ reset // Reset the state of DML generator for training script.
- // Append source statements for each layer
- source(net, solver, null)
- tabDMLScript.append("weights = ifdef($weights, \" \")\n")
- // Initialize the layers and solvers
- tabDMLScript.append("# Initialize the layers and solvers\n")
- net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
- if(mloutput == null && estimator.inputs.containsKey("$weights")) {
- // fit was not called
- net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
- net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
- }
- else if(mloutput == null) {
- throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
+ val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+ assign(tabDMLScript, "debug", if(DEBUG_PREDICTION) "TRUE" else "FALSE")
+
+ appendHeaders(net, solver, false) // Appends DML corresponding to source and externalFunction statements.
+ readInputData(net, false) // Read X_full and y_full
+ assign(tabDMLScript, "X", "X_full")
+
+ // Initialize the layers and solvers. Reads weights and bias if readWeights is true.
+ val readWeights = {
+ if(mloutput == null && estimator.inputs.containsKey("$weights")) true
+ else if(mloutput == null) throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
+ else false
}
- net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
-
-// if(estimator.inputs.containsKey("$debug") && estimator.inputs.get("$debug").equals("TRUE")) {
-// System.out.println("The output shape of layers:")
-// net.getLayers.map(layer => System.out.println(net.getCaffeLayer(layer).param.getName + " " + net.getCaffeLayer(layer).outputShape))
-// }
+ initWeights(net, solver, readWeights)
// Donot update mean and variance in batchnorm
- net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = false)
- tabDMLScript.append("X_full = read(\" \", format=\"csv\")\n")
- assign(tabDMLScript, "X", "X_full")
- tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+ updateMeanVarianceForBatchNorm(net, false)
+
+ val lossLayers = getLossLayers(net)
- val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
- customAssert(lossLayers.length == 1, "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
assign(tabDMLScript, "Prob", matrix("0", Caffe2DML.numImages, numClasses))
estimator.solverParam.getTestAlgo.toLowerCase match {
case "minibatch" => {
@@ -495,8 +483,8 @@ class Caffe2DMLModel(val mloutput: MLResults,
System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." )
if(DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
- // Reset
- net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = true)
+ // Reset state of BatchNorm layer
+ updateMeanVarianceForBatchNorm(net, true)
val script = dml(predictionScript).out("Prob").in(estimator.inputs)
if(mloutput != null) {
@@ -504,9 +492,9 @@ class Caffe2DMLModel(val mloutput: MLResults,
net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, mloutput.getBinaryBlockMatrix(l.weight)))
net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, mloutput.getBinaryBlockMatrix(l.bias)))
}
-
(script, "X_full")
}
+ // ================================================================================================
// Prediction
def transform(X: MatrixBlock): MatrixBlock = {
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
index 0620e44..0e39192 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeSolver.scala
@@ -88,11 +88,30 @@ class LearningRatePolicy(lr_policy:String="exp", base_lr:Double=0.01) {
}
}
-/**
- * lambda: regularization parameter
- * momentum: Momentum value. Typical values are in the range of [0.5, 0.99], usually started at the lower end and annealed towards the higher end.
- */
class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+ /*
+ * Performs an SGD update with momentum.
+ *
+ * In SGD with momentum, we assume that the parameters have a velocity
+ * that continues with some momentum, and that is influenced by the
+ * gradient.
+ *
+ * Inputs:
+ * - X: Parameters to update, of shape (any, any).
+ * - dX: Gradient wrt `X` of a loss function being optimized, of
+ * same shape as `X`.
+ * - lr: Learning rate.
+ * - mu: Momentum value.
+ * Typical values are in the range of [0.5, 0.99], usually
+ * started at the lower end and annealed towards the higher end.
+ * - v: State maintaining the velocity of the parameters `X`, of same
+ * shape as `X`.
+ *
+ * Outputs:
+ * - X: Updated parameters `X`, of same shape as input `X`.
+ * - v: Updated velocity of the parameters `X`, of same shape as
+ * input `X`.
+ */
def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
l2reg_update(lambda, dmlScript, layer)
if(momentum == 0) {
@@ -117,13 +136,34 @@ class SGD(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
def sourceFileName:String = if(momentum == 0) "sgd" else "sgd_momentum"
}
-/**
- * lambda: regularization parameter
- * epsilon: Smoothing term to avoid divide by zero errors. Typical values are in the range of [1e-8, 1e-4].
- *
- * See Adaptive Subgradient Methods for Online Learning and Stochastic Optimization, Duchi et al.
- */
class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
+ /*
+ * Performs an Adagrad update.
+ *
+ * This is an adaptive learning rate optimizer that maintains the
+ * sum of squared gradients to automatically adjust the effective
+ * learning rate.
+ *
+ * Reference:
+ * - Adaptive Subgradient Methods for Online Learning and Stochastic
+ * Optimization, Duchi et al.
+ * - http://jmlr.org/papers/v12/duchi11a.html
+ *
+ * Inputs:
+ * - X: Parameters to update, of shape (any, any).
+ * - dX: Gradient wrt `X` of a loss function being optimized, of
+ * same shape as `X`.
+ * - lr: Learning rate.
+ * - epsilon: Smoothing term to avoid divide by zero errors.
+ * Typical values are in the range of [1e-8, 1e-4].
+ * - cache: State that maintains per-parameter sum of squared
+ * gradients, of same shape as `X`.
+ *
+ * Outputs:
+ * - X: Updated parameters `X`, of same shape as input `X`.
+ * - cache: State that maintains per-parameter sum of squared
+ * gradients, of same shape as `X`.
+ */
def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
l2reg_update(lambda, dmlScript, layer)
if(layer.shouldUpdateWeight) dmlScript.append("\t").append("["+ commaSep(layer.weight, layer.weight+"_cache") + "] " +
@@ -138,11 +178,39 @@ class AdaGrad(lambda:Double=5e-04, epsilon:Double=1e-6) extends CaffeSolver {
def sourceFileName:String = "adagrad"
}
-/**
- * lambda: regularization parameter
- * momentum: Momentum value. Typical values are in the range of [0.5, 0.99], usually started at the lower end and annealed towards the higher end.
- */
class Nesterov(lambda:Double=5e-04, momentum:Double=0.9) extends CaffeSolver {
+ /*
+ * Performs an SGD update with Nesterov momentum.
+ *
+ * As with regular SGD with momentum, in SGD with Nesterov momentum,
+ * we assume that the parameters have a velocity that continues
+ * with some momentum, and that is influenced by the gradient.
+ * In this view specifically, we perform the position update from the
+ * position that the momentum is about to carry the parameters to,
+ * rather than from the previous position. Additionally, we always
+ * store the parameters in their position after momentum.
+ *
+ * Reference:
+ * - Advances in optimizing Recurrent Networks, Bengio et al.,
+ * section 3.5.
+ * - http://arxiv.org/abs/1212.0901
+ *
+ * Inputs:
+ * - X: Parameters to update, of shape (any, any).
+ * - dX: Gradient wrt `X` of a loss function being optimized, of
+ * same shape as `X`.
+ * - lr: Learning rate.
+ * - mu: Momentum value.
+ * Typical values are in the range of [0.5, 0.99], usually
+ * started at the lower end and annealed towards the higher end.
+ * - v: State maintaining the velocity of the parameters `X`, of same
+ * shape as `X`.
+ *
+ * Outputs:
+ * - X: Updated parameters X, of same shape as input X.
+ * - v: Updated velocity of the parameters X, of same shape as
+ * input v.
+ */
def update(dmlScript:StringBuilder, layer:CaffeLayer):Unit = {
val fn = if(Caffe2DML.USE_NESTEROV_UDF) "update_nesterov" else "sgd_nesterov::update"
val lastParameter = if(Caffe2DML.USE_NESTEROV_UDF) (", " + lambda) else ""
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/700b0809/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
index 668d996..456b032 100644
--- a/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/DMLGenerator.scala
@@ -54,7 +54,7 @@ trait BaseDMLGenerator {
def isNumber(x: String):Boolean = x forall Character.isDigit
def transpose(x:String):String = "t(" + x + ")"
def write(varName:String, fileName:String, format:String):String = "write(" + varName + ", \"" + fileName + "\", format=\"" + format + "\")\n"
- def read(varName:String, fileName:String, sep:String="/"):String = varName + " = read(weights + \"" + sep + fileName + "\")\n"
+ def readWeight(varName:String, fileName:String, sep:String="/"):String = varName + " = read(weights + \"" + sep + fileName + "\")\n"
def asDMLString(str:String):String = "\"" + str + "\""
def assign(dmlScript:StringBuilder, lhsVar:String, rhsVar:String):Unit = {
dmlScript.append(lhsVar).append(" = ").append(rhsVar).append("\n")
@@ -132,6 +132,11 @@ trait BaseDMLGenerator {
def nrow(m:String):String = "nrow(" + m + ")"
def ncol(m:String):String = "ncol(" + m + ")"
def customAssert(cond:Boolean, msg:String) = if(!cond) throw new DMLRuntimeException(msg)
+ def multiply(v1:String, v2:String):String = v1 + "*" + v2
+ def colSums(m:String):String = "colSums(" + m + ")"
+ def ifdef(cmdLineVar:String, defaultVal:String):String = "ifdef(" + cmdLineVar + ", " + defaultVal + ")"
+ def ifdef(cmdLineVar:String):String = ifdef(cmdLineVar, "\" \"")
+ def read(filePathVar:String, format:String):String = "read(" + filePathVar + ", format=\""+ format + "\")"
}
trait TabbedDMLGenerator extends BaseDMLGenerator {
@@ -229,14 +234,6 @@ trait VisualizeDMLGenerator extends TabbedDMLGenerator {
+ ");\n")
dmlScript.append("viz_counter = viz_counter + viz_counter1\n")
}
- def appendVisualizationHeaders(dmlScript:StringBuilder, numTabs:Int): Unit = {
- if(doVisualize) {
- tabDMLScript(dmlScript, numTabs).append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
- "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
- tabDMLScript(dmlScript, numTabs).append("viz_counter = 0\n")
- System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
- }
- }
def visualizeLayer(net:CaffeNetwork, layerName:String, varType:String, aggFn:String): Unit = {
// 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
// 'sum', 'mean', 'var' or 'sd'
@@ -316,4 +313,101 @@ trait DMLGenerator extends SourceDMLGenerator with NextBatchGenerator with Visua
tabDMLScript.append("}\n")
}
+ def printClassificationReport():Unit = {
+ ifBlock("debug"){
+ assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
+ assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
+ forBlock("class_i", "1", "num_rows_error_measures") {
+ assign(tabDMLScript, "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )")
+ assign(tabDMLScript, "tp_plus_fp", "sum( (predicted_yb == class_i) )")
+ assign(tabDMLScript, "tp_plus_fn", "sum( (true_yb == class_i) )")
+ assign(tabDMLScript, "precision", "tp / tp_plus_fp")
+ assign(tabDMLScript, "recall", "tp / tp_plus_fn")
+ assign(tabDMLScript, "f1Score", "2*precision*recall / (precision+recall)")
+ assign(tabDMLScript, "error_measures[class_i,1]", "class_i")
+ assign(tabDMLScript, "error_measures[class_i,2]", "precision")
+ assign(tabDMLScript, "error_measures[class_i,3]", "recall")
+ assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
+ assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
+ }
+ val dmlTab = "\\t"
+ val header = "class " + dmlTab + "precision" + dmlTab + "recall " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+ val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
+ tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
+ }
+ }
+
+ // Appends DML corresponding to source and externalFunction statements.
+ def appendHeaders(net:CaffeNetwork, solver:CaffeSolver, isTraining:Boolean):Unit = {
+ // Append source statements for layers as well as solver
+ source(net, solver, if(isTraining) Array[String]("l2_reg") else null)
+
+ if(isTraining) {
+ // Append external built-in function headers:
+ // 1. visualize external built-in function header
+ if(doVisualize) {
+ tabDMLScript.append("visualize = externalFunction(String layerName, String varType, String aggFn, Double x, Double y, String logDir) return (Double B) " +
+ "implemented in (classname=\"org.apache.sysml.udf.lib.Caffe2DMLVisualizeWrapper\",exectype=\"mem\"); \n")
+ tabDMLScript.append("viz_counter = 0\n")
+ System.out.println("Please use the following command for visualizing: tensorboard --logdir=" + tensorboardLogDir)
+ }
+ // 2. update_nesterov external built-in function header
+ if(Caffe2DML.USE_NESTEROV_UDF) {
+ tabDMLScript.append("update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname=\"org.apache.sysml.udf.lib.SGDNesterovUpdate\",exectype=\"mem\"); \n")
+ }
+ }
+ }
+
+ def readMatrix(varName:String, cmdLineVar:String):Unit = {
+ val pathVar = varName + "_path"
+ assign(tabDMLScript, pathVar, ifdef(cmdLineVar))
+ // Uncomment the following lines if we want to the user to pass the format
+ // val formatVar = varName + "_fmt"
+ // assign(tabDMLScript, formatVar, ifdef(cmdLineVar + "_fmt", "\"csv\""))
+ // assign(tabDMLScript, varName, "read(" + pathVar + ", format=" + formatVar + ")")
+ assign(tabDMLScript, varName, "read(" + pathVar + ")")
+ }
+
+ def readInputData(net:CaffeNetwork, isTraining:Boolean):Unit = {
+ // Read and convert to one-hot encoding
+ readMatrix("X_full", "$X")
+ if(isTraining) {
+ readMatrix("y_full", "$y")
+ tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+ tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+ tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+ }
+ else {
+ tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+ }
+ }
+
+ def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean): Unit = {
+ initWeights(net, solver, readWeights, new HashSet[String]())
+ }
+
+ def initWeights(net:CaffeNetwork, solver:CaffeSolver, readWeights:Boolean, layersToIgnore:HashSet[String]): Unit = {
+ tabDMLScript.append("weights = ifdef($weights, \" \")\n")
+ // Initialize the layers and solvers
+ tabDMLScript.append("# Initialize the layers and solvers\n")
+ net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+ if(readWeights) {
+ // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+ tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+ net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(readWeight(l.weight, l.param.getName + "_weight.mtx")))
+ net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(readWeight(l.bias, l.param.getName + "_bias.mtx")))
+ }
+ net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+ }
+
+ def getLossLayers(net:CaffeNetwork):List[IsLossLayer] = {
+ val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
+ if(lossLayers.length != 1)
+ throw new DMLRuntimeException("Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
+ lossLayers
+ }
+
+ def updateMeanVarianceForBatchNorm(net:CaffeNetwork, value:Boolean):Unit = {
+ net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = value)
+ }
}
\ No newline at end of file