You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/04/19 22:08:33 UTC

[2/3] incubator-systemml git commit: [SYSTEMML-692] Added initial version of DML generator for Caffe

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/proto/tensorflow/event.proto
----------------------------------------------------------------------
diff --git a/src/main/proto/tensorflow/event.proto b/src/main/proto/tensorflow/event.proto
new file mode 100644
index 0000000..06d1992
--- /dev/null
+++ b/src/main/proto/tensorflow/event.proto
@@ -0,0 +1,102 @@
+//-------------------------------------------------------------
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+//-------------------------------------------------------------
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "EventProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.util";
+
+import "summary.proto";
+
+// Protocol buffer representing an event that happened during
+// the execution of a Brain model.
+message Event {
+  // Timestamp of the event.
+  double wall_time = 1;
+
+  // Global step of the event.
+  int64 step = 2;
+
+  oneof what {
+    // An event file was started, with the specified version.
+    // This is use to identify the contents of the record IO files
+    // easily.  Current version is "brain.Event:2".  All versions
+    // start with "brain.Event:".
+    string file_version = 3;
+    // An encoded version of a GraphDef.
+    bytes graph_def = 4;
+    // A summary was generated.
+    Summary summary = 5;
+    // The user output a log message. Not all messages are logged, only ones
+    // generated via the Python tensorboard_logging module.
+    LogMessage log_message = 6;
+    // The state of the session which can be used for restarting after crashes.
+    SessionLog session_log = 7;
+    // The metadata returned by running a session.run() call.
+    TaggedRunMetadata tagged_run_metadata = 8;
+    // An encoded version of a MetaGraphDef.
+    bytes meta_graph_def = 9;
+  }
+}
+
+// Protocol buffer used for logging messages to the events file.
+message LogMessage {
+  enum Level {
+    UNKNOWN = 0;
+    // Note: The logging level 10 cannot be named DEBUG. Some software
+    // projects compile their C/C++ code with -DDEBUG in debug builds. So the
+    // C++ code generated from this file should not have an identifier named
+    // DEBUG.
+    DEBUGGING = 10;
+    INFO = 20;
+    WARN = 30;
+    ERROR = 40;
+    FATAL = 50;
+  }
+  Level level = 1;
+  string message = 2;
+}
+
+// Protocol buffer used for logging session state.
+message SessionLog {
+  enum SessionStatus {
+    STATUS_UNSPECIFIED = 0;
+    START = 1;
+    STOP = 2;
+    CHECKPOINT = 3;
+  }
+
+  SessionStatus status = 1;
+  // This checkpoint_path contains both the path and filename.
+  string checkpoint_path = 2;
+  string msg = 3;
+}
+
+// For logging the metadata output for a single session.run() call.
+message TaggedRunMetadata {
+  // Tag name associated with this metadata.
+  string tag = 1;
+  // Byte-encoded version of the `RunMetadata` proto in order to allow lazy
+  // deserialization.
+  bytes run_metadata = 2;
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/proto/tensorflow/summary.proto
----------------------------------------------------------------------
diff --git a/src/main/proto/tensorflow/summary.proto b/src/main/proto/tensorflow/summary.proto
new file mode 100644
index 0000000..fc8053c
--- /dev/null
+++ b/src/main/proto/tensorflow/summary.proto
@@ -0,0 +1,123 @@
+//-------------------------------------------------------------
+//
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+//
+//-------------------------------------------------------------
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "SummaryProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// import "tensorflow/core/framework/tensor.proto";
+
+// Metadata associated with a series of Summary data
+message SummaryDescription {
+  // Hint on how plugins should process the data in this series.
+  // Supported values include "scalar", "histogram", "image", "audio"
+  string type_hint = 1;
+}
+
+// Serialization format for histogram module in
+// core/lib/histogram/histogram.h
+message HistogramProto {
+  double min = 1;
+  double max = 2;
+  double num = 3;
+  double sum = 4;
+  double sum_squares = 5;
+
+  // Parallel arrays encoding the bucket boundaries and the bucket values.
+  // bucket(i) is the count for the bucket i.  The range for
+  // a bucket is:
+  //   i == 0:  -DBL_MAX .. bucket_limit(0)
+  //   i != 0:  bucket_limit(i-1) .. bucket_limit(i)
+  repeated double bucket_limit = 6 [packed = true];
+  repeated double bucket = 7 [packed = true];
+};
+
+// A Summary is a set of named values to be displayed by the
+// visualizer.
+//
+// Summaries are produced regularly during training, as controlled by
+// the "summary_interval_secs" attribute of the training operation.
+// Summaries are also produced at the end of an evaluation.
+message Summary {
+  message Image {
+    // Dimensions of the image.
+    int32 height = 1;
+    int32 width = 2;
+    // Valid colorspace values are
+    //   1 - grayscale
+    //   2 - grayscale + alpha
+    //   3 - RGB
+    //   4 - RGBA
+    //   5 - DIGITAL_YUV
+    //   6 - BGRA
+    int32 colorspace = 3;
+    // Image data in encoded format.  All image formats supported by
+    // image_codec::CoderUtil can be stored here.
+    bytes encoded_image_string = 4;
+  }
+
+  message Audio {
+    // Sample rate of the audio in Hz.
+    float sample_rate = 1;
+    // Number of channels of audio.
+    int64 num_channels = 2;
+    // Length of the audio in frames (samples per channel).
+    int64 length_frames = 3;
+    // Encoded audio data and its associated RFC 2045 content type (e.g.
+    // "audio/wav").
+    bytes encoded_audio_string = 4;
+    string content_type = 5;
+  }
+
+  message Value {
+    // Name of the node that output this summary; in general, the name of a
+    // TensorSummary node. If the node in question has multiple outputs, then
+    // a ":\d+" suffix will be appended, like "some_op:13".
+    // Might not be set for legacy summaries (i.e. those not using the tensor
+    // value field)
+    string node_name = 7;
+
+    // Tag name for the data.  Will only be used by legacy summaries
+    // (ie. those not using the tensor value field)
+    // For legacy summaries, will be used as the title of the graph
+    // in the visualizer.
+    //
+    // Tag is usually "op_name:value_name", where "op_name" itself can have
+    // structure to indicate grouping.
+    string tag = 1;
+
+    // Value associated with the tag.
+    oneof value {
+      float simple_value = 2;
+      bytes obsolete_old_style_histogram = 3;
+      Image image = 4;
+      HistogramProto histo = 5;
+      Audio audio = 6;
+      // TensorProto tensor = 8;
+    }
+  }
+
+  // Set of values for the summary.
+  repeated Value value = 1;
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/python/setup.py
----------------------------------------------------------------------
diff --git a/src/main/python/setup.py b/src/main/python/setup.py
index 635dad7..fcda255 100644
--- a/src/main/python/setup.py
+++ b/src/main/python/setup.py
@@ -38,10 +38,12 @@ ARTIFACT_VERSION_SHORT = ARTIFACT_VERSION.split("-")[0]
 
 numpy_version = '1.8.2'
 scipy_version = '0.15.1'
+pillow_version = '2.0.0'
 REQUIRED_PACKAGES = [
     'numpy >= %s' % numpy_version,
     'scipy >= %s' % scipy_version,
-    'pandas'
+    'pandas',
+    'Pillow >= %s' % pillow_version
 ]
 
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/python/systemml/converters.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/converters.py b/src/main/python/systemml/converters.py
index 9651f14..8bf05d7 100644
--- a/src/main/python/systemml/converters.py
+++ b/src/main/python/systemml/converters.py
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF']
+__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF', 'convertImageToNumPyArr']
 
 import numpy as np
 import pandas as pd
@@ -118,6 +118,35 @@ def convertToNumPyArr(sc, mb):
     else:
         raise TypeError('sc needs to be of type SparkContext') # TODO: We can generalize this by creating py4j gateway ourselves
 
+# Example usage: convertImageToNumPyArr(im, img_shape=(3, 224, 224), add_rotated_images=True, add_mirrored_images=True)
+# The above call returns a numpy array of shape (6, 50176) in NCHW format
+def convertImageToNumPyArr(im, img_shape=None, add_rotated_images=False, add_mirrored_images=False):
+    from PIL import Image
+    if img_shape is not None:
+        num_channels = img_shape[0]
+        size = (img_shape[1], img_shape[2])
+    else:
+        num_channels = 1 if im.mode == 'L' else 3
+        size = None
+    if num_channels != 1 and num_channels != 3:
+        raise ValueError('Expected the number of channels to be either 1 or 3')
+    if size is not None:
+        im = im.resize(size, Image.LANCZOS)
+    expected_mode = 'L' if num_channels == 1 else 'RGB'
+    if expected_mode is not im.mode:
+        im = im.convert(expected_mode)
+    def _im2NumPy(im):
+        if expected_mode == 'L':
+            return np.asarray(im.getdata()).reshape((1, -1))
+        else:
+            # (H,W,C) --> (C,H,W) --> (1, C*H*W)
+            return np.asarray(im).transpose(2, 0, 1).reshape((1, -1))
+    ret = _im2NumPy(im)
+    if add_rotated_images:
+        ret = np.vstack((ret, _im2NumPy(im.rotate(90)), _im2NumPy(im.rotate(180)), _im2NumPy(im.rotate(270)) ))
+    if add_mirrored_images:
+        ret = np.vstack((ret, _im2NumPy(im.transpose(Image.FLIP_LEFT_RIGHT)), _im2NumPy(im.transpose(Image.FLIP_TOP_BOTTOM))))
+    return ret
 
 def convertToPandasDF(X):
     if not isinstance(X, pd.DataFrame):

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index d6ad069..94aa1f2 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -19,7 +19,7 @@
 #
 #-------------------------------------------------------------
 
-__all__ = ['LinearRegression', 'LogisticRegression', 'SVM', 'NaiveBayes']
+__all__ = ['LinearRegression', 'LogisticRegression', 'SVM', 'NaiveBayes', 'Caffe2DML']
 
 import numpy as np
 from pyspark.ml import Estimator
@@ -45,6 +45,7 @@ def assemble(sparkSession, pdf, inputCols, outputCol):
 class BaseSystemMLEstimator(Estimator):
     features_col = 'features'
     label_col = 'label'
+    do_visualize = False
     
     def set_features_col(self, colName):
         """
@@ -66,6 +67,21 @@ class BaseSystemMLEstimator(Estimator):
         """
         self.label_col = colName
 
+    def setGPU(self, enableGPU):
+        self.estimator.setGPU(enableGPU)
+        return self
+    
+    def setExplain(self, explain):
+        self.estimator.setExplain(explain)
+        return self
+            
+    def setStatistics(self, stat):
+        self.estimator.setStatistics(stat)
+        return self
+    
+    def setConfigProperty(self, propertyName, propertyValue):
+        self.estimator.setConfigProperty(propertyName, propertyValue)
+        return self
     
     def _fit_df(self):
         try:
@@ -158,6 +174,11 @@ class BaseSystemMLEstimator(Estimator):
         ----------
         X: NumPy ndarray, Pandas DataFrame, scipy sparse matrix or PySpark DataFrame
         """
+        try:
+            if self.estimator is not None and self.model is not None:
+                self.estimator.copyProperties(self.model)
+        except AttributeError:
+            pass
         if isinstance(X, SUPPORTED_TYPES):
             if self.transferUsingDF:
                 pdfX = convertToPandasDF(X)
@@ -206,6 +227,13 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         else:
             return [ self.labelMap[int(i)] for i in y ]
         
+    def predict(self, X):
+        predictions = np.asarray(super(BaseSystemMLClassifier, self).predict(X))
+        try:
+            return np.asarray(predictions, dtype='double')
+        except ValueError:
+            return np.asarray(predictions, dtype='str')
+            
     def score(self, X, y):
         """
         Scores the predicted value with ground truth 'y'
@@ -215,8 +243,11 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         X: NumPy ndarray, Pandas DataFrame, scipy sparse matrix
         y: NumPy ndarray, Pandas DataFrame, scipy sparse matrix
         """
-        return accuracy_score(y, self.predict(X))
-
+        predictions = np.asarray(self.predict(X))
+        if np.issubdtype(predictions.dtype.type, np.number):
+            return accuracy_score(y, predictions)
+        else:
+            return accuracy_score(np.asarray(y, dtype='str'), np.asarray(predictions, dtype='str'))
 
 class BaseSystemMLRegressor(BaseSystemMLEstimator):
 
@@ -499,4 +530,133 @@ class NaiveBayes(BaseSystemMLClassifier):
         self.estimator = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes(self.uid, self.sc._jsc.sc())
         self.estimator.setLaplace(laplace)
         self.transferUsingDF = transferUsingDF
-        self.setOutputRawPredictionsToFalse = False
\ No newline at end of file
+        self.setOutputRawPredictionsToFalse = False
+
+class Caffe2DML(BaseSystemMLClassifier):
+    """
+    Performs training/prediction for a given caffe network.
+    
+    Examples
+    --------
+    
+    >>> from systemml.mllearn import Caffe2DML
+    >>> from pyspark.sql import SQLContext
+    >>> sqlCtx = SQLContext(sc)
+    >>> from mlxtend.data import mnist_data
+    >>> import numpy as np
+    >>> from sklearn.utils import shuffle
+    >>> X, y = mnist_data()
+    >>> X, y = shuffle(X, y)
+    >>> imgShape = (1, 28, 28)
+    >>> import urllib
+    >>> urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet.proto', 'lenet.proto')
+    >>> urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet_solver.proto', 'lenet_solver.proto')
+    >>> caffe2DML = Caffe2DML(sqlCtx, 'lenet_solver.proto').set(max_iter=500)
+    >>> caffe2DML.fit(X, y)
+    """
+    def __init__(self, sqlCtx, solver, input_shape, weights=None, ignore_weights=None, transferUsingDF=False, tensorboard_log_dir=None):
+        """
+        Performs training/prediction for a given caffe network. 
+
+        Parameters
+        ----------
+        sqlCtx: PySpark SQLContext
+        solver: caffe solver file path
+        input_shape: 3-element list (number of channels, input height, input width)
+        weights: directory whether learned weights are stored (default: None)
+        ignore_weights: names of layers to not read from the weights directory (list of string, default:None)
+        transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False)
+        tensorboard_log_dir: directory to store the event logs (default: None, we use a temporary directory)
+        """
+        self.sqlCtx = sqlCtx
+        self.sc = sqlCtx._sc
+        self.uid = "Caffe2DML"
+        self.model = None
+        if len(input_shape) != 3:
+            raise ValueError('Expected input_shape as list of 3 element')
+        solver = self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver(solver)
+        self.estimator = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML(self.sc._jsc.sc(), solver, str(input_shape[0]), str(input_shape[1]), str(input_shape[2]))
+        self.weights = weights
+        if weights is not None:
+            self.estimator.setInput("$weights", str(weights))
+            self._loadLabelTxt()
+            if ignore_weights is not None:
+                self.estimator.setWeightsToIgnore(ignore_weights)
+        self.transferUsingDF = transferUsingDF
+        self.setOutputRawPredictionsToFalse = False
+        self.visualize_called = False
+        if tensorboard_log_dir is not None:
+            self.estimator.setTensorBoardLogDir(tensorboard_log_dir)
+    
+    def _loadLabelTxt(self, format="binary", sep="/"):
+        if(self.weights is not None):
+            self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(self.estimator)
+            df = self.sqlCtx.read.csv(self.weights + sep + 'labels.txt', header=False).toPandas()
+            keys = np.asarray(df._c0, dtype='int')
+            values = np.asarray(df._c1, dtype='str')
+            self.labelMap = {}
+            self.le = None
+            for i in range(len(keys)):
+                self.labelMap[int(keys[i])] = values[i]
+            # self.encode(classes) # Giving incorrect results
+    
+    def set(self, num_classes=None, debug=None):
+        """
+        Set input to Caffe2DML
+        
+        Parameters
+        ----------
+        debug: to add debugging DML code such as classification report, print DML script, etc (default: False)
+        """
+        if debug is not None: self.estimator.setInput("$debug", str(debug).upper())
+        return self
+    
+    def visualize(self, layerName=None, varType='weight', aggFn='mean'):
+        """
+        Use this to visualize the training procedure (requires validation_percentage to be non-zero).
+        When one provides no argument to this method, we visualize training and validation loss.
+        
+        Parameters
+        ----------
+        layerName: Name of the layer in the Caffe prototype
+        varType: should be either 'weight', 'bias', 'dweight', 'dbias', 'output' or 'doutput'
+        aggFn: should be either 'sum', 'mean', 'var' or 'sd'
+        """
+        valid_vis_var_types = ['weight', 'bias', 'dweight', 'dbias', 'output', 'doutput']
+        valid_vis_aggFn = [ 'sum', 'mean', 'var', 'sd' ]
+        if layerName is not None and varType is not None and aggFn is not None:
+            # Visualize other values
+            if varType not in valid_vis_var_types:
+                raise ValueError('The second argument should be either weight, bias, dweight, dbias, output or doutput')
+            if aggFn not in valid_vis_aggFn:
+                raise ValueError('The third argument should be either sum, mean, var, sd.')
+            if self.visualize_called:
+                self.estimator.visualizeLoss()
+            self.estimator.visualizeLayer(layerName, varType, aggFn)
+        else:
+            self.estimator.visualizeLoss()
+        self.visualize_called = True
+        return self
+    
+    def save(self, outputDir, format='binary', sep='/'):
+        """
+        Save a trained model.
+        
+        Parameters
+        ----------
+        outputDir: Directory to save the model to
+        format: optional format (default: 'binary')
+        sep: seperator to use (default: '/')
+        """
+        if self.model != None:
+            self.model.save(outputDir, format, sep)
+            if self.le is not None:
+                labelMapping = dict(enumerate(list(self.le.classes_), 1))
+            else:
+                labelMapping = self.labelMap
+            lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ]
+            df = self.sqlCtx.createDataFrame(lStr)
+            df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False)
+        else:
+            raise Exception('Cannot save as you need to train the model first using fit')
+        return self

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
new file mode 100644
index 0000000..7ab9160
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -0,0 +1,510 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.dl
+
+import caffe.Caffe.LayerParameter;
+import caffe.Caffe.NetParameter;
+import caffe.Caffe.SolverParameter;
+
+import org.apache.sysml.parser.LanguageException;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.api.ml.ScriptsUtils
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import scala.collection.JavaConversions._
+import java.util.ArrayList
+import caffe.Caffe.Phase
+import caffe.Caffe
+import java.util.HashSet
+import org.apache.sysml.api.DMLScript
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+import org.apache.sysml.api.ml._
+import java.util.Random
+import org.apache.commons.logging.Log
+import org.apache.commons.logging.LogFactory
+import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer
+
+
+object Caffe2DML  {
+  val LOG = LogFactory.getLog(classOf[Caffe2DML].getName())
+  def fileSep():String = { if(File.separator.equals("\\")) "\\\\" else File.separator }
+  def setNNLibraryPath(path:String):Unit = { prefix = path + fileSep + "nn"}  
+  // ------------------------------------------------------------------------
+  var prefix = Utils.getPrefix()
+  def layerDir = prefix + fileSep + "layers" + fileSep
+  def optimDir = prefix + fileSep + "optim" + fileSep
+  
+  // Naming conventions:
+  val X = "X"; val y = "y"; val batchSize = "BATCH_SIZE"; val numImages = "num_images"; val numValidationImages = "num_validation"
+  val XVal = "X_val"; val yVal = "y_val"
+}
+
+class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter, 
+    val solver:CaffeSolver, val net:CaffeNetwork, 
+    val lrPolicy:LearningRatePolicy, val numChannels:String, val height:String, val width:String) extends Estimator[Caffe2DMLModel] 
+  with BaseSystemMLClassifier with DMLGenerator {
+  // --------------------------------------------------------------
+  // Invoked by Python, MLPipeline
+  def this(sc: SparkContext, solver1:Caffe.SolverParameter, networkPath:String, numChannels:String, height:String, width:String) {
+    this(sc, solver1, Utils.parseSolver(solver1), 
+        new CaffeNetwork(networkPath, caffe.Caffe.Phase.TRAIN, numChannels, height, width),
+        new LearningRatePolicy(solver1), numChannels, height, width)
+  }
+  def this(sc: SparkContext, solver1:Caffe.SolverParameter, numChannels:String, height:String, width:String) {
+    this(sc, solver1, Utils.parseSolver(solver1), new CaffeNetwork(solver1.getNet, caffe.Caffe.Phase.TRAIN, numChannels, height, width), 
+        new LearningRatePolicy(solver1), numChannels, height, width)
+  } 
+  val uid:String = "caffe_classifier_" + (new Random).nextLong
+  override def copy(extra: org.apache.spark.ml.param.ParamMap): Estimator[Caffe2DMLModel] = {
+    val that = new Caffe2DML(sc, solverParam, solver, net, lrPolicy, numChannels, height, width)
+    copyValues(that, extra)
+  }
+  // Note: will update the y_mb as this will be called by Python mllearn
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): Caffe2DMLModel = {
+    val ret = baseFit(X_mb, y_mb, sc)
+    new Caffe2DMLModel(ret, Utils.numClasses(net), sc, solver, net, lrPolicy, this)
+  }
+  def fit(df: ScriptsUtils.SparkDataType): Caffe2DMLModel = {
+    val ret = baseFit(df, sc)
+    new Caffe2DMLModel(ret, Utils.numClasses(net), sc, solver, net, lrPolicy, this)
+  }
+	// --------------------------------------------------------------
+  
+  // Used for simplifying transfer learning
+  private val layersToIgnore:HashSet[String] = new HashSet[String]() 
+  def setWeightsToIgnore(layerName:String):Unit = layersToIgnore.add(layerName)
+  def setWeightsToIgnore(layerNames:ArrayList[String]):Unit = layersToIgnore.addAll(layerNames)
+  	  
+  // Input parameters to prediction and scoring script
+  val inputs:java.util.HashMap[String, String] = new java.util.HashMap[String, String]()
+  def setInput(key: String, value:String):Unit = inputs.put(key, value)
+  customAssert(solverParam.getTestIterCount <= 1, "Multiple test_iter variables are not supported")
+  customAssert(solverParam.getMaxIter > 0, "Please set max_iter to a positive value")
+  customAssert(net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[IsLossLayer]).length == 1, "Expected exactly one loss layer")
+    
+  // TODO: throw error or warning if user tries to set solver_mode == GPU instead of using setGPU method
+  
+  // Method called by Python mllearn to visualize variable of certain layer
+  def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = visualizeLayer(net, layerName, varType, aggFn)
+  
+  // -------------------------------------------------------------------------------------------
+  // Helper functions to generate DML
+  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+  private def trainTestSplit(numValidationBatches:Int):Unit = {
+    if(numValidationBatches > 0) {
+      if(solverParam.getDisplay <= 0) 
+        throw new DMLRuntimeException("Since test_iter and test_interval is greater than zero, you should set display to be greater than zero")
+      tabDMLScript.append(Caffe2DML.numValidationImages).append(" = " + numValidationBatches + " * " + Caffe2DML.batchSize + "\n")
+      tabDMLScript.append("# Sanity check to ensure that validation set is not too large\n")
+      val maxValidationSize = "ceil(0.3 * " + Caffe2DML.numImages + ")"
+      ifBlock(Caffe2DML.numValidationImages  + " > " + maxValidationSize) {
+        assign(tabDMLScript, "max_test_iter", "floor(" + maxValidationSize + " / " + Caffe2DML.batchSize + ")")
+        tabDMLScript.append("stop(" +
+            dmlConcat(asDMLString("Too large validation size. Please reduce test_iter to "), "max_test_iter") 
+            + ")\n")
+      }
+      val one = "1"
+      val rl = int_add(Caffe2DML.numValidationImages, one)
+      rightIndexing(tabDMLScript.append(Caffe2DML.X).append(" = "), "X_full", rl, Caffe2DML.numImages, null, null)
+      tabDMLScript.append("; ")
+      rightIndexing(tabDMLScript.append(Caffe2DML.y).append(" = "), "y_full", rl, Caffe2DML.numImages, null, null)
+      tabDMLScript.append("; ")
+      rightIndexing(tabDMLScript.append(Caffe2DML.XVal).append(" = "), "X_full", one, Caffe2DML.numValidationImages, null, null)
+      tabDMLScript.append("; ")
+      rightIndexing(tabDMLScript.append(Caffe2DML.yVal).append(" = "), "y_full", one, Caffe2DML.numValidationImages, null, null)
+      tabDMLScript.append("; ")
+      tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y)\n")
+    }
+    else {
+      assign(tabDMLScript, Caffe2DML.X, "X_full")
+	    assign(tabDMLScript, Caffe2DML.y, "y_full")
+	    tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(" + Caffe2DML.y + ")\n")
+    }
+  }
+  
+  private def printClassificationReport():Unit = {
+    ifBlock("debug"){
+      assign(tabDMLScript, "num_rows_error_measures", min("10", ncol("yb")))
+      assign(tabDMLScript, "error_measures", matrix("0", "num_rows_error_measures", "5"))
+      forBlock("class_i", "1", "num_rows_error_measures") {
+        assign(tabDMLScript, "tp", "sum( (true_yb == predicted_yb) * (true_yb == class_i) )")
+        assign(tabDMLScript, "tp_plus_fp", "sum( (predicted_yb == class_i) )")
+        assign(tabDMLScript, "tp_plus_fn", "sum( (true_yb == class_i) )")
+        assign(tabDMLScript, "precision", "tp / tp_plus_fp")
+        assign(tabDMLScript, "recall", "tp / tp_plus_fn")
+        assign(tabDMLScript, "f1Score", "2*precision*recall / (precision+recall)")
+        assign(tabDMLScript, "error_measures[class_i,1]", "class_i")
+        assign(tabDMLScript, "error_measures[class_i,2]", "precision")
+        assign(tabDMLScript, "error_measures[class_i,3]", "recall")
+        assign(tabDMLScript, "error_measures[class_i,4]", "f1Score")
+        assign(tabDMLScript, "error_measures[class_i,5]", "tp_plus_fn")
+      }
+      val dmlTab = "\\t"
+      val header = "class    " + dmlTab + "precision" + dmlTab + "recall  " + dmlTab + "f1-score" + dmlTab + "num_true_labels\\n"
+      val errorMeasures = "toString(error_measures, decimal=7, sep=" + asDMLString(dmlTab) + ")"
+      tabDMLScript.append(print(dmlConcat(asDMLString(header), errorMeasures)))
+    }
+  }
+  
+  // Append the DML to display training and validation loss
+  private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit = {
+    if(solverParam.getDisplay > 0) {
+      // Append the DML to compute training loss
+      tabDMLScript.append("# Compute training loss & accuracy\n")
+      ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
+        assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
+        lossLayer.computeLoss(dmlScript, numTabs)
+        assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, "training_accuracy", "accuracy")
+        tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
+            asDMLString(", training loss:"), "training_loss", asDMLString(", training accuracy:"), "training_accuracy" )))
+        appendTrainingVisualizationBody(dmlScript, numTabs)
+        printClassificationReport
+      }
+      if(shouldValidate) {
+        // Append the DML to compute validation loss
+        val numValidationBatches = if(solverParam.getTestIterCount > 0) solverParam.getTestIter(0) else 0
+        tabDMLScript.append("# Compute validation loss & accuracy\n")
+        ifBlock("iter  %% " + solverParam.getTestInterval + " == 0") {
+          assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", "0")
+          solverParam.getTestAlgo.toLowerCase match {
+            case "minibatch" => {
+              assign(tabDMLScript, "validation_loss", "0")
+              assign(tabDMLScript, "validation_accuracy", "0")
+              forBlock("iVal", "1", "num_iters_per_epoch") {
+    	          getValidationBatch(tabDMLScript)
+    	          tabDMLScript.append("iter = start_iter + i\n")
+    	          forward;  lossLayer.computeLoss(dmlScript, numTabs)
+                tabDMLScript.append("validation_loss = validation_loss + loss\n")
+                tabDMLScript.append("validation_accuracy = validation_accuracy + accuracy\n")
+    	        }
+              tabDMLScript.append("validation_accuracy = validation_accuracy / num_iters_per_epoch\n")
+            }
+            case "batch" => {
+              assign(tabDMLScript, "Xb", Caffe2DML.XVal); assign(tabDMLScript, "yb", Caffe2DML.yVal)
+              net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
+              lossLayer.computeLoss(dmlScript, numTabs)
+              assign(tabDMLScript, "validation_loss", "loss"); assign(tabDMLScript, "validation_accuracy", "accuracy")
+              
+            }
+            case _ => throw new DMLRuntimeException("Unsupported test algo:" + solverParam.getTestAlgo)
+          }
+          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
+              asDMLString(", validation loss:"), "validation_loss", asDMLString(", validation accuracy:"), "validation_accuracy" )))
+          appendValidationVisualizationBody(dmlScript, numTabs)
+        }
+      }
+    }
+  }
+  
+  private def performSnapshot():Unit = {
+    if(solverParam.getSnapshot > 0) {
+      ifBlock("iter %% snapshot == 0") {
+        tabDMLScript.append("snapshot_dir= \"" + solverParam.getSnapshotPrefix + "\" + \"/iter_\" + iter + \"/\"\n")
+        net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(write(l.weight, "snapshot_dir + \"" + l.param.getName + "_weight.mtx\"", "binary")))
+  		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(write(l.bias, "snapshot_dir + \"" + l.param.getName + "_bias.mtx\"", "binary")))
+      }
+  	}
+  }
+  
+  private def forward():Unit = {
+    tabDMLScript.append("# Perform forward pass\n")
+	  net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, false))
+  }
+  private def backward():Unit = backward("")
+  private def backward(suffix:String):Unit = {
+    tabDMLScript.append("# Perform backward pass\n")
+    net.getLayers.reverse.map(layer => net.getCaffeLayer(layer).backward(tabDMLScript, suffix))
+  }
+  private def update():Unit = {
+    tabDMLScript.append("# Update the parameters\n")
+    net.getLayers.map(layer => solver.update(tabDMLScript, net.getCaffeLayer(layer)))
+  }
+  private def initAggGradients():Unit = {
+    tabDMLScript.append("# Data structure to store gradients computed in parallel")
+    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", matrix("0", "parallel_batches", multiply(nrow(l.weight), ncol(l.weight))))
+      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", matrix("0", "parallel_batches", multiply(nrow(l.bias), ncol(l.bias)))) 
+    })
+  }
+  private def flattenAndStoreAggGradients_j():Unit = {
+    tabDMLScript.append("# Flatten and store gradients for this parallel execution\n")
+    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
+          matrix(l.dWeight, "1", multiply(nrow(l.weight), ncol(l.weight)))) 
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias + "_agg[j,]", 
+          matrix(l.dBias, "1", multiply(nrow(l.bias), ncol(l.bias))))
+    })
+  }
+  private def aggregateAggGradients():Unit = {
+    tabDMLScript.append("# Aggregate the gradients\n")
+    net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight, 
+          matrix(colSums(l.dWeight + "_agg"), nrow(l.weight), ncol(l.weight))) 
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dBias, 
+          matrix(colSums(l.dBias + "_agg"), nrow(l.bias), ncol(l.bias)))
+    })
+  }
+  // -------------------------------------------------------------------------------------------
+  
+  private def multiply(v1:String, v2:String):String = v1 + "*" + v2
+  private def colSums(m:String):String = "colSums(" + m + ")"
+  
+	// Script generator
+	def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+	  val startTrainingTime = System.nanoTime()
+	  val DEBUG_TRAINING = if(inputs.containsKey("$debug")) inputs.get("$debug").toLowerCase.toBoolean else false
+    reset()
+	  
+	  // Add source for layers as well as solver as well as visualization header
+	  source(net, solver, Array[String]("l2_reg"))
+	  appendVisualizationHeaders(dmlScript, numTabs)
+	  
+	  // Read and convert to one-hote encoding
+	  assign(tabDMLScript, "X_full", "read(\" \", format=\"csv\")")
+	  assign(tabDMLScript, "y_full", "read(\" \", format=\"csv\")")
+	  tabDMLScript.append(Caffe2DML.numImages).append(" = nrow(y_full)\n")
+	  tabDMLScript.append("weights = ifdef($weights, \" \")\n")
+	  tabDMLScript.append("debug = ifdef($debug, FALSE)\n")
+	  tabDMLScript.append("# Convert to one-hot encoding (Assumption: 1-based labels) \n")
+	  tabDMLScript.append("y_full = table(seq(1," + Caffe2DML.numImages + ",1), y_full, " + Caffe2DML.numImages + ", " + Utils.numClasses(net) + ")\n")
+	  
+	  // Initialize the layers and solvers
+	  tabDMLScript.append("# Initialize the layers and solvers\n")
+	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+	  if(inputs.containsKey("$weights")) {
+		  // Loading existing weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias
+		  tabDMLScript.append("# Load the weights. Note: keeping the initialization code in case the layer wants to initialize non-weights and non-bias\n")
+		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
+		  net.getLayers.filter(l => !layersToIgnore.contains(l)).map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
+	  }
+	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+	  
+	  // Split into training and validation set
+	  // Initializes Caffe2DML.X, Caffe2DML.y, Caffe2DML.XVal, Caffe2DML.yVal and Caffe2DML.numImages
+	  val shouldValidate = solverParam.getTestInterval > 0 && solverParam.getTestIterCount > 0 && solverParam.getTestIter(0) > 0
+	  trainTestSplit(if(shouldValidate) solverParam.getTestIter(0) else 0)
+	  
+	  // Set iteration-related variables such as max_epochs, num_iters_per_epoch, lr, etc.
+	  val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
+	  if(lossLayers.length != 1) throw new DMLRuntimeException("Expected exactly one loss layer")
+	  solverParam.getTrainAlgo.toLowerCase match {
+	    case "batch" => 
+	      assign(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString)
+	    case _ => {
+	      ceilDivide(tabDMLScript, "num_iters_per_epoch", Caffe2DML.numImages, Caffe2DML.batchSize)
+	      ceilDivide(tabDMLScript, "max_epochs", solverParam.getMaxIter.toString, "num_iters_per_epoch")
+	    }
+	  }
+	  assign(tabDMLScript, "start_iter", "0")
+	  assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
+	  
+	  // ----------------------------------------------------------------------------
+	  // Main logic
+	  forBlock("e", "1", "max_epochs") {
+	    solverParam.getTrainAlgo.toLowerCase match {
+	      case "minibatch" => 
+	        forBlock("i", "1", "num_iters_per_epoch") {
+	          getTrainingBatch(tabDMLScript)
+	          tabDMLScript.append("iter = start_iter + i\n")
+	          forward; backward; update
+	          displayLoss(lossLayers(0), shouldValidate)
+            performSnapshot
+	        }
+	      case "batch" => {
+          tabDMLScript.append("iter = start_iter + i\n")
+          forward; backward; update
+          displayLoss(lossLayers(0), shouldValidate)
+          performSnapshot
+	      }
+	      case "allreduce" => {
+	        forBlock("i", "1", "num_iters_per_epoch") {
+	          getTrainingBatch(tabDMLScript)
+	          assign(tabDMLScript, "X_group_batch", "Xb")
+	          assign(tabDMLScript, "y_group_batch", "yb")
+	          tabDMLScript.append("iter = start_iter + i\n")
+	          initAggGradients
+	          parForBlock("j", "1", "nrow(y_group_batch)") {
+	            assign(tabDMLScript, "Xb", "X_group_batch[j,]")
+	            assign(tabDMLScript, "yb", "y_group_batch[j,]")
+	            forward; backward("_agg")
+              flattenAndStoreAggGradients_j
+	          }
+	          aggregateAggGradients
+            tabDMLScript.append("iter = start_iter + parallel_batches\n")    
+	          update
+            displayLoss(lossLayers(0), shouldValidate)
+            performSnapshot
+	        }
+	      }
+	      case _ => throw new DMLRuntimeException("Unsupported train algo:" + solverParam.getTrainAlgo)
+	    }
+	    // After every epoch, update the learning rate
+	    tabDMLScript.append("# Learning rate\n")
+	    lrPolicy.updateLearningRate(tabDMLScript)
+	    tabDMLScript.append("start_iter = start_iter + num_iters_per_epoch\n")
+	  }
+	  // ----------------------------------------------------------------------------
+	  
+	  // Check if this is necessary
+	  if(doVisualize) tabDMLScript.append("print(" + asDMLString("Visualization counter:") + " + viz_counter)")
+	  
+	  val trainingScript = tabDMLScript.toString()
+	  // Print script generation time and the DML script on stdout
+	  System.out.println("Time taken to generate training script from Caffe proto: " + ((System.nanoTime() - startTrainingTime)*1e-9) + " seconds." )
+	  if(DEBUG_TRAINING) Utils.prettyPrintDMLScript(trainingScript)
+	  
+	  // Set input/output variables and execute the script
+	  val script = dml(trainingScript).in(inputs)
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.out(l.weight))
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.out(l.bias))
+	  (script, "X_full", "y_full")
+	}
+}
+
+class Caffe2DMLModel(val mloutput: MLResults,  
+    val numClasses:String, val sc: SparkContext, val solver:CaffeSolver,
+    val net:CaffeNetwork, val lrPolicy:LearningRatePolicy,
+    val estimator:Caffe2DML) 
+  extends Model[Caffe2DMLModel] with HasMaxOuterIter with BaseSystemMLClassifierModel with DMLGenerator {
+  // --------------------------------------------------------------
+  // Invoked by Python, MLPipeline
+  val uid:String = "caffe_model_" + (new Random).nextLong 
+  def this(estimator:Caffe2DML) =  {
+    this(null, Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
+        estimator.net,
+        // new CaffeNetwork(estimator.solverParam.getNet, caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, estimator.width), 
+        estimator.lrPolicy, estimator) 
+  }
+      
+  override def copy(extra: org.apache.spark.ml.param.ParamMap): Caffe2DMLModel = {
+    val that = new Caffe2DMLModel(mloutput, numClasses, sc, solver, net, lrPolicy, estimator)
+    copyValues(that, extra)
+  }
+  // --------------------------------------------------------------
+  
+  def save(outputDir:String, format:String="binary", sep:String="/"):Unit = {
+	  if(mloutput == null) throw new DMLRuntimeException("Cannot save as you need to train the model first using fit")
+	  val dmlScript = new StringBuilder
+	  dmlScript.append("print(\"Saving the model to " + outputDir + "...\")\n")
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => dmlScript.append(write(l.weight, outputDir + sep + l.param.getName + "_weight.mtx", format)))
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => dmlScript.append(write(l.bias, outputDir + sep + l.param.getName + "_bias.mtx", format)))
+	  
+	  val script = dml(dmlScript.toString)
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, mloutput.getBinaryBlockMatrix(l.weight)))
+	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, mloutput.getBinaryBlockMatrix(l.bias)))
+	  val ml = new MLContext(sc)
+	  ml.execute(script)
+	}
+    
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+    reset()
+    val startPredictionTime = System.nanoTime()
+	  val DEBUG_PREDICTION = if(estimator.inputs.containsKey("$debug")) estimator.inputs.get("$debug").toLowerCase.toBoolean else false
+	  
+	  // Append source statements for each layer
+	  source(net, solver, null)
+    tabDMLScript.append("weights = ifdef($weights, \" \")\n")
+	  // Initialize the layers and solvers
+	  tabDMLScript.append("# Initialize the layers and solvers\n")
+	  net.getLayers.map(layer => net.getCaffeLayer(layer).init(tabDMLScript))
+	  if(mloutput == null && estimator.inputs.containsKey("$weights")) {
+		  // fit was not called
+		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => tabDMLScript.append(read(l.weight, l.param.getName + "_weight.mtx")))
+		  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => tabDMLScript.append(read(l.bias, l.param.getName + "_bias.mtx")))
+	  }
+	  else if(mloutput == null) {
+		  throw new DMLRuntimeException("Cannot call predict/score without calling either fit or by providing weights")
+	  }
+	  net.getLayers.map(layer => solver.init(tabDMLScript, net.getCaffeLayer(layer)))
+	  
+//	  if(estimator.inputs.containsKey("$debug") && estimator.inputs.get("$debug").equals("TRUE")) {
+//		  System.out.println("The output shape of layers:")
+//		  net.getLayers.map(layer =>  System.out.println(net.getCaffeLayer(layer).param.getName + " " + net.getCaffeLayer(layer).outputShape))
+//	  }
+	  
+	  // Donot update mean and variance in batchnorm
+	  net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = false)
+	  tabDMLScript.append("X_full = read(\" \", format=\"csv\")\n")
+	  assign(tabDMLScript, "X", "X_full")
+	  tabDMLScript.append(Caffe2DML.numImages + " = nrow(X_full)\n")
+	  
+	  val lossLayers = net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]).map(layer => net.getCaffeLayer(layer).asInstanceOf[IsLossLayer])
+	  customAssert(lossLayers.length == 1, "Expected exactly one loss layer, but found " + lossLayers.length + ":" + net.getLayers.filter(layer => net.getCaffeLayer(layer).isInstanceOf[IsLossLayer]))
+	  assign(tabDMLScript, "Prob", matrix("0", Caffe2DML.numImages, numClasses))
+	  estimator.solverParam.getTestAlgo.toLowerCase match {
+      case "minibatch" => {
+        ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
+        forBlock("i", "1", "num_iters") {
+          getTestBatch(tabDMLScript)
+          net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, true))
+          assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
+        }
+      }
+      case "batch" => {
+        assign(tabDMLScript, "Xb", "X_full")
+        net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, true))
+        assign(tabDMLScript, "Prob", lossLayers(0).out)
+      }
+      case "allreduce" => {
+        ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, Caffe2DML.batchSize)
+        parForBlock("i", "1", "num_iters") {
+          getTestBatch(tabDMLScript)
+          net.getLayers.map(layer => net.getCaffeLayer(layer).forward(tabDMLScript, true))
+          assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
+        }
+      }
+      case _ => throw new DMLRuntimeException("Unsupported test algo:" + estimator.solverParam.getTestAlgo)
+    }
+		
+		val predictionScript = dmlScript.toString()
+		System.out.println("Time taken to generate prediction script from Caffe proto:" + ((System.nanoTime() - startPredictionTime)*1e-9) + "secs." )
+		if(DEBUG_PREDICTION) Utils.prettyPrintDMLScript(predictionScript)
+		
+		// Reset
+		net.getLayers.filter(net.getCaffeLayer(_).isInstanceOf[BatchNorm]).map(net.getCaffeLayer(_).asInstanceOf[BatchNorm].update_mean_var = true)
+		
+	  val script = dml(predictionScript).out("Prob").in(estimator.inputs)
+	  if(mloutput != null) {
+	    // fit was called
+  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, mloutput.getBinaryBlockMatrix(l.weight)))
+  	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, mloutput.getBinaryBlockMatrix(l.bias)))
+	  }
+	  
+	  (script, "X_full")
+  }
+  
+  // Prediction
+  def transform(X: MatrixBlock): MatrixBlock = {
+	  baseTransform(X, mloutput, sc, "Prob")
+  }
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = {
+	  baseTransform(df, mloutput, sc, "Prob")
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
new file mode 100644
index 0000000..4faa203
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -0,0 +1,357 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.dl
+
+import caffe.Caffe.LayerParameter
+import scala.collection.JavaConversions._
+import org.apache.sysml.parser.LanguageException
+import java.util.HashSet
+import java.io.File
+import org.apache.sysml.api.DMLScript
+import org.apache.sysml.runtime.util.ConvolutionUtils
+import caffe.Caffe.EltwiseParameter.EltwiseOp
+import org.apache.sysml.runtime.DMLRuntimeException;
+import java.util.ArrayList
+
+trait CaffeLayer extends BaseDMLGenerator {
+  // -------------------------------------------------
+  // Any layer that wants to reuse SystemML-NN has to override following methods that help in generating the DML for the given layer:
+  def sourceFileName:String;
+  def init(dmlScript:StringBuilder):Unit;
+  def forward(dmlScript:StringBuilder, isPrediction:Boolean):Unit;
+  def backward(dmlScript:StringBuilder, outSuffix:String):Unit;
+  var computedOutputShape:(String, String, String) = null
+  def outputShape:(String, String, String) = {
+    if(computedOutputShape == null) computedOutputShape = bottomLayerOutputShape
+    computedOutputShape
+  }
+  // -------------------------------------------------
+  var computedBottomLayerOutputShape:(String, String, String) = null
+  def bottomLayerOutputShape:(String, String, String) = {
+    if(computedBottomLayerOutputShape == null) {
+      val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if(ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
+      computedBottomLayerOutputShape = ret(0).outputShape
+    }
+    computedBottomLayerOutputShape
+  }
+  def param:LayerParameter
+  def id:Int
+  def net:CaffeNetwork
+  // --------------------------------------------------------------------------------------
+  // No need to override these methods in subclasses
+  // Exception: Only Data layer overrides "out" method to use 'Xb' for consistency
+  // Naming of the below methods is consistent with the nn library:
+  // X (feature map from the previous layer) ----> Forward pass  ----> out (feature map to the next layer)
+  // dX (errors to the previous layer)       <---- Backward pass <---- dout (errors from the next layer)
+  def out = "out" + id  
+  var computedX:String = null
+  def X:String = {
+    if(computedX == null) {
+      val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if(ret.size == 0) throw new LanguageException("Expected atleast 1 bottom layer for " + param.getName)
+      else if(ret.size == 1)    computedX = ret(0).out
+      else                      computedX = sum(new StringBuilder, ret.map(_.out).toList).toString()
+    }
+    computedX
+  }
+  var computedDout:String = null
+  def dout: String = {
+    if(computedDout == null) {
+      val ret = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if(ret.size == 0) throw new LanguageException("Expected atleast 1 top layer for " + param.getName)
+      else if(ret.size == 1)     computedDout = ret(0).dX
+      else                       computedDout = sum(new StringBuilder, ret.map(_.dX).toList).toString()
+    }
+    computedDout
+  }
+  val dX = "dOut" + id
+  // --------------------------------------------------------------------------------------
+  // No need to override these methods in subclasses, instead classes that have weights and biases 
+  // should implement HasWeight and HasBias traits.
+  def dWeight():String = throw new DMLRuntimeException("dWeight is not implemented in super class")
+  def dBias():String = throw new DMLRuntimeException("dBias is not implemented in super class")
+  def weight():String = null;
+  def bias():String = null;
+  def shouldUpdateWeight():Boolean = if(weight != null) true else false
+  def shouldUpdateBias():Boolean = if(bias != null) true else false
+  // --------------------------------------------------------------------------------------
+  // Helper methods to simplify the code of subclasses
+  def invokeInit(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
+    invoke(dmlScript, sourceFileName + "::", returnVariables, "init", arguments.toList)
+  }
+  def invokeForward(dmlScript:StringBuilder, returnVariables:List[String], arguments:String*):Unit = {
+    invoke(dmlScript, sourceFileName + "::", returnVariables, "forward", arguments.toList)
+  }
+  def invokeBackward(dmlScript:StringBuilder, outSuffix:String, resultVariables:List[String],  arguments:String*):Unit = {
+    invoke(dmlScript, sourceFileName + "::", resultVariables.map(_ + outSuffix), "backward", arguments.toList)
+  }
+  // --------------------------------------------------------------------------------------
+}
+
+
+trait IsLossLayer extends CaffeLayer {
+  def computeLoss(dmlScript:StringBuilder, numTabs:Int):Unit
+}
+
+trait HasWeight extends CaffeLayer {
+  override def weight = "W" + id
+  override def dWeight = "dW" + id
+}
+
+trait HasBias extends CaffeLayer {
+  override def bias = "b" + id
+  override def dBias = "db" + id
+}
+
+class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numChannels:String, val height:String, val width:String) extends CaffeLayer {
+  // -------------------------------------------------
+  override def sourceFileName = null
+  override def init(dmlScript:StringBuilder) = {
+    if(param.hasTransformParam && param.getTransformParam.hasScale) {
+      dmlScript.append("X_full = X_full * " + param.getTransformParam.getScale + "\n")
+    }
+    dmlScript.append("BATCH_SIZE = " + param.getDataParam.getBatchSize + "\n")
+  }
+  var dataOutputShape = ("$num_channels", "$height", "$width")
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = { }
+  override def out = "Xb"
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = { }
+  override def outputShape = (numChannels, height, width)
+  // -------------------------------------------------
+}
+
+
+// ------------------------------------------------------------------
+// weight is ema_mean and bias is ema_var
+// Fuse 
+class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  // val scale =  
+  override def sourceFileName = "batch_norm2d"
+  override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](gamma, beta, ema_mean, ema_var), numChannels)
+  var update_mean_var = true
+  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+    val mode = if(isPrediction) "\"test\"" else "\"train\""
+    invokeForward(dmlScript, List[String](out, withSuffix(ema_mean), withSuffix(ema_var), withSuffix(cache_mean), withSuffix(cache_var), withSuffix(cache_norm)), 
+        X, gamma, beta, numChannels, Hin, Win, mode, ema_mean, ema_var,  ma_fraction, eps)  
+  }
+  
+  def backward(dmlScript: StringBuilder, outSuffix:String): Unit = {
+    invokeBackward(dmlScript, outSuffix, List[String](dX, dgamma, dbeta), dout, out, ema_mean, ema_var, cache_mean, cache_var, cache_norm, X, gamma, beta, numChannels, 
+          Hin, Win, "\"train\"", ema_mean, ema_var,  ma_fraction, eps)
+  }
+  
+  private def withSuffix(str:String):String = if(update_mean_var) str else str + "_ignore"
+  override def weight = "ema_mean" + id
+  override def bias = "ema_var" + id
+  def cache_mean(): String = "cache_mean" + id
+  def cache_var():String = "cache_mean" + id
+  def cache_norm():String = "cache_norm" + id
+  var scaleLayer:Scale = null
+  def gamma():String = { checkNextLayer(); scaleLayer.weight }
+  def ma_fraction():String = if(param.getBatchNormParam.hasMovingAverageFraction()) param.getBatchNormParam.getMovingAverageFraction.toString else "0.999"
+  def eps():String = if(param.getBatchNormParam.hasEps()) param.getBatchNormParam.getEps.toString else "1e-5"
+  def beta():String = { checkNextLayer(); scaleLayer.bias }
+  def dgamma():String = { checkNextLayer();  scaleLayer.dWeight }
+  def dbeta():String = { checkNextLayer();  scaleLayer.dBias }
+  override def shouldUpdateWeight():Boolean = false
+  override def shouldUpdateBias():Boolean = false
+  def ema_mean(): String = weight
+  def ema_var(): String = bias
+  def checkNextLayer(): Unit = {
+    if(scaleLayer == null) {
+      val topLayers = net.getTopLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+      if(topLayers.length != 1 && !topLayers(0).isInstanceOf[Scale]) throw new LanguageException("Only one top layer of type Scale allowed for BatchNorm")
+      scaleLayer = topLayers(0).asInstanceOf[Scale]
+    }
+  }
+  def numChannels = bottomLayerOutputShape._1
+  def Hin = bottomLayerOutputShape._2
+  def Win = bottomLayerOutputShape._3
+}
+// weight is gamma and bias is beta
+class Scale(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  if(!param.getScaleParam.getBiasTerm) throw new LanguageException("Add \"scale_param { bias_term: true }\" to the layer " + param.getName)
+  override def sourceFileName = null
+  override def init(dmlScript: StringBuilder): Unit = {}
+  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = assign(dmlScript, out, X)
+  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+}
+// ------------------------------------------------------------------
+
+class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName = null
+  override def init(dmlScript: StringBuilder): Unit = {}
+  if(param.getEltwiseParam.hasOperation && param.getEltwiseParam.getOperation != EltwiseOp.SUM)
+    throw new LanguageException("Currently only elementwise sum operation supported")
+  def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = {
+    addAndAssign(dmlScript, out, param.getBottomList.map(b => net.getCaffeLayer(b).out).toList)
+  }
+  override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assign(dmlScript, dX + outSuffix, dout)
+  override def outputShape = {
+    if(_out == null) _out = net.getCaffeLayer(net.getBottomLayers(param.getName).take(1).toSeq.get(0)).outputShape
+    _out
+  }
+  var _out:(String, String, String) = null
+  
+}
+
+class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with IsLossLayer {
+  // -------------------------------------------------
+  override def sourceFileName = "softmax"
+  override def init(dmlScript:StringBuilder) = {}
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
+    invokeForward(dmlScript, List[String](out), scores)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) =  {
+    invoke(dmlScript, "cross_entropy_loss::", List[String]("dProbs" + outSuffix), "backward", out, "yb")
+    invoke(dmlScript.append("\t"), "softmax::", List[String](dX + outSuffix), "backward", "dProbs", scores)
+  }
+  override def computeLoss(dmlScript:StringBuilder, numTabs:Int) = {
+    val tabBuilder = new StringBuilder
+    for(i <- 0 until numTabs) tabBuilder.append("\t")
+    val tabs = tabBuilder.toString
+    dmlScript.append("tmp_loss = cross_entropy_loss::forward(" + commaSep(out, "yb") + ")\n")
+    dmlScript.append(tabs).append("loss = loss + tmp_loss\n")
+    dmlScript.append(tabs).append("true_yb = rowIndexMax(yb)\n")
+    dmlScript.append(tabs).append("predicted_yb = rowIndexMax(" + out + ")\n")
+    dmlScript.append(tabs).append("accuracy = mean(predicted_yb == true_yb)*100\n")
+  }
+  def scores():String = {
+	  val ret = net.getBottomLayers(param.getName).map(l => net.getCaffeLayer(l)).toList
+	  if(ret.size == 1) return ret.get(0).out
+	  else if(ret.size == 2) {
+		  val ret1 = if(!ret.get(0).out.equals("Xb")) ret.get(0).out else ""; 
+		  val ret2 = if(!ret.get(1).out.equals("Xb")) ret.get(1).out else "";
+		  if(!ret1.equals("") && !ret2.equals("")) throw new LanguageException("Atleast one of the output of previous layer should be Xb")
+		  else if(!ret1.equals("")) return ret1
+		  else return ret2
+	  }
+	  else 
+		  throw new LanguageException("More than 2 bottom layers is not supported")
+  }
+  // -------------------------------------------------
+}
+
+class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  // -------------------------------------------------
+  override def sourceFileName = "relu"
+  override def init(dmlScript:StringBuilder) = { }
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X)
+  // -------------------------------------------------
+}
+
+class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  // -------------------------------------------------
+  override def sourceFileName = "dropout"
+  override def init(dmlScript:StringBuilder) = { }
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) =
+    if(!isPrediction)
+      invokeForward(dmlScript, List[String](out, mask), X, p, seed)
+    else
+      assign(dmlScript, out, X) // Forward-pass not required to be performed during prediction for Dropout layer
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String](dX), dout, X, p, mask)
+  // -------------------------------------------------
+  def mask = "mask" + id
+  def p = param.getDropoutParam.getDropoutRatio.toString
+  def seed = "-1"
+}
+
+class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  // -------------------------------------------------
+  override def sourceFileName = "affine"
+  override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numFeatures, numNeurons)
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
+      invokeForward(dmlScript, List[String](out), X, weight, bias)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
+      invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, X, weight, bias)
+  // -------------------------------------------------
+  def numNeurons = param.getInnerProductParam.getNumOutput.toString
+  def numFeatures = int_mult(bottomLayerOutputShape._1, bottomLayerOutputShape._2, bottomLayerOutputShape._3)
+  override def outputShape = ( param.getInnerProductParam.getNumOutput.toString, "1", "1" )
+}
+
+class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  // -------------------------------------------------
+  override def sourceFileName = "max_pool2d_builtin"
+  override def init(dmlScript:StringBuilder) = {}
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
+    invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
+        X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
+    invokeBackward(dmlScript, outSuffix, List[String](dX), dout, Hout, Wout, X, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def outputShape = ( numChannels, Hout, Wout )
+  // -------------------------------------------------
+  def Hin = bottomLayerOutputShape._2
+  def Win = bottomLayerOutputShape._3
+  def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h)
+  def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+  def poolingParam = param.getPoolingParam
+  def numChannels = bottomLayerOutputShape._1
+  def kernel_h = if(poolingParam.hasKernelH) poolingParam.getKernelH.toString 
+                   else poolingParam.getKernelSize.toString 
+  def kernel_w = if(poolingParam.hasKernelW) poolingParam.getKernelW.toString 
+                   else poolingParam.getKernelSize.toString
+  def stride_h = if(poolingParam.hasStrideH) poolingParam.getStrideH.toString 
+                   else poolingParam.getStride.toString
+  def stride_w = if(poolingParam.hasStrideW) poolingParam.getStrideW.toString 
+                   else poolingParam.getStride.toString
+  def pad_h =   if(poolingParam.hasPadH) poolingParam.getPadH.toString 
+                   else poolingParam.getPad.toString
+  def pad_w =   if(poolingParam.hasPadW) poolingParam.getPadW.toString 
+                   else poolingParam.getPad.toString
+}
+
+class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
+  // -------------------------------------------------
+  override def sourceFileName = "conv2d_builtin";
+  override def init(dmlScript:StringBuilder) = invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = 
+    invokeForward(dmlScript, List[String](out, "ignoreHout_"+id, "ignoreWout_"+id), 
+        X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = 
+    invokeBackward(dmlScript, outSuffix, List[String](dX, dWeight, dBias), dout, Hout, Wout, X, weight, bias, numChannels, Hin, Win, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w)
+  override def outputShape = ( numKernels, Hout, Wout )
+  // -------------------------------------------------
+  def numChannels = bottomLayerOutputShape._1
+  def Hin = bottomLayerOutputShape._2
+  def Win = bottomLayerOutputShape._3
+  def Hout = ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._2, kernel_h, stride_h, pad_h) 
+  def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
+  def convParam = param.getConvolutionParam
+  def numKernels = convParam.getNumOutput.toString
+  def kernel_h = if(convParam.hasKernelH) convParam.getKernelH.toString 
+                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
+                   else throw new LanguageException("Incorrect kernel parameters")
+  def kernel_w = if(convParam.hasKernelW) convParam.getKernelW.toString 
+                   else if(convParam.getKernelSizeCount > 0)  convParam.getKernelSize(0).toString 
+                   else throw new LanguageException("Incorrect kernel parameters")
+  def stride_h = if(convParam.hasStrideH) convParam.getStrideH.toString 
+                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
+                   else throw new LanguageException("Incorrect stride parameters:" + convParam.getStrideH + " " + convParam.getStrideList + " " + param.getName)
+  def stride_w = if(convParam.hasStrideW) convParam.getStrideW.toString 
+                   else if(convParam.getStrideCount > 0)  convParam.getStride(0).toString 
+                   else throw new LanguageException("Incorrect stride parameters")
+  def pad_h =   if(convParam.hasPadH) convParam.getPadH.toString 
+                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
+                   else throw new LanguageException("Incorrect pad parameters")
+  def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
+                   else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
+                   else throw new LanguageException("Incorrect pad parameters")
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/cc7993fc/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
new file mode 100644
index 0000000..e585e30
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.dl
+
+import org.apache.sysml.runtime.DMLRuntimeException
+import scala.collection.JavaConversions._
+import caffe.Caffe.NetParameter
+import caffe.Caffe.LayerParameter
+import caffe.Caffe.Phase
+import java.util.ArrayList
+import java.util.HashSet
+import scala.collection.mutable.Stack
+import org.apache.sysml.parser.LanguageException;
+import java.util.HashMap
+import caffe.Caffe.PoolingParameter
+import org.apache.commons.logging.LogFactory
+
+trait Network {
+  def getLayers(): List[String]
+  def getCaffeLayer(layerName:String):CaffeLayer
+  def getBottomLayers(layerName:String): Set[String]
+  def getTopLayers(layerName:String): Set[String]
+  def getLayerID(layerName:String): Int
+}
+
+object CaffeNetwork {
+  val LOG = LogFactory.getLog(classOf[CaffeNetwork].getName)
+}
+
+class CaffeNetwork(netFilePath:String, val currentPhase:Phase, 
+     val numChannels:String, val height:String, val width:String
+    ) extends Network {
+  private def isIncludedInCurrentPhase(l:LayerParameter): Boolean = {
+    if(l.getIncludeCount == 0) true else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
+  }
+  private var id = 1
+  
+  // --------------------------------------------------------------------------------
+  private var _caffeLayerParams:List[LayerParameter] = Utils.readCaffeNet(netFilePath).getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+  // --------------------------------------------------------------------------------
+  
+  private var _layerNames: List[String] = _caffeLayerParams.map(l => l.getName).toList
+  CaffeNetwork.LOG.debug("Layers in current phase:" + _layerNames)
+  
+  // Condition 1: assert that each name is unique
+  private val _duplicateLayerNames =_layerNames.diff(_layerNames.distinct)
+  if(_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
+  
+  // Condition 2: only 1 top name, except Data layer
+  private val _condition2Exceptions = Set("data")
+  _caffeLayerParams.filter(l => !_condition2Exceptions.contains(l.getType.toLowerCase)).map(l => if(l.getTopCount != 1) throw new LanguageException("Multiple top layers is not supported for " + l.getName))
+
+  // Condition 3: Replace top layer names referring to a Data layer with its name
+  // Example: layer{ name: mnist, top: data, top: label, ... }
+  private val _topToNameMappingForDataLayer = new HashMap[String, String]()
+  private def containsOnly(list:java.util.List[String], v:String): Boolean = list.toSet.diff(Set(v)).size() == 0
+  private def isData(l:LayerParameter):Boolean = l.getType.equalsIgnoreCase("data")
+  private def replaceTopWithNameOfDataLayer(l:LayerParameter):LayerParameter =  {
+    if(containsOnly(l.getTopList,l.getName))
+      return l
+    else {
+      val builder = l.toBuilder(); 
+      for(i <- 0 until l.getTopCount) {
+        if(! l.getTop(i).equals(l.getName)) { _topToNameMappingForDataLayer.put(l.getTop(i), l.getName) }
+        builder.setTop(i, l.getName) 
+      }
+      return builder.build() 
+    }
+  }
+  // 3a: Replace top of DataLayer with its names
+  // Example: layer{ name: mnist, top: mnist, top: mnist, ... }
+  _caffeLayerParams = _caffeLayerParams.map(l => if(isData(l)) replaceTopWithNameOfDataLayer(l) else l)
+  private def replaceBottomOfNonDataLayers(l:LayerParameter):LayerParameter = {
+    val builder = l.toBuilder();
+    // Note: Top will never be Data layer
+    for(i <- 0 until l.getBottomCount) {
+      if(_topToNameMappingForDataLayer.containsKey(l.getBottom(i))) 
+        builder.setBottom(i, _topToNameMappingForDataLayer.get(l.getBottom(i)))
+    }
+    return builder.build()
+  }
+  // 3a: If top/bottom of other layers refer DataLayer, then replace them
+  // layer { name: "conv1_1", type: "Convolution", bottom: "data"
+  _caffeLayerParams = if(_topToNameMappingForDataLayer.size == 0) _caffeLayerParams else _caffeLayerParams.map(l => if(isData(l)) l else replaceBottomOfNonDataLayers(l))
+  
+  // Condition 4: Deal with fused layer
+  // Example: layer { name: conv1, top: conv1, ... } layer { name: foo, bottom: conv1, top: conv1 }
+  private def isFusedLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && l.getBottomCount == 1 && l.getTop(0).equalsIgnoreCase(l.getBottom(0))
+  private def containsReferencesToFusedLayer(l:LayerParameter):Boolean = l.getBottomList.foldLeft(false)((prev, bLayer) => prev || _fusedTopLayer.containsKey(bLayer))
+  private val _fusedTopLayer = new HashMap[String, String]()
+  _caffeLayerParams = _caffeLayerParams.map(l => {
+    if(isFusedLayer(l)) {
+      val builder = l.toBuilder();
+      if(_fusedTopLayer.containsKey(l.getBottom(0))) {
+        builder.setBottom(0, _fusedTopLayer.get(l.getBottom(0)))
+      }
+      builder.setTop(0, l.getName)
+      _fusedTopLayer.put(l.getBottom(0), l.getName)
+      builder.build()
+    }
+    else if(containsReferencesToFusedLayer(l)) {
+      val builder = l.toBuilder();
+      for(i <- 0 until l.getBottomCount) {
+        if(_fusedTopLayer.containsKey(l.getBottomList.get(i))) {
+          builder.setBottom(i, _fusedTopLayer.get(l.getBottomList.get(i)))
+        }
+      }
+      builder.build()
+    }
+    else l
+  })
+
+  // --------------------------------------------------------------------------------
+  
+  // Helper functions to extract bottom and top layers
+  private def convertTupleListToMap(m:List[(String, String)]):Map[String, Set[String]] = m.groupBy(_._1).map(x => (x._1, x._2.map(y => y._2).toSet)).toMap
+  private def flipKeyValues(t:List[(String, Set[String])]): List[(String, String)] = t.flatMap(x => x._2.map(b => b -> x._1))
+  private def expandBottomList(layerName:String, bottomList:java.util.List[String]): List[(String, String)] = bottomList.filter(b => !b.equals(layerName)).map(b => layerName -> b).toList 
+  
+  // The bottom layers are the layers available in the getBottomList (from Caffe .proto files)
+  private val _bottomLayers:Map[String, Set[String]] = convertTupleListToMap(
+      _caffeLayerParams.flatMap(l => expandBottomList(l.getName, l.getBottomList)))
+  CaffeNetwork.LOG.info("Bottom layers:" + _bottomLayers)
+  
+  // Find the top layers by reversing the bottom list
+  private val _topLayers:Map[String, Set[String]] = convertTupleListToMap(flipKeyValues(_bottomLayers.toList))
+  CaffeNetwork.LOG.info("Top layers:" + _topLayers)
+  
+  private val _layers: Map[String, CaffeLayer] = _caffeLayerParams.map(l => l.getName -> convertLayerParameterToCaffeLayer(l)).toMap
+  CaffeNetwork.LOG.info("Layers:" + _layers)
+  private val _layerIDs: Map[String, Int] = _layers.entrySet().map(x => x.getKey -> x.getValue.id).toMap
+  
+  
+  private def throwException(layerName:String) = throw new LanguageException("Layer with name " + layerName + " not found")                              
+  def getLayers(): List[String] =  _layerNames
+  def getCaffeLayer(layerName:String):CaffeLayer = if(checkKey(_layers, layerName)) _layers.get(layerName).get else throwException(layerName)
+  def getBottomLayers(layerName:String): Set[String] =  if(checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
+  def getTopLayers(layerName:String): Set[String] = if(checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
+  def getLayerID(layerName:String): Int = if(checkKey(_layerIDs, layerName))  _layerIDs.get(layerName).get else throwException(layerName)
+  
+  // Helper functions
+  private def checkKey(m:Map[String, Any], key:String): Boolean = {
+    if(m == null) throw new LanguageException("Map is null")
+    else if(key == null) throw new LanguageException("key is null")
+    else m.containsKey(key)
+  }
+  private def convertLayerParameterToCaffeLayer(param:LayerParameter):CaffeLayer = {
+    id = id + 1
+    param.getType.toLowerCase() match {
+      case "convolution" => new Convolution(param, id, this)
+      case "pooling" => if(param.getPoolingParam.getPool == PoolingParameter.PoolMethod.MAX)  new MaxPooling(param, id, this)
+                        else throw new LanguageException("Only maxpooling is supported:" + param.getPoolingParam.getPool.name)
+      case "innerproduct" => new InnerProduct(param, id, this)
+      case "relu" => new ReLU(param, id, this)
+      case "softmaxwithloss" => new SoftmaxWithLoss(param, id, this)
+      case "dropout" => new Dropout(param, id, this)
+      case "data" => new Data(param, id, this, numChannels, height, width)
+      case "batchnorm" => new BatchNorm(param, id, this)
+      case "scale" => new Scale(param, id, this)
+      case "eltwise" => new Elementwise(param, id, this)
+      case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
+    }
+  }
+}
\ No newline at end of file