You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/26 05:35:59 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-1632] Support loading and saving models via mllearn

Repository: incubator-systemml
Updated Branches:
  refs/heads/master d36a0c1b0 -> d69f3441c


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py
index deed4c2..ec225c4 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -276,15 +276,22 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
     def decode(self, y):
         if self.le is not None:
             return self.le.inverse_transform(np.asarray(y - 1, dtype=int))
-        else:
+        elif self.labelMap is not None:
             return [ self.labelMap[int(i)] for i in y ]
+        else:
+            return y
         
     def predict(self, X):
-        predictions = np.asarray(super(BaseSystemMLClassifier, self).predict(X))
-        try:
-            return np.asarray(predictions, dtype='double')
-        except ValueError:
-            return np.asarray(predictions, dtype='str')
+        predictions = super(BaseSystemMLClassifier, self).predict(X)
+        from pyspark.sql.dataframe import DataFrame as df
+        if type(predictions) == df:
+            return predictions
+        else:
+            try:
+                return np.asarray(predictions, dtype='double')
+            except ValueError:
+                print(type(predictions))
+                return np.asarray(predictions, dtype='str')
             
     def score(self, X, y):
         """
@@ -300,6 +307,55 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
             return accuracy_score(y, predictions)
         else:
             return accuracy_score(np.asarray(y, dtype='str'), np.asarray(predictions, dtype='str'))
+            
+    def loadLabels(self, file_path):
+        createJavaObject(self.sc, 'dummy')
+        utilObj = self.sc._jvm.org.apache.sysml.api.ml.Utils()
+        if utilObj.checkIfFileExists(file_path):
+            df = self.sparkSession.read.csv(file_path, header=False).toPandas()
+            keys = np.asarray(df._c0, dtype='int')
+            values = np.asarray(df._c1, dtype='str')
+            self.labelMap = {}
+            self.le = None
+            for i in range(len(keys)):
+                self.labelMap[int(keys[i])] = values[i]
+            # self.encode(classes) # Giving incorrect results
+        
+    def load(self, weights=None, sep='/'):
+        """
+        Load a pretrained model. 
+
+        Parameters
+        ----------
+        weights: directory whether learned weights are stored (default: None)
+        sep: seperator to use (default: '/')
+        """
+        self.weights = weights
+        self.model.load(self.sc._jsc, weights, sep)
+        self.loadLabels(weights + '/labels.txt')
+        
+    def save(self, outputDir, format='binary', sep='/'):
+        """
+        Save a trained model.
+        
+        Parameters
+        ----------
+        outputDir: Directory to save the model to
+        format: optional format (default: 'binary')
+        sep: seperator to use (default: '/')
+        """
+        if self.model != None:
+            self.model.save(self.sc._jsc, outputDir, format, sep)
+            if self.le is not None:
+                labelMapping = dict(enumerate(list(self.le.classes_), 1))
+            else:
+                labelMapping = self.labelMap
+            lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ]
+            df = self.sparkSession.createDataFrame(lStr)
+            df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False)
+        else:
+            raise Exception('Cannot save as you need to train the model first using fit')
+        return self
 
 class BaseSystemMLRegressor(BaseSystemMLEstimator):
 
@@ -319,6 +375,34 @@ class BaseSystemMLRegressor(BaseSystemMLEstimator):
         y: NumPy ndarray, Pandas DataFrame, scipy sparse matrix
         """
         return r2_score(y, self.predict(X), multioutput='variance_weighted')
+        
+    def load(self, weights=None, sep='/'):
+        """
+        Load a pretrained model. 
+
+        Parameters
+        ----------
+        weights: directory whether learned weights are stored (default: None)
+        sep: seperator to use (default: '/')
+        """
+        self.weights = weights
+        self.model.load(self.sc._jsc, weights, sep)
+
+    def save(self, outputDir, format='binary', sep='/'):
+        """
+        Save a trained model.
+        
+        Parameters
+        ----------
+        outputDir: Directory to save the model to
+        format: optional format (default: 'binary')
+        sep: seperator to use (default: '/')
+        """
+        if self.model != None:
+            self.model.save(outputDir, format, sep)
+        else:
+            raise Exception('Cannot save as you need to train the model first using fit')
+        return self
 
 
 class LogisticRegression(BaseSystemMLClassifier):
@@ -411,11 +495,12 @@ class LogisticRegression(BaseSystemMLClassifier):
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = True
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegressionModel(self.estimator)
         if penalty != 'l2':
             raise Exception('Only l2 penalty is supported')
         if solver != 'newton-cg':
             raise Exception('Only newton-cg solver supported')
-
+        
 
 class LinearRegression(BaseSystemMLRegressor):
     """
@@ -481,6 +566,7 @@ class LinearRegression(BaseSystemMLRegressor):
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.LinearRegressionModel(self.estimator)
 
 
 class SVM(BaseSystemMLClassifier):
@@ -526,6 +612,7 @@ class SVM(BaseSystemMLClassifier):
         self.sc = sparkSession._sc
         self.uid = "svm"
         createJavaObject(self.sc, 'dummy')
+        self.is_multi_class = is_multi_class
         self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM(self.uid, self.sc._jsc.sc(), is_multi_class)
         self.estimator.setMaxIter(max_iter)
         if C <= 0:
@@ -537,7 +624,7 @@ class SVM(BaseSystemMLClassifier):
         self.estimator.setIcpt(icpt)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
-
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.SVMModel(self.estimator, self.is_multi_class)
 
 class NaiveBayes(BaseSystemMLClassifier):
     """
@@ -583,6 +670,7 @@ class NaiveBayes(BaseSystemMLClassifier):
         self.estimator.setLaplace(laplace)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
+        self.model = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayesModel(self.estimator)
 
 class Caffe2DML(BaseSystemMLClassifier):
     """
@@ -592,8 +680,6 @@ class Caffe2DML(BaseSystemMLClassifier):
     --------
     
     >>> from systemml.mllearn import Caffe2DML
-    >>> from pyspark.sql import SQLContext
-    >>> sqlCtx = SQLContext(sc)
     >>> from mlxtend.data import mnist_data
     >>> import numpy as np
     >>> from sklearn.utils import shuffle
@@ -603,25 +689,23 @@ class Caffe2DML(BaseSystemMLClassifier):
     >>> import urllib
     >>> urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet.proto', 'lenet.proto')
     >>> urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet_solver.proto', 'lenet_solver.proto')
-    >>> caffe2DML = Caffe2DML(sqlCtx, 'lenet_solver.proto').set(max_iter=500)
+    >>> caffe2DML = Caffe2DML(spark, 'lenet_solver.proto').set(max_iter=500)
     >>> caffe2DML.fit(X, y)
     """
-    def __init__(self, sqlCtx, solver, input_shape, weights=None, ignore_weights=None, transferUsingDF=False, tensorboard_log_dir=None):
+    def __init__(self, sparkSession, solver, input_shape, transferUsingDF=False, tensorboard_log_dir=None):
         """
         Performs training/prediction for a given caffe network. 
 
         Parameters
         ----------
-        sqlCtx: PySpark SQLContext
+        sparkSession: PySpark SparkSession
         solver: caffe solver file path
         input_shape: 3-element list (number of channels, input height, input width)
-        weights: directory whether learned weights are stored (default: None)
-        ignore_weights: names of layers to not read from the weights directory (list of string, default:None)
         transferUsingDF: whether to pass the input dataset via PySpark DataFrame (default: False)
         tensorboard_log_dir: directory to store the event logs (default: None, we use a temporary directory)
         """
-        self.sqlCtx = sqlCtx
-        self.sc = sqlCtx._sc
+        self.sparkSession = sparkSession
+        self.sc = sparkSession._sc
         createJavaObject(self.sc, 'dummy')
         self.uid = "Caffe2DML"
         self.model = None
@@ -629,30 +713,30 @@ class Caffe2DML(BaseSystemMLClassifier):
             raise ValueError('Expected input_shape as list of 3 element')
         solver = self.sc._jvm.org.apache.sysml.api.dl.Utils.readCaffeSolver(solver)
         self.estimator = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DML(self.sc._jsc.sc(), solver, str(input_shape[0]), str(input_shape[1]), str(input_shape[2]))
-        self.weights = weights
-        if weights is not None:
-            self.estimator.setInput("$weights", str(weights))
-            self._loadLabelTxt()
-            if ignore_weights is not None:
-                self.estimator.setWeightsToIgnore(ignore_weights)
         self.transferUsingDF = transferUsingDF
         self.setOutputRawPredictionsToFalse = False
         self.visualize_called = False
         if tensorboard_log_dir is not None:
             self.estimator.setTensorBoardLogDir(tensorboard_log_dir)
-    
-    def _loadLabelTxt(self, format="binary", sep="/"):
-        if(self.weights is not None):
-            self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(self.estimator)
-            df = self.sqlCtx.read.csv(self.weights + sep + 'labels.txt', header=False).toPandas()
-            keys = np.asarray(df._c0, dtype='int')
-            values = np.asarray(df._c1, dtype='str')
-            self.labelMap = {}
-            self.le = None
-            for i in range(len(keys)):
-                self.labelMap[int(keys[i])] = values[i]
-            # self.encode(classes) # Giving incorrect results
-    
+
+    def load(self, weights=None, sep='/', ignore_weights=None):
+        """
+        Load a pretrained model. 
+
+        Parameters
+        ----------
+        weights: directory whether learned weights are stored (default: None)
+        sep: seperator to use (default: '/')
+        ignore_weights: names of layers to not read from the weights directory (list of string, default:None)
+        """
+        self.weights = weights
+        self.estimator.setInput("$weights", str(weights))
+        self.model = self.sc._jvm.org.apache.sysml.api.dl.Caffe2DMLModel(self.estimator)
+        self.model.load(self.sc._jsc, weights, sep)
+        self.loadLabels(weights + '/labels.txt')
+        if ignore_weights is not None:
+            self.estimator.setWeightsToIgnore(ignore_weights)
+            
     def set(self, num_classes=None, debug=None):
         """
         Set input to Caffe2DML
@@ -691,25 +775,4 @@ class Caffe2DML(BaseSystemMLClassifier):
         self.visualize_called = True
         return self
     
-    def save(self, outputDir, format='binary', sep='/'):
-        """
-        Save a trained model.
-        
-        Parameters
-        ----------
-        outputDir: Directory to save the model to
-        format: optional format (default: 'binary')
-        sep: seperator to use (default: '/')
-        """
-        if self.model != None:
-            self.model.save(outputDir, format, sep)
-            if self.le is not None:
-                labelMapping = dict(enumerate(list(self.le.classes_), 1))
-            else:
-                labelMapping = self.labelMap
-            lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ]
-            df = self.sqlCtx.createDataFrame(lStr)
-            df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False)
-        else:
-            raise Exception('Cannot save as you need to train the model first using fit')
-        return self
+    

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index fe6b159..7fb3e17 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -55,15 +55,35 @@ import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyze
 /***************************************************************************************
 DESIGN OF CAFFE2DML:
 
-1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that needed to be implemented are:
+1. Caffe2DML is designed to fit well into the mllearn framework. Hence, the key methods that were to be implemented are:
 - `getTrainingScript` for the Estimator class. 
 - `getPredictionScript` for the Model class.
 
+These methods should be the starting point of any developer to understand the DML generated for training and prediction respectively.
+
 2. To simplify the DML generation in getTrainingScript and getPredictionScript method, we use DMLGenerator interface.
 This interface generates DML string for common operations such as loops (such as if, for, while) as well as built-in functions (read, write), etc.
 Also, this interface helps in "code reading" of this class :)
 
-3. Additionally, we created mapping classes for layer, solver and learning rate that maps the corresponding Caffe abstraction to the SystemML-NN library.
+3. Here is an analogy for SystemML developers to think of various moving components of Caffe2DML:
+- Like Dml.g4 in the org.apache.sysml.parser.dml package, caffe.proto in the src/main/proto/caffe directory
+is used to generate classes to parse the input files.
+
+Dml.g4      ---> antlr  ---> DmlLexer.java, DmlListener.java, DmlParser.java
+caffe.proto ---> protoc ---> target/generated-sources/caffe/Caffe.java
+
+- Just like the classes generated by Dml.g4 are used to parse input DML file,
+the target/generated-sources/caffe/Caffe.java class is used to parse the input caffe network/deploy prototxt and solver files.
+
+- You can think of .caffemodel file as DML file with matrix values encoded in it (please see below example). 
+So it is possible to read .caffemodel file with the Caffe.java class. This is done in Utils.scala's readCaffeNet method.
+
+X = matrix("1.2 3.5 0.999 7.123", rows=2, cols=2)
+...
+
+- Just like we convert the AST generated by antlr into our DMLProgram representation, we convert
+caffe's abstraction into the below given mapping classes for layer, solver and learning rate.
+These mapping classes maps the corresponding Caffe abstraction to the SystemML-NN library.
 This greatly simplifies adding new layers into Caffe2DML:
 trait CaffeLayer {
   // Any layer that wants to reuse SystemML-NN has to override following methods that help in generating the DML for the given layer:
@@ -87,6 +107,13 @@ trait Network {
   def getTopLayers(layerName:String): Set[String]
   def getLayerID(layerName:String): Int
 }
+
+5. One of the key design restriction of Caffe2DML is that every layer is identified uniquely by its name.
+This restriction simplifies the code significantly.
+To shield from network files that violates this restriction, Caffe2DML performs rewrites in CaffeNetwork class (search for condition 1-5).
+
+6. Caffe2DML also expects the layers to be in sorted order.
+
 ***************************************************************************************/
 
 object Caffe2DML  {
@@ -129,12 +156,12 @@ class Caffe2DML(val sc: SparkContext, val solverParam:Caffe.SolverParameter,
   }
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): Caffe2DMLModel = {
-    val ret = baseFit(X_mb, y_mb, sc)
-    new Caffe2DMLModel(ret, Utils.numClasses(net), sc, solver, net, lrPolicy, this)
+    mloutput = baseFit(X_mb, y_mb, sc)
+    new Caffe2DMLModel(this)
   }
   def fit(df: ScriptsUtils.SparkDataType): Caffe2DMLModel = {
-    val ret = baseFit(df, sc)
-    new Caffe2DMLModel(ret, Utils.numClasses(net), sc, solver, net, lrPolicy, this)
+    mloutput = baseFit(df, sc)
+    new Caffe2DMLModel(this)
   }
 	// --------------------------------------------------------------
   
@@ -412,23 +439,14 @@ class Caffe2DMLModel(val mloutput: MLResults,
   }
   // --------------------------------------------------------------
   
-  def save(outputDir:String, format:String="binary", sep:String="/"):Unit = {
-	  if(mloutput == null) throw new DMLRuntimeException("Cannot save as you need to train the model first using fit")
-	  val dmlScript = new StringBuilder
-	  dmlScript.append("print(\"Saving the model to " + outputDir + "...\")\n")
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => dmlScript.append(write(l.weight, outputDir + sep + l.param.getName + "_weight.mtx", format)))
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => dmlScript.append(write(l.bias, outputDir + sep + l.param.getName + "_bias.mtx", format)))
-	  
-	  val script = dml(dmlScript.toString)
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(l => script.in(l.weight, mloutput.getBinaryBlockMatrix(l.weight)))
-	  net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l => script.in(l.bias, mloutput.getBinaryBlockMatrix(l.bias)))
-	  val ml = new MLContext(sc)
-	  ml.execute(script)
-	}
+  def modelVariables():List[String] = {
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != null).map(_.weight) ++
+    net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(_.bias)
+  }
     
   // ================================================================================================
   // The below method parses the provided network and solver file and generates DML script.
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
     val startPredictionTime = System.nanoTime()
     
 	  reset                                  // Reset the state of DML generator for training script.
@@ -496,11 +514,13 @@ class Caffe2DMLModel(val mloutput: MLResults,
   }
   // ================================================================================================
   
+  def baseEstimator():BaseSystemMLEstimator = estimator
+  
   // Prediction
   def transform(X: MatrixBlock): MatrixBlock = {
-	  baseTransform(X, mloutput, sc, "Prob")
+	  baseTransform(X, sc, "Prob")
   }
   def transform(df: ScriptsUtils.SparkDataType): DataFrame = {
-	  baseTransform(df, mloutput, sc, "Prob")
+	  baseTransform(df, sc, "Prob")
   }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
index 0d1740e..3fdbdb1 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeLayer.scala
@@ -88,7 +88,9 @@ trait CaffeLayer extends BaseDMLGenerator {
   def dWeight():String = throw new DMLRuntimeException("dWeight is not implemented in super class")
   def dBias():String = throw new DMLRuntimeException("dBias is not implemented in super class")
   def weight():String = null;
+  def weightShape():Array[Int];
   def bias():String = null;
+  def biasShape():Array[Int];
   def shouldUpdateWeight():Boolean = if(weight != null) true else false
   def shouldUpdateBias():Boolean = if(bias != null) true else false
   // --------------------------------------------------------------------------------------
@@ -136,13 +138,13 @@ trait IsLossLayer extends CaffeLayer {
 }
 
 trait HasWeight extends CaffeLayer {
-  override def weight = "W" + id
-  override def dWeight = "dW" + id
+  override def weight = param.getName + "_weight"
+  override def dWeight = param.getName + "_dWeight"
 }
 
 trait HasBias extends CaffeLayer {
-  override def bias = "b" + id
-  override def dBias = "db" + id
+  override def bias = param.getName + "_bias"
+  override def dBias = param.getName + "_dBias"
 }
 
 class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numChannels:String, val height:String, val width:String) extends CaffeLayer {
@@ -152,13 +154,21 @@ class Data(val param:LayerParameter, val id:Int, val net:CaffeNetwork, val numCh
     if(param.hasTransformParam && param.getTransformParam.hasScale) {
       dmlScript.append("X_full = X_full * " + param.getTransformParam.getScale + "\n")
     }
-    dmlScript.append("BATCH_SIZE = " + param.getDataParam.getBatchSize + "\n")
+    if(param.hasDataParam && param.getDataParam.hasBatchSize) {
+      dmlScript.append("BATCH_SIZE = " + param.getDataParam.getBatchSize + "\n")
+    }
+    else {
+      Caffe2DML.LOG.debug("Using default batch size of 64 as batch size is not set with DataParam")
+      dmlScript.append("BATCH_SIZE = 64\n")
+    }
   }
   var dataOutputShape = ("$num_channels", "$height", "$width")
   override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = { }
   override def out = "Xb"
   override def backward(dmlScript:StringBuilder, outSuffix:String) = { }
   override def outputShape = (numChannels, height, width)
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
   // -------------------------------------------------
 }
 
@@ -303,6 +313,8 @@ class BatchNorm(val param:LayerParameter, val id:Int, val net:CaffeNetwork) exte
   
   private def withSuffix(str:String):String = if(update_mean_var) str else str + "_ignore"
   override def weight = "ema_mean" + id
+  override def weightShape():Array[Int] = Array(numChannels.toInt, 1)
+  override def biasShape():Array[Int] = Array(numChannels.toInt, 1)
   override def bias = "ema_var" + id
   def cache_mean(): String = "cache_mean" + id
   def cache_var():String = "cache_mean" + id
@@ -337,6 +349,8 @@ class Scale(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends
   // TODO: Generalize this !!
   def forward(dmlScript: StringBuilder, isPrediction: Boolean): Unit = assign(dmlScript, out, X)
   override def backward(dmlScript: StringBuilder, outSuffix:String): Unit = assignDoutToDX(dmlScript, outSuffix)
+  override def weightShape():Array[Int] = Array(bottomLayerOutputShape._1.toInt, 1)
+  override def biasShape():Array[Int] = Array(bottomLayerOutputShape._1.toInt, 1)
 }
 // ------------------------------------------------------------------
 
@@ -354,7 +368,8 @@ class Elementwise(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
     _out
   }
   var _out:(String, String, String) = null
-  
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
 }
 
 class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
@@ -466,6 +481,8 @@ class Concat(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends
     _out
   }
   var _out:(String, String, String) = null
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
 }
 
 class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with IsLossLayer {
@@ -506,6 +523,8 @@ class SoftmaxWithLoss(val param:LayerParameter, val id:Int, val net:CaffeNetwork
 	  else 
 		  throw new LanguageException("More than 2 bottom layers is not supported")
   }
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
   // -------------------------------------------------
 }
 
@@ -540,9 +559,72 @@ class ReLU(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends C
    *  - dX: Gradient wrt `X`, of same shape as `X`.
    */
   override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
+  // -------------------------------------------------
+}
+
+class Softmax(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  // -------------------------------------------------
+  override def sourceFileName = "softmax"
+  override def init(dmlScript:StringBuilder) = { }
+  /*
+   * Computes the forward pass for a softmax classifier.  The inputs
+   * are interpreted as unnormalized, log-probabilities for each of
+   * N examples, and the softmax function transforms them to normalized
+   * probabilities.
+   *
+   * This can be interpreted as a generalization of the sigmoid
+   * function to multiple classes.
+   *
+   *   `probs_ij = e^scores_ij / sum(e^scores_i)`
+   *
+   * Inputs:
+   *  - scores: Inputs, of shape (N, D).
+   *
+   * Outputs:
+   *  - probs: Outputs, of shape (N, D).
+   */
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = invokeForward(dmlScript, List[String](out), X)
+  /*
+   * Computes the backward pass for a softmax classifier.
+   *
+   * Note that dscores_ij has multiple source branches:
+   *
+   *   ```
+   *   dprobs_ij/dscores_ij = probs_ij * (1 - probs_ij)
+   *   dprobs_ik/dscores_ij = -probs_ik * probs_ij, for all k != j
+   *
+   *   dloss/dscores_ij =
+   *      (dloss/dprobs_ij * dprobs_ij/dscores_ij)
+   *      + sum_{k!=j}(dloss/dprobs_ik * dprobs_ik/dscores_ij)
+   *   ```
+   *
+   * Inputs:
+   *  - dprobs: Gradient wrt `probs` from upstream, of shape (N, D).
+   *  - scores: Inputs, of shape (N, D).
+   *
+   * Outputs:
+   *  - dscores: Gradient wrt `scores`, of shape (N, D).
+   */
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = invokeBackward(dmlScript, outSuffix, List[String]("dOut" + id), dout, X)
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
   // -------------------------------------------------
 }
 
+
+class Threshold(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
+  override def sourceFileName = null
+  override def init(dmlScript:StringBuilder) = { }
+  val threshold = if(param.getThresholdParam.hasThreshold) param.getThresholdParam.getThreshold else 0
+  override def forward(dmlScript:StringBuilder, isPrediction:Boolean) = assign(dmlScript, out, X + " > " + threshold)
+  override def backward(dmlScript:StringBuilder, outSuffix:String) = throw new DMLRuntimeException("Backward operation for Threshold layer is not supported.")
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
+}
+
+
 class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
   // -------------------------------------------------
   override def sourceFileName = "dropout"
@@ -591,6 +673,8 @@ class Dropout(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extend
   // dropout ratio
   def p = if(param.getDropoutParam.hasDropoutRatio()) param.getDropoutParam.getDropoutRatio.toString else "0.5"
   def seed = "-1"
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
 }
 
 class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
@@ -656,8 +740,11 @@ class InnerProduct(val param:LayerParameter, val id:Int, val net:CaffeNetwork) e
   def numFeatures = int_mult(bottomLayerOutputShape._1, bottomLayerOutputShape._2, bottomLayerOutputShape._3)
   // n * c_o * 1 * 1
   override def outputShape = ( param.getInnerProductParam.getNumOutput.toString, "1", "1" )
+  override def weightShape():Array[Int] = Array(numFeatures.toInt, numNeurons.toInt)
+  override def biasShape():Array[Int] = Array(1, numNeurons.toInt)
 }
 
+
 class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer {
   // -------------------------------------------------
   override def sourceFileName = "max_pool2d_builtin"
@@ -748,6 +835,8 @@ class MaxPooling(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ext
   def pad_w =   if(poolingParam.hasPadW) poolingParam.getPadW.toString 
                    else if(poolingParam.hasPad) poolingParam.getPad.toString
                    else "0"
+  override def weightShape():Array[Int] = null
+  override def biasShape():Array[Int] = null
 }
 
 class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) extends CaffeLayer with HasWeight with HasBias {
@@ -861,6 +950,8 @@ class Convolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork) ex
   def Wout =  ConvolutionUtils.getConv2dOutputMap(bottomLayerOutputShape._3, kernel_w, stride_w, pad_w)
   // -------------------------------------------------
   def convParam = param.getConvolutionParam
+  override def weightShape():Array[Int] = Array(numKernels.toInt, int_mult(numChannels, kernel_h, kernel_w).toInt)
+  override def biasShape():Array[Int] = Array(numKernels.toInt, 1)
   // num_output (c_o): the number of filters
   def numKernels = convParam.getNumOutput.toString
   // kernel_size (or kernel_h and kernel_w): specifies height and width of each filter
@@ -910,6 +1001,9 @@ class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork)
   override def init(dmlScript: StringBuilder): Unit = 
     invokeInit(dmlScript, List[String](weight, bias), numKernels, numChannels, kernel_h, kernel_w)
     
+  override def weightShape():Array[Int] = Array(numKernels.toInt, int_mult(numChannels, kernel_h, kernel_w).toInt)
+  override def biasShape():Array[Int] = Array(numKernels.toInt, 1)
+    
   /*
    * Computes the forward pass for a 2D spatial transpose convolutional
    * layer with F filters.  The input data has N examples, each
@@ -1017,4 +1111,4 @@ class DeConvolution(val param:LayerParameter, val id:Int, val net:CaffeNetwork)
   def pad_w =   if(convParam.hasPadW) convParam.getPadW.toString 
                    else if(convParam.getPadCount > 0)  convParam.getPad(0).toString 
                    else "0"
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
index c106cb7..5c2dc77 100644
--- a/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/CaffeNetwork.scala
@@ -44,22 +44,50 @@ object CaffeNetwork {
 }
 
 class CaffeNetwork(netFilePath:String, val currentPhase:Phase, 
-     val numChannels:String, val height:String, val width:String
+     var numChannels:String, var height:String, var width:String
     ) extends Network {
   private def isIncludedInCurrentPhase(l:LayerParameter): Boolean = {
-    if(l.getIncludeCount == 0) true else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
+    if(currentPhase == null) return true // while deployment
+    else if(l.getIncludeCount == 0) true 
+    else l.getIncludeList.filter(r => r.hasPhase() && r.getPhase != currentPhase).length == 0
   }
   private var id = 1
-  
+  def this(deployFilePath:String) {
+    this(deployFilePath, null, null, null, null)
+  }
   // --------------------------------------------------------------------------------
-  private var _caffeLayerParams:List[LayerParameter] = Utils.readCaffeNet(netFilePath).getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+  private var _net:NetParameter = Utils.readCaffeNet(netFilePath)
+  private var _caffeLayerParams:List[LayerParameter] = _net.getLayerList.filter(l => isIncludedInCurrentPhase(l)).toList
+  // This method is used if the user doesnot provide number of channels, height and width
+  private def setCHW(inputShapes:java.util.List[caffe.Caffe.BlobShape]):Unit = {
+    if(inputShapes.size != 1)
+        throw new DMLRuntimeException("Expected only one input shape")
+    val inputShape = inputShapes.get(0)
+    if(inputShape.getDimCount != 4)
+      throw new DMLRuntimeException("Expected the input shape of dimension 4")
+    numChannels = inputShape.getDim(1).toString
+    height = inputShape.getDim(2).toString
+    width = inputShape.getDim(3).toString
+  }
+  if(numChannels == null && height == null && width == null) {
+    val inputLayer:List[LayerParameter] = _caffeLayerParams.filter(_.getType.toLowerCase.equals("input"))
+    if(inputLayer.size == 1) {
+      setCHW(inputLayer(0).getInputParam.getShapeList)
+    }
+    else if(inputLayer.size == 0) {
+      throw new DMLRuntimeException("Input shape (number of channels, height, width) is unknown. Hint: If you are using deprecated input/input_shape API, we recommend you use Input layer.")
+    }
+    else {
+      throw new DMLRuntimeException("Multiple Input layer is not supported")
+    }
+  }
   // --------------------------------------------------------------------------------
   
   private var _layerNames: List[String] = _caffeLayerParams.map(l => l.getName).toList
   CaffeNetwork.LOG.debug("Layers in current phase:" + _layerNames)
   
   // Condition 1: assert that each name is unique
-  private val _duplicateLayerNames =_layerNames.diff(_layerNames.distinct)
+  private val _duplicateLayerNames = _layerNames.diff(_layerNames.distinct)
   if(_duplicateLayerNames.size != 0) throw new LanguageException("Duplicate layer names is not supported:" + _duplicateLayerNames)
   
   // Condition 2: only 1 top name, except Data layer
@@ -126,12 +154,16 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
     else l
   })
   
+  // Used while reading caffemodel
+  val replacedLayerNames = new HashMap[String, String]();
+  
   // Condition 5: Deal with incorrect naming
   // Example: layer { name: foo, bottom: arbitrary, top: bar } ... Rename the layer to bar
   private def isIncorrectNamingLayer(l:LayerParameter): Boolean = l.getTopCount == 1 && !l.getTop(0).equalsIgnoreCase(l.getName)
   _caffeLayerParams = _caffeLayerParams.map(l => {
     if(isIncorrectNamingLayer(l)) {
       val builder = l.toBuilder();
+      replacedLayerNames.put(l.getName, l.getTop(0))
       builder.setName(l.getTop(0))
       builder.build()
     }
@@ -161,7 +193,15 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
   
   private def throwException(layerName:String) = throw new LanguageException("Layer with name " + layerName + " not found")                              
   def getLayers(): List[String] =  _layerNames
-  def getCaffeLayer(layerName:String):CaffeLayer = if(checkKey(_layers, layerName)) _layers.get(layerName).get else throwException(layerName)
+  def getCaffeLayer(layerName:String):CaffeLayer = {
+    if(checkKey(_layers, layerName)) _layers.get(layerName).get
+    else {
+      if(replacedLayerNames.contains(layerName) && checkKey(_layers, replacedLayerNames.get(layerName))) {
+        _layers.get(replacedLayerNames.get(layerName)).get
+      }
+      else throwException(layerName)
+    }
+  }
   def getBottomLayers(layerName:String): Set[String] =  if(checkKey(_bottomLayers, layerName)) _bottomLayers.get(layerName).get else throwException(layerName)
   def getTopLayers(layerName:String): Set[String] = if(checkKey(_topLayers, layerName)) _topLayers.get(layerName).get else throwException(layerName)
   def getLayerID(layerName:String): Int = if(checkKey(_layerIDs, layerName))  _layerIDs.get(layerName).get else throwException(layerName)
@@ -183,11 +223,14 @@ class CaffeNetwork(netFilePath:String, val currentPhase:Phase,
       case "softmaxwithloss" => new SoftmaxWithLoss(param, id, this)
       case "dropout" => new Dropout(param, id, this)
       case "data" => new Data(param, id, this, numChannels, height, width)
+      case "input" => new Data(param, id, this, numChannels, height, width)
       case "batchnorm" => new BatchNorm(param, id, this)
       case "scale" => new Scale(param, id, this)
       case "eltwise" => new Elementwise(param, id, this)
       case "concat" => new Concat(param, id, this)
       case "deconvolution" => new DeConvolution(param, id, this)
+      case "threshold" => new Threshold(param, id, this)
+      case "softmax" => new Softmax(param, id, this)
       case _ => throw new LanguageException("Layer of type " + param.getType + " is not supported")
     }
   }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/dl/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Utils.scala b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
index 5181c9b..5c7222c 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Utils.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Utils.scala
@@ -34,6 +34,11 @@ import java.io.InputStreamReader;
 import org.apache.sysml.runtime.DMLRuntimeException
 import java.io.StringReader
 import java.io.BufferedReader
+import com.google.protobuf.CodedInputStream
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.api.mlcontext.MLContext
+import org.apache.spark.SparkContext
+import org.apache.spark.api.java.JavaSparkContext
 
 object Utils {
   // ---------------------------------------------------------------------------------------------
@@ -80,12 +85,144 @@ object Utils {
 	// --------------------------------------------------------------
 	// Caffe utility functions
 	def readCaffeNet(netFilePath:String):NetParameter = {
+	  // Load network
 		val reader:InputStreamReader = getInputStreamReader(netFilePath); 
   	val builder:NetParameter.Builder =  NetParameter.newBuilder();
   	TextFormat.merge(reader, builder);
   	return builder.build();
 	}
 	
+	class CopyFloatToDoubleArray(data:java.util.List[java.lang.Float], rows:Int, cols:Int, transpose:Boolean, arr:Array[Double]) extends Thread {
+	  override def run(): Unit = {
+	    if(transpose) {
+        var iter = 0
+        for(i <- 0 until cols) {
+          for(j <- 0 until rows) {
+            arr(j*cols + i) = data.get(iter).doubleValue()
+            iter += 1
+          }
+        }
+      }
+      else {
+        for(i <- 0 until data.size()) {
+          arr(i) = data.get(i).doubleValue()
+        }
+      }
+	  }
+	}
+	
+	def allocateMatrixBlock(data:java.util.List[java.lang.Float], rows:Int, cols:Int, transpose:Boolean):(MatrixBlock,CopyFloatToDoubleArray) = {
+	  val mb =  new MatrixBlock(rows, cols, false)
+    mb.allocateDenseBlock()
+    val arr = mb.getDenseBlock
+    val thread = new CopyFloatToDoubleArray(data, rows, cols, transpose, arr)
+	  thread.start
+	  return (mb, thread)
+	}
+	def validateShape(shape:Array[Int], data:java.util.List[java.lang.Float], layerName:String): Unit = {
+	  if(shape == null) 
+      throw new DMLRuntimeException("Unexpected weight for layer: " + layerName)
+    else if(shape.length != 2) 
+      throw new DMLRuntimeException("Expected shape to be of length 2:" + layerName)
+    else if(shape(0)*shape(1) != data.size())
+      throw new DMLRuntimeException("Incorrect size of blob from caffemodel for the layer " + layerName + ". Expected of size " + shape(0)*shape(1) + ", but found " + data.size())
+	}
+	
+	def saveCaffeModelFile(sc:JavaSparkContext, deployFilePath:String, 
+	    caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
+	  saveCaffeModelFile(sc.sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
+	}
+	
+	def saveCaffeModelFile(sc:SparkContext, deployFilePath:String, caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
+	  val inputVariables = new java.util.HashMap[String, MatrixBlock]()
+	  readCaffeNet(new CaffeNetwork(deployFilePath), deployFilePath, caffeModelFilePath, inputVariables)
+	  val ml = new MLContext(sc)
+	  val dmlScript = new StringBuilder
+	  if(inputVariables.keys.size == 0)
+	    throw new DMLRuntimeException("No weights found in the file " + caffeModelFilePath)
+	  for(input <- inputVariables.keys) {
+	    dmlScript.append("write(" + input + ", \"" + input + ".mtx\", format=\"" + format + "\");\n")
+	  }
+	  if(Caffe2DML.LOG.isDebugEnabled())
+	    Caffe2DML.LOG.debug("Executing the script:" + dmlScript.toString)
+	  val script = org.apache.sysml.api.mlcontext.ScriptFactory.dml(dmlScript.toString()).in(inputVariables)
+	  ml.execute(script)
+	}
+	
+	def readCaffeNet(net:CaffeNetwork, netFilePath:String, weightsFilePath:String, inputVariables:java.util.HashMap[String, MatrixBlock]):NetParameter = {
+	  // Load network
+		val reader:InputStreamReader = getInputStreamReader(netFilePath); 
+  	val builder:NetParameter.Builder =  NetParameter.newBuilder();
+  	TextFormat.merge(reader, builder);
+  	// Load weights
+	  val inputStream = CodedInputStream.newInstance(new FileInputStream(weightsFilePath))
+	  inputStream.setSizeLimit(Integer.MAX_VALUE)
+	  builder.mergeFrom(inputStream)
+	  val net1 = builder.build();
+	  
+	  val asyncThreads = new java.util.ArrayList[CopyFloatToDoubleArray]()
+	  for(layer <- net1.getLayerList) {
+	    if(layer.getBlobsCount == 0) {
+	      // No weight or bias
+	      Caffe2DML.LOG.debug("The layer:" + layer.getName + " has no blobs")
+	    }
+	    else if(layer.getBlobsCount == 2) {
+	      // Both weight and bias
+	      val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
+	      val transpose = caffe2DMLLayer.isInstanceOf[InnerProduct]
+	      
+	      // weight
+	      val data = layer.getBlobs(0).getDataList
+	      val shape = caffe2DMLLayer.weightShape()
+	      if(shape == null)
+	        throw new DMLRuntimeException("Didnot expect weights for the layer " + layer.getName)
+	      validateShape(shape, data, layer.getName)
+	      val ret1 = allocateMatrixBlock(data, shape(0), shape(1), transpose)
+	      asyncThreads.add(ret1._2)
+	      inputVariables.put(caffe2DMLLayer.weight, ret1._1)
+	      
+	      // bias
+	      val biasData = layer.getBlobs(1).getDataList
+	      val biasShape = caffe2DMLLayer.biasShape()
+	      if(biasShape == null)
+	        throw new DMLRuntimeException("Didnot expect bias for the layer " + layer.getName)
+	      validateShape(biasShape, biasData, layer.getName)
+	      val ret2 = allocateMatrixBlock(biasData, biasShape(0), biasShape(1), transpose)
+	      asyncThreads.add(ret2._2)
+	      inputVariables.put(caffe2DMLLayer.bias, ret2._1)
+	      Caffe2DML.LOG.debug("Read weights/bias for layer:" + layer.getName)
+	    }
+	    else if(layer.getBlobsCount == 1) {
+	      // Special case: convolution/deconvolution without bias
+	      // TODO: Extend nn layers to handle this situation + Generalize this to other layers, for example: InnerProduct
+	      val caffe2DMLLayer = net.getCaffeLayer(layer.getName)
+	      val convParam = if((caffe2DMLLayer.isInstanceOf[Convolution] || caffe2DMLLayer.isInstanceOf[DeConvolution]) && caffe2DMLLayer.param.hasConvolutionParam())  caffe2DMLLayer.param.getConvolutionParam else null  
+	      if(convParam == null)
+	        throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
+	     
+	      val data = layer.getBlobs(0).getDataList
+	      val shape = caffe2DMLLayer.weightShape()
+	      validateShape(shape, data, layer.getName)
+	      val ret1 = allocateMatrixBlock(data, shape(0), shape(1), false)
+	      asyncThreads.add(ret1._2)
+	      inputVariables.put(caffe2DMLLayer.weight, ret1._1)
+	      inputVariables.put(caffe2DMLLayer.bias, new MatrixBlock(convParam.getNumOutput, 1, false))
+	      Caffe2DML.LOG.debug("Read only weight for layer:" + layer.getName)
+	    }
+	    else {
+	      throw new DMLRuntimeException("Layer with blob count " + layer.getBlobsCount + " is not supported for the layer " + layer.getName)
+	    }
+	  }
+	  
+	  // Wait for the copy to be finished
+	  for(t <- asyncThreads) {
+	    t.join()
+	  }
+	  
+	  // Return the NetParameter without
+	  return readCaffeNet(netFilePath)
+	}
+	
 	def readCaffeSolver(solverFilePath:String):SolverParameter = {
 		val reader = getInputStreamReader(solverFilePath);
 		val builder =  SolverParameter.newBuilder();
@@ -112,4 +249,12 @@ object Utils {
 		}
 	}
 	// --------------------------------------------------------------
+}
+
+class Utils {
+  def saveCaffeModelFile(sc:JavaSparkContext, deployFilePath:String, 
+	    caffeModelFilePath:String, outputDirectory:String, format:String):Unit = {
+    Utils.saveCaffeModelFile(sc, deployFilePath, caffeModelFilePath, outputDirectory, format)
+  }
+  
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index f0af799..e601a7d 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -19,6 +19,7 @@
 
 package org.apache.sysml.api.ml
 
+import org.apache.spark.api.java.JavaSparkContext
 import org.apache.spark.rdd.RDD
 import java.io.File
 import org.apache.spark.SparkContext
@@ -95,7 +96,7 @@ trait BaseSystemMLEstimatorOrModel {
 
 trait BaseSystemMLEstimator extends BaseSystemMLEstimatorOrModel {
   def transformSchema(schema: StructType): StructType = schema
-  
+  var mloutput:MLResults = null
   // Returns the script and variables for X and y
   def getTrainingScript(isSingleNode:Boolean):(Script, String, String)
   
@@ -120,7 +121,37 @@ trait BaseSystemMLEstimatorModel extends BaseSystemMLEstimatorOrModel {
   def transformSchema(schema: StructType): StructType = schema
   
   // Returns the script and variable for X
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)
+  def getPredictionScript(isSingleNode:Boolean): (Script, String)
+  def baseEstimator():BaseSystemMLEstimator
+  def modelVariables():List[String]
+  // self.model.load(self.sc._jsc, weights, format, sep)
+  def load(sc:JavaSparkContext, outputDir:String, sep:String):Unit = {
+  	val dmlScript = new StringBuilder
+  	dmlScript.append("print(\"Loading the model from " + outputDir + "...\")\n")
+		for(varName <- modelVariables) {
+			dmlScript.append(varName + " = read(\"" + outputDir + sep + varName + ".mtx\")\n")
+		}
+  	val script = dml(dmlScript.toString)
+		for(varName <- modelVariables) {
+			script.out(varName)
+		}
+	  val ml = new MLContext(sc)
+	  baseEstimator.mloutput = ml.execute(script)
+  }
+  def save(sc:JavaSparkContext, outputDir:String, format:String="binary", sep:String="/"):Unit = {
+	  if(baseEstimator.mloutput == null) throw new DMLRuntimeException("Cannot save as you need to train the model first using fit")
+	  val dmlScript = new StringBuilder
+	  dmlScript.append("print(\"Saving the model to " + outputDir + "...\")\n")
+	  for(varName <- modelVariables) {
+	  	dmlScript.append("write(" + varName + ", \"" + outputDir + sep + varName + ".mtx\", format=\"" + format + "\")\n")
+	  }
+	  val script = dml(dmlScript.toString)
+		for(varName <- modelVariables) {
+			script.in(varName, baseEstimator.mloutput.getBinaryBlockMatrix(varName))
+		}
+	  val ml = new MLContext(sc)
+	  ml.execute(script)
+	}
 }
 
 trait BaseSystemMLClassifier extends BaseSystemMLEstimator {
@@ -150,11 +181,11 @@ trait BaseSystemMLClassifier extends BaseSystemMLEstimator {
 
 trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
 
-  def baseTransform(X: MatrixBlock, mloutput: MLResults, sc: SparkContext, probVar:String): MatrixBlock = {
+  def baseTransform(X: MatrixBlock, sc: SparkContext, probVar:String): MatrixBlock = {
     val isSingleNode = true
     val ml = new MLContext(sc)
     updateML(ml)
-    val script = getPredictionScript(mloutput, isSingleNode)
+    val script = getPredictionScript(isSingleNode)
     // Uncomment for debugging
     // ml.setExplainLevel(ExplainLevel.RECOMPILE_RUNTIME)
     val modelPredict = ml.execute(script._1.in(script._2, X, new MatrixMetadata(X.getNumRows, X.getNumColumns, X.getNonZeros)))
@@ -167,14 +198,14 @@ trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
     return ret
   }
 
-  def baseTransform(df: ScriptsUtils.SparkDataType, mloutput: MLResults, sc: SparkContext, 
+  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, 
       probVar:String, outputProb:Boolean=true): DataFrame = {
     val isSingleNode = false
     val ml = new MLContext(sc)
     updateML(ml)
     val mcXin = new MatrixCharacteristics()
     val Xin = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame].select("features"), mcXin, false, true)
-    val script = getPredictionScript(mloutput, isSingleNode)
+    val script = getPredictionScript(isSingleNode)
     val Xin_bin = new BinaryBlockMatrix(Xin, mcXin)
     val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
     val predLabelOut = PredictionUtils.computePredictedClassLabelsFromProbability(modelPredict, isSingleNode, sc, probVar)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
index 5dd23e0..9e2a34a 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
@@ -60,11 +60,11 @@ trait BaseSystemMLRegressor extends BaseSystemMLEstimator {
 
 trait BaseSystemMLRegressorModel extends BaseSystemMLEstimatorModel {
   
-  def baseTransform(X: MatrixBlock, mloutput: MLResults, sc: SparkContext, predictionVar:String): MatrixBlock = {
+  def baseTransform(X: MatrixBlock, sc: SparkContext, predictionVar:String): MatrixBlock = {
     val isSingleNode = true
     val ml = new MLContext(sc)
     updateML(ml)
-    val script = getPredictionScript(mloutput, isSingleNode)
+    val script = getPredictionScript(isSingleNode)
     val modelPredict = ml.execute(script._1.in(script._2, X))
     val ret = modelPredict.getBinaryBlockMatrix(predictionVar).getMatrixBlock
               
@@ -74,13 +74,13 @@ trait BaseSystemMLRegressorModel extends BaseSystemMLEstimatorModel {
     return ret
   }
   
-  def baseTransform(df: ScriptsUtils.SparkDataType, mloutput: MLResults, sc: SparkContext, predictionVar:String): DataFrame = {
+  def baseTransform(df: ScriptsUtils.SparkDataType, sc: SparkContext, predictionVar:String): DataFrame = {
     val isSingleNode = false
     val ml = new MLContext(sc)
     updateML(ml)
     val mcXin = new MatrixCharacteristics()
     val Xin = RDDConverterUtils.dataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame], mcXin, false, true)
-    val script = getPredictionScript(mloutput, isSingleNode)
+    val script = getPredictionScript(isSingleNode)
     val Xin_bin = new BinaryBlockMatrix(Xin, mcXin)
     val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
     val predictedDF = modelPredict.getDataFrame(predictionVar).select(RDDConverterUtils.DF_ID_COLUMN, "C1").withColumnRenamed("C1", "prediction")

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
index 76bc0a3..463d81a 100644
--- a/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
@@ -48,6 +48,7 @@ class LinearRegression(override val uid: String, val sc: SparkContext, val solve
   def setRegParam(value: Double) = set(regParam, value)
   def setTol(value: Double) = set(tol, value)
   
+
   override def copy(extra: ParamMap): Estimator[LinearRegressionModel] = {
     val that = new LinearRegression(uid, sc, solver)
     copyValues(that, extra)
@@ -72,26 +73,38 @@ class LinearRegression(override val uid: String, val sc: SparkContext, val solve
     (script, "X", "y")
   }
   
-  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LinearRegressionModel = 
-    new LinearRegressionModel("lr")(baseFit(X_mb, y_mb, sc), sc)
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LinearRegressionModel =  {
+    mloutput = baseFit(X_mb, y_mb, sc)
+    new LinearRegressionModel(this)
+  }
     
-  def fit(df: ScriptsUtils.SparkDataType): LinearRegressionModel = 
-    new LinearRegressionModel("lr")(baseFit(df, sc), sc)
+  def fit(df: ScriptsUtils.SparkDataType): LinearRegressionModel = { 
+    mloutput = baseFit(df, sc)
+    new LinearRegressionModel(this)
+  }
   
 }
 
-class LinearRegressionModel(override val uid: String)(val mloutput: MLResults, val sc: SparkContext) extends Model[LinearRegressionModel] with HasIcpt
+class LinearRegressionModel(override val uid: String)(estimator:LinearRegression, val sc: SparkContext) extends Model[LinearRegressionModel] with HasIcpt
     with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLRegressorModel {
   override def copy(extra: ParamMap): LinearRegressionModel = {
-    val that = new LinearRegressionModel(uid)(mloutput, sc)
+    val that = new LinearRegressionModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
   
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String) =
-    PredictionUtils.getGLMPredictionScript(mloutput.getBinaryBlockMatrix("beta_out"), isSingleNode)
+  def baseEstimator():BaseSystemMLEstimator = estimator
+  
+  def this(estimator:LinearRegression) =  {
+  	this("model")(estimator, estimator.sc)
+  }
+  
+  def getPredictionScript(isSingleNode:Boolean): (Script, String) =
+    PredictionUtils.getGLMPredictionScript(estimator.mloutput.getBinaryBlockMatrix("beta_out"), isSingleNode)
+  
+  def modelVariables():List[String] = List[String]("beta_out")
   
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, mloutput, sc, "means")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "means")
   
-  def transform(X: MatrixBlock): MatrixBlock =  baseTransform(X, mloutput, sc, "means")
+  def transform(X: MatrixBlock): MatrixBlock =  baseTransform(X, sc, "means")
   
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
index 9f3d844..f4b5afe 100644
--- a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
@@ -54,15 +54,16 @@ class LogisticRegression(override val uid: String, val sc: SparkContext) extends
     copyValues(that, extra)
   }
   
+
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LogisticRegressionModel = {
-    val ret = baseFit(X_mb, y_mb, sc)
-    new LogisticRegressionModel("log")(ret, sc)
+    mloutput = baseFit(X_mb, y_mb, sc)
+    new LogisticRegressionModel(this)
   }
   
   def fit(df: ScriptsUtils.SparkDataType): LogisticRegressionModel = {
-    val ret = baseFit(df, sc)
-    new LogisticRegressionModel("log")(ret, sc)
+    mloutput = baseFit(df, sc)
+    new LogisticRegressionModel(this)
   }
   
   
@@ -89,21 +90,26 @@ object LogisticRegressionModel {
  */
 
 class LogisticRegressionModel(override val uid: String)(
-    val mloutput: MLResults, val sc: SparkContext) 
+    estimator: LogisticRegression, val sc: SparkContext) 
     extends Model[LogisticRegressionModel] with HasIcpt
     with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter with BaseSystemMLClassifierModel {
   override def copy(extra: ParamMap): LogisticRegressionModel = {
-    val that = new LogisticRegressionModel(uid)(mloutput, sc)
+    val that = new LogisticRegressionModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
   var outputRawPredictions = true
   def setOutputRawPredictions(outRawPred:Boolean): Unit = { outputRawPredictions = outRawPred }
+  def this(estimator:LogisticRegression) =  {
+  	this("model")(estimator, estimator.sc)
+  }
+  def getPredictionScript(isSingleNode:Boolean): (Script, String) =
+    PredictionUtils.getGLMPredictionScript(estimator.mloutput.getBinaryBlockMatrix("B_out"), isSingleNode, 3)
+  
+  def baseEstimator():BaseSystemMLEstimator = estimator
+  def modelVariables():List[String] = List[String]("B_out")
   
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String) =
-    PredictionUtils.getGLMPredictionScript(mloutput.getBinaryBlockMatrix("B_out"), isSingleNode, 3)
-   
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, mloutput, sc, "means")
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, mloutput, sc, "means")
+  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "means")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "means")
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
index 9161a8f..b2e967b 100644
--- a/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
@@ -46,13 +46,13 @@ class NaiveBayes(override val uid: String, val sc: SparkContext) extends Estimat
   
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): NaiveBayesModel = {
-    val ret = baseFit(X_mb, y_mb, sc)
-    new NaiveBayesModel("naive")(ret, sc)
+    mloutput = baseFit(X_mb, y_mb, sc)
+    new NaiveBayesModel(this)
   }
   
   def fit(df: ScriptsUtils.SparkDataType): NaiveBayesModel = {
-    val ret = baseFit(df, sc)
-    new NaiveBayesModel("naive")(ret, sc)
+    mloutput = baseFit(df, sc)
+    new NaiveBayesModel(this)
   }
   
   def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
@@ -74,15 +74,20 @@ object NaiveBayesModel {
 }
 
 class NaiveBayesModel(override val uid: String)
-  (val mloutput: MLResults, val sc: SparkContext) 
+  (estimator:NaiveBayes, val sc: SparkContext) 
   extends Model[NaiveBayesModel] with HasLaplace with BaseSystemMLClassifierModel {
   
+  def this(estimator:NaiveBayes) =  {
+    this("model")(estimator, estimator.sc)
+  }
+  
   override def copy(extra: ParamMap): NaiveBayesModel = {
-    val that = new NaiveBayesModel(uid)(mloutput, sc)
+    val that = new NaiveBayesModel(uid)(estimator, sc)
     copyValues(that, extra)
   }
   
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+  def modelVariables():List[String] = List[String]("classPrior", "classConditionals")
+  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
     val script = dml(ScriptsUtils.getDMLScript(NaiveBayesModel.scriptPath))
       .in("$X", " ")
       .in("$prior", " ")
@@ -90,8 +95,8 @@ class NaiveBayesModel(override val uid: String)
       .in("$probabilities", " ")
       .out("probs")
     
-    val classPrior = mloutput.getBinaryBlockMatrix("classPrior")
-    val classConditionals = mloutput.getBinaryBlockMatrix("classConditionals")
+    val classPrior = estimator.mloutput.getBinaryBlockMatrix("classPrior")
+    val classConditionals = estimator.mloutput.getBinaryBlockMatrix("classConditionals")
     val ret = if(isSingleNode) {
       script.in("prior", classPrior.getMatrixBlock, classPrior.getMatrixMetadata)
             .in("conditionals", classConditionals.getMatrixBlock, classConditionals.getMatrixMetadata)
@@ -103,7 +108,8 @@ class NaiveBayesModel(override val uid: String)
     (ret, "D")
   }
   
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, mloutput, sc, "probs")
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, mloutput, sc, "probs")
+  def baseEstimator():BaseSystemMLEstimator = estimator
+  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "probs")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "probs")
   
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/SVM.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/SVM.scala b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
index db8ce3a..d706101 100644
--- a/src/main/scala/org/apache/sysml/api/ml/SVM.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
@@ -67,13 +67,13 @@ class SVM (override val uid: String, val sc: SparkContext, val isMultiClass:Bool
   
   // Note: will update the y_mb as this will be called by Python mllearn
   def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): SVMModel = {
-    val ret = baseFit(X_mb, y_mb, sc)
-    new SVMModel("svm")(ret, sc, isMultiClass)
+    mloutput = baseFit(X_mb, y_mb, sc)
+    new SVMModel(this, isMultiClass)
   }
   
   def fit(df: ScriptsUtils.SparkDataType): SVMModel = {
-    val ret = baseFit(df, sc)
-    new SVMModel("svm")(ret, sc, isMultiClass)
+    mloutput = baseFit(df, sc)
+    new SVMModel(this, isMultiClass)
   }
   
 }
@@ -83,20 +83,27 @@ object SVMModel {
   final val predictionScriptPathMulticlass = "scripts" + File.separator + "algorithms" + File.separator + "m-svm-predict.dml"
 }
 
-class SVMModel (override val uid: String)(val mloutput: MLResults, val sc: SparkContext, val isMultiClass:Boolean) 
+class SVMModel (override val uid: String)(estimator:SVM, val sc: SparkContext, val isMultiClass:Boolean) 
   extends Model[SVMModel] with BaseSystemMLClassifierModel {
   override def copy(extra: ParamMap): SVMModel = {
-    val that = new SVMModel(uid)(mloutput, sc, isMultiClass)
+    val that = new SVMModel(uid)(estimator, sc, isMultiClass)
     copyValues(that, extra)
   }
   
-  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+  def this(estimator:SVM, isMultiClass:Boolean) =  {
+  	this("model")(estimator, estimator.sc, isMultiClass)
+  }
+  
+  def baseEstimator():BaseSystemMLEstimator = estimator
+  def modelVariables():List[String] = List[String]("w")
+  
+  def getPredictionScript(isSingleNode:Boolean): (Script, String)  = {
     val script = dml(ScriptsUtils.getDMLScript(if(isMultiClass) SVMModel.predictionScriptPathMulticlass else SVMModel.predictionScriptPathBinary))
       .in("$X", " ")
       .in("$model", " ")
       .out("scores")
     
-    val w = mloutput.getBinaryBlockMatrix("w")
+    val w = estimator.mloutput.getBinaryBlockMatrix("w")
     val wVar = if(isMultiClass) "W" else "w"
       
     val ret = if(isSingleNode) {
@@ -108,6 +115,6 @@ class SVMModel (override val uid: String)(val mloutput: MLResults, val sc: Spark
     (ret, "X")
   }
   
-  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, mloutput, sc, "scores")
-  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, mloutput, sc, "scores")
+  def transform(X: MatrixBlock): MatrixBlock = baseTransform(X, sc, "scores")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = baseTransform(df, sc, "scores")
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/scala/org/apache/sysml/api/ml/Utils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/Utils.scala b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
new file mode 100644
index 0000000..da3edf5
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/Utils.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.api.ml
+
+class Utils {
+  def checkIfFileExists(filePath:String):Boolean = {
+    return org.apache.sysml.runtime.util.MapReduceTool.existsFileOnHDFS(filePath)
+  }
+}
\ No newline at end of file


[2/2] incubator-systemml git commit: [SYSTEMML-1632] Support loading and saving models via mllearn

Posted by ni...@apache.org.
[SYSTEMML-1632] Support loading and saving models via mllearn

- Also, updated documentation and fixed bugs.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d69f3441
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d69f3441
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d69f3441

Branch: refs/heads/master
Commit: d69f3441c8243ddd13dd3da6aab9c2d5701c6e50
Parents: d36a0c1
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu May 25 22:32:02 2017 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu May 25 22:32:02 2017 -0700

----------------------------------------------------------------------
 docs/algorithms-classification.md               |  52 +-
 docs/algorithms-regression.md                   |   8 +-
 docs/beginners-guide-caffe2dml.md               | 264 ++++--
 docs/beginners-guide-python.md                  |  33 +-
 docs/native-backend.md                          |   4 +-
 docs/python-reference.md                        | 929 +++++--------------
 pom.xml                                         |   3 +
 .../caffe2dml/models/mnist_lenet/lenet.proto    | 195 ++++
 .../models/mnist_lenet/lenet_solver.proto       |  19 +
 src/main/python/systemml/converters.py          |  96 +-
 src/main/python/systemml/mllearn/estimators.py  | 179 ++--
 .../org/apache/sysml/api/dl/Caffe2DML.scala     |  64 +-
 .../org/apache/sysml/api/dl/CaffeLayer.scala    | 108 ++-
 .../org/apache/sysml/api/dl/CaffeNetwork.scala  |  55 +-
 .../scala/org/apache/sysml/api/dl/Utils.scala   | 145 +++
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  43 +-
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |   8 +-
 .../apache/sysml/api/ml/LinearRegression.scala  |  33 +-
 .../sysml/api/ml/LogisticRegression.scala       |  28 +-
 .../org/apache/sysml/api/ml/NaiveBayes.scala    |  28 +-
 .../scala/org/apache/sysml/api/ml/SVM.scala     |  27 +-
 .../scala/org/apache/sysml/api/ml/Utils.scala   |  25 +
 22 files changed, 1365 insertions(+), 981 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/algorithms-classification.md
----------------------------------------------------------------------
diff --git a/docs/algorithms-classification.md b/docs/algorithms-classification.md
index ed56c34..04c5eb8 100644
--- a/docs/algorithms-classification.md
+++ b/docs/algorithms-classification.md
@@ -131,7 +131,7 @@ Eqs. (1) and (2).
 {% highlight python %}
 from systemml.mllearn import LogisticRegression
 # C = 1/reg
-logistic = LogisticRegression(sqlCtx, fit_intercept=True, max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0)
+logistic = LogisticRegression(spark, fit_intercept=True, max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0)
 # X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
 y_test = logistic.fit(X_train, y_train).predict(X_test)
 # df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
@@ -229,6 +229,8 @@ if no maximum limit provided
 `mm`, or `csv`; see read/write functions in
 SystemML Language Reference for details.
 
+Please see [mllearn documentation](https://apache.github.io/incubator-systemml/python-reference#mllearn-api) for
+more details on the Python API. 
 
 ### Examples
 
@@ -255,9 +257,7 @@ print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_te
 from pyspark.ml import Pipeline
 from systemml.mllearn import LogisticRegression
 from pyspark.ml.feature import HashingTF, Tokenizer
-from pyspark.sql import SQLContext
-sqlCtx = SQLContext(sc)
-training = sqlCtx.createDataFrame([
+training = spark.createDataFrame([
     (0L, "a b c d e spark", 1.0),
     (1L, "b d", 2.0),
     (2L, "spark f g h", 1.0),
@@ -273,10 +273,10 @@ training = sqlCtx.createDataFrame([
 ], ["id", "text", "label"])
 tokenizer = Tokenizer(inputCol="text", outputCol="words")
 hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-lr = LogisticRegression(sqlCtx)
+lr = LogisticRegression(spark)
 pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
 model = pipeline.fit(training)
-test = sqlCtx.createDataFrame([
+test = spark.createDataFrame([
     (12L, "spark i j k"),
     (13L, "l m n"),
     (14L, "mapreduce spark"),
@@ -290,7 +290,7 @@ prediction.show()
 import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
 import org.apache.sysml.api.ml.LogisticRegression
 import org.apache.spark.ml.Pipeline
-val training = sqlContext.createDataFrame(Seq(
+val training = spark.createDataFrame(Seq(
     ("a b c d e spark", 1.0),
     ("b d", 2.0),
     ("spark f g h", 1.0),
@@ -308,7 +308,7 @@ val hashingTF = new HashingTF().setNumFeatures(20).setInputCol(tokenizer.getOutp
 val lr = new LogisticRegression("logReg", sc)
 val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr))
 val model = pipeline.fit(training)
-val test = sqlContext.createDataFrame(Seq(
+val test = spark.createDataFrame(Seq(
     ("spark i j k", 1.0),
     ("l m n", 2.0),
     ("mapreduce spark", 1.0),
@@ -500,7 +500,7 @@ support vector machine (`y` with domain size `2`).
 {% highlight python %}
 from systemml.mllearn import SVM
 # C = 1/reg
-svm = SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=False)
+svm = SVM(spark, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=False)
 # X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
 y_test = svm.fit(X_train, y_train)
 # df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
@@ -637,6 +637,8 @@ held-out test set. Note that this is an optional argument.
 **confusion**: Location (on HDFS) to store the confusion matrix computed
 using a held-out test set. Note that this is an optional argument.
 
+Please see [mllearn documentation](https://apache.github.io/incubator-systemml/python-reference#mllearn-api) for
+more details on the Python API. 
 
 #### Examples
 
@@ -768,7 +770,7 @@ class labels.
 {% highlight python %}
 from systemml.mllearn import SVM
 # C = 1/reg
-svm = SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=True)
+svm = SVM(spark, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=True)
 # X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
 y_test = svm.fit(X_train, y_train)
 # df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
@@ -906,6 +908,8 @@ SystemML Language Reference for details.
 **confusion**: Location (on HDFS) to store the confusion matrix computed
     using a held-out test set. Note that this is an optional argument.
 
+Please see [mllearn documentation](https://apache.github.io/incubator-systemml/python-reference#mllearn-api) for
+more details on the Python API. 
 
 #### Examples
 
@@ -917,25 +921,21 @@ SystemML Language Reference for details.
 # Scikit-learn way
 from sklearn import datasets, neighbors
 from systemml.mllearn import SVM
-from pyspark.sql import SQLContext
-sqlCtx = SQLContext(sc)
 digits = datasets.load_digits()
 X_digits = digits.data
 y_digits = digits.target 
 n_samples = len(X_digits)
-X_train = X_digits[:.9 * n_samples]
-y_train = y_digits[:.9 * n_samples]
-X_test = X_digits[.9 * n_samples:]
-y_test = y_digits[.9 * n_samples:]
-svm = SVM(sqlCtx, is_multi_class=True)
+X_train = X_digits[:int(.9 * n_samples)]
+y_train = y_digits[:int(.9 * n_samples)]
+X_test = X_digits[int(.9 * n_samples):]
+y_test = y_digits[int(.9 * n_samples):]
+svm = SVM(spark, is_multi_class=True)
 print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
 
 # MLPipeline way
 from pyspark.ml import Pipeline
 from systemml.mllearn import SVM
 from pyspark.ml.feature import HashingTF, Tokenizer
-from pyspark.sql import SQLContext
-sqlCtx = SQLContext(sc)
 training = sqlCtx.createDataFrame([
     (0L, "a b c d e spark", 1.0),
     (1L, "b d", 2.0),
@@ -952,7 +952,7 @@ training = sqlCtx.createDataFrame([
 ], ["id", "text", "label"])
 tokenizer = Tokenizer(inputCol="text", outputCol="words")
 hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-svm = SVM(sqlCtx, is_multi_class=True)
+svm = SVM(spark, is_multi_class=True)
 pipeline = Pipeline(stages=[tokenizer, hashingTF, svm])
 model = pipeline.fit(training)
 test = sqlCtx.createDataFrame([
@@ -969,7 +969,7 @@ prediction.show()
 import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
 import org.apache.sysml.api.ml.SVM
 import org.apache.spark.ml.Pipeline
-val training = sqlContext.createDataFrame(Seq(
+val training = spark.createDataFrame(Seq(
     ("a b c d e spark", 1.0),
     ("b d", 2.0),
     ("spark f g h", 1.0),
@@ -987,7 +987,7 @@ val hashingTF = new HashingTF().setNumFeatures(20).setInputCol(tokenizer.getOutp
 val svm = new SVM("svm", sc, isMultiClass=true)
 val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, svm))
 val model = pipeline.fit(training)
-val test = sqlContext.createDataFrame(Seq(
+val test = spark.createDataFrame(Seq(
     ("spark i j k", 1.0),
     ("l m n", 2.0),
     ("mapreduce spark", 1.0),
@@ -1123,7 +1123,7 @@ applicable when all features are counts of categorical values.
 <div data-lang="Python" markdown="1">
 {% highlight python %}
 from systemml.mllearn import NaiveBayes
-nb = NaiveBayes(sqlCtx, laplace=1.0)
+nb = NaiveBayes(spark, laplace=1.0)
 # X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
 y_test = nb.fit(X_train, y_train)
 # df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
@@ -1246,6 +1246,8 @@ SystemML Language Reference for details.
 **confusion**: Location (on HDFS) to store the confusion matrix computed
     using a held-out test set. Note that this is an optional argument.
 
+Please see [mllearn documentation](https://apache.github.io/incubator-systemml/python-reference#mllearn-api) for
+more details on the Python API. 
 
 ### Examples
 
@@ -1258,8 +1260,6 @@ from sklearn.datasets import fetch_20newsgroups
 from sklearn.feature_extraction.text import TfidfVectorizer
 from systemml.mllearn import NaiveBayes
 from sklearn import metrics
-from pyspark.sql import SQLContext
-sqlCtx = SQLContext(sc)
 categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
 newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
 newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
@@ -1267,7 +1267,7 @@ vectorizer = TfidfVectorizer()
 # Both vectors and vectors_test are SciPy CSR matrix
 vectors = vectorizer.fit_transform(newsgroups_train.data)
 vectors_test = vectorizer.transform(newsgroups_test.data)
-nb = NaiveBayes(sqlCtx)
+nb = NaiveBayes(spark)
 nb.fit(vectors, newsgroups_train.target)
 pred = nb.predict(vectors_test)
 metrics.f1_score(newsgroups_test.target, pred, average='weighted')

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/algorithms-regression.md
----------------------------------------------------------------------
diff --git a/docs/algorithms-regression.md b/docs/algorithms-regression.md
index 22c6959..13f6cff 100644
--- a/docs/algorithms-regression.md
+++ b/docs/algorithms-regression.md
@@ -212,6 +212,8 @@ gradient iterations, or `0` if no maximum limit provided
 `mm`, or `csv`; see read/write functions in
 SystemML Language Reference for details.
 
+Please see [mllearn documentation](https://apache.github.io/incubator-systemml/python-reference#mllearn-api) for
+more details on the Python API. 
 
 ### Examples
 
@@ -223,7 +225,6 @@ SystemML Language Reference for details.
 import numpy as np
 from sklearn import datasets
 from systemml.mllearn import LinearRegression
-from pyspark.sql import SQLContext
 # Load the diabetes dataset
 diabetes = datasets.load_diabetes()
 # Use only one feature
@@ -235,7 +236,7 @@ diabetes_X_test = diabetes_X[-20:]
 diabetes_y_train = diabetes.target[:-20]
 diabetes_y_test = diabetes.target[-20:]
 # Create linear regression object
-regr = LinearRegression(sqlCtx, solver='direct-solve')
+regr = LinearRegression(spark, solver='direct-solve')
 # Train the model using the training sets
 regr.fit(diabetes_X_train, diabetes_y_train)
 # The mean square error
@@ -278,7 +279,6 @@ print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) -
 import numpy as np
 from sklearn import datasets
 from systemml.mllearn import LinearRegression
-from pyspark.sql import SQLContext
 # Load the diabetes dataset
 diabetes = datasets.load_diabetes()
 # Use only one feature
@@ -290,7 +290,7 @@ diabetes_X_test = diabetes_X[-20:]
 diabetes_y_train = diabetes.target[:-20]
 diabetes_y_test = diabetes.target[-20:]
 # Create linear regression object
-regr = LinearRegression(sqlCtx, solver='newton-cg')
+regr = LinearRegression(spark, solver='newton-cg')
 # Train the model using the training sets
 regr.fit(diabetes_X_train, diabetes_y_train)
 # The mean square error

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/beginners-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-caffe2dml.md b/docs/beginners-guide-caffe2dml.md
index dea53fd..55eb154 100644
--- a/docs/beginners-guide-caffe2dml.md
+++ b/docs/beginners-guide-caffe2dml.md
@@ -29,20 +29,45 @@ limitations under the License.
 
 ## Introduction
 
-Caffe2DML is an experimental API that converts an Caffe specification to DML.
+Caffe2DML is an **experimental API** that converts an Caffe specification to DML. 
+It is designed to fit well into the mllearn framework and hence supports NumPy, Pandas as well as PySpark DataFrame.
 
-## Example: Train Lenet
+## Examples
 
-1. Install `mlextend` package to get MNIST data: `pip install mlxtend`.
-2. (Optional but recommended) Follow the steps mentioned in [the user guide]([the user guide of native backend](http://apache.github.io/incubator-systemml/native-backend)) and install Intel MKL.
-3. Install [SystemML](http://apache.github.io/incubator-systemml/beginners-guide-python#install-systemml).
-4. Invoke PySpark shell: `pyspark --conf spark.executorEnv.LD_LIBRARY_PATH=/path/to/blas-n-other-dependencies`.
+### Train Lenet on MNIST dataset
+
+#### MNIST dataset
+
+The MNIST dataset was constructed from two datasets of the US National Institute of Standards and Technology (NIST). The training set consists of handwritten digits from 250 different people, 50 percent high school students, and 50 percent employees from the Census Bureau. Note that the test set contains handwritten digits from different people following the same split.
+In the below example, we are using mlxtend package to load the mnist dataset into Python NumPy arrays, but you are free to download it directly from http://yann.lecun.com/exdb/mnist/.
 
 ```bash
-# Download the MNIST dataset
+pip install mlxtend
+```
+
+#### Lenet network
+
+Lenet is a simple convolutional neural network, proposed by Yann LeCun in 1998. It has 2 convolutions/pooling and fully connected layer. 
+Similar to Caffe, the network has been modified to add dropout. 
+For more detail, please see http://yann.lecun.com/exdb/lenet/
+
+The [solver specification](https://raw.githubusercontent.com/apache/incubator-systemml/master/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto)
+specifies to Caffe2DML to use following configuration when generating the training DML script:  
+- `type: "SGD", momentum: 0.9`: Stochastic Gradient Descent with momentum optimizer with `momentum=0.9`.
+- `lr_policy: "exp", gamma: 0.95, base_lr: 0.01`: Use exponential decay learning rate policy (`base_lr * gamma ^ iter`).
+- `display: 100`: Display training loss after every 100 iterations.
+- `test_interval: 500`: Display validation loss after every 500 iterations.
+- `test_iter: 10`: Validation data size = 10 * BATCH_SIZE.
+ 
+
+```python
 from mlxtend.data import mnist_data
 import numpy as np
 from sklearn.utils import shuffle
+import urllib
+from systemml.mllearn import Caffe2DML
+
+# Download the MNIST dataset
 X, y = mnist_data()
 X, y = shuffle(X, y)
 
@@ -54,105 +79,182 @@ X_test = X[int(.9 * n_samples):]
 y_test = y[int(.9 * n_samples):]
 
 # Download the Lenet network
-import urllib
-urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet.proto', 'lenet.proto')
-urllib.urlretrieve('https://raw.githubusercontent.com/niketanpansare/model_zoo/master/caffe/vision/lenet/mnist/lenet_solver.proto', 'lenet_solver.proto')
+urllib.urlretrieve('https://raw.githubusercontent.com/apache/incubator-systemml/master/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet.proto', 'lenet.proto')
+urllib.urlretrieve('https://raw.githubusercontent.com/apache/incubator-systemml/master/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto', 'lenet_solver.proto')
 
 # Train Lenet On MNIST using scikit-learn like API
-from systemml.mllearn import Caffe2DML
-lenet = Caffe2DML(sqlCtx, solver='lenet_solver.proto', input_shape=(1, 28, 28)).set(debug=True).setStatistics(True)
+# MNIST dataset contains 28 X 28 gray-scale (number of channel=1).
+lenet = Caffe2DML(sqlCtx, solver='lenet_solver.proto', input_shape=(1, 28, 28))
+
+# debug=True prints will print the generated DML script along with classification report. Please donot test this flag in production.
+lenet.set(debug=True)
+
+# If you want to see the statistics as well as the plan
+lenet.setStatistics(True).setExplain(True)
+
+# If you want to force GPU execution. Please make sure the required dependency are available.  
+# lenet.setGPU(True).setForceGPU(True)
+
+# (Optional but recommended) Enable native BLAS. For more detail see http://apache.github.io/incubator-systemml/native-backend
+lenet.setConfigProperty("native.blas", "auto")
+
+# In case you want to enable experimental feature such as codegen
+# lenet.setConfigProperty("codegen.enabled", "true").setConfigProperty("codegen.plancache", "true")
+
+# Since Caffe2DML is a mllearn API, it allows for scikit-learn like method for training.
 lenet.fit(X_train, y_train)
-y_predicted = lenet.predict(X_test)
+lenet.predict(X_test)
 ```
 
 ## Frequently asked questions
 
-- How to set batch size ?
+#### How can I speedup the training with Caffe2DML ?
+
+- Enable native BLAS to improve the performance of CP convolution and matrix multiplication operators.
+If you are using OpenBLAS, please ensure that it was built with `USE_OPENMP` flag turned on.
+For more detail see http://apache.github.io/incubator-systemml/native-backend
+
+```python
+caffe2dmlObject.setConfigProperty("native.blas", "auto")
+```
+
+- Turn on the experimental codegen feature. This should help reduce unnecessary allocation cost after every binary operation.
+
+```python
+caffe2dmlObject.setConfigProperty("codegen.enabled", "true").setConfigProperty("codegen.plancache", "true")
+```
+
+- Tuned the [Garbage Collector](http://spark.apache.org/docs/latest/tuning.html#garbage-collection-tuning). 
+
+- Enable GPU support (described below).
+
+#### How to enable GPU support in Caffe2DML ?
+
+To be consistent with other mllearn algorithms, we recommend that you use following method instead of setting 
+the `solver_mode` in solver file.
+
+```python
+# The below method tells SystemML optimizer to use a GPU-enabled instruction if the operands fit in the GPU memory 
+caffe2dmlObject.setGPU(True)
+# The below method tells SystemML optimizer to always use a GPU-enabled instruction irrespective of the memory requirement
+caffe2dmlObject.setForceGPU(True)
+```
+
+#### What is lr_policy in the solver specification ?
+
+The parameter `lr_policy` specifies the learning rate decay policy. Caffe2DML supports following policies:
+- `fixed`: always return `base_lr`.
+- `step`: return `base_lr * gamma ^ (floor(iter / step))`
+- `exp`: return `base_lr * gamma ^ iter`
+- `inv`: return `base_lr * (1 + gamma * iter) ^ (- power)`
+- `poly`: the effective learning rate follows a polynomial decay, to be zero by the max_iter. return `base_lr (1 - iter/max_iter) ^ (power)`
+- `sigmoid`: the effective learning rate follows a sigmod decay return b`ase_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))`
+      
+#### How to set batch size ?
 
 Batch size is set in `data_param` of the Data layer:
 
-	layer {
-	  name: "mnist"
-	  type: "Data"
-	  top: "data"
-	  top: "label"
-	  data_param {
-	    source: "mnist_train"
-	    batch_size: 64
-	    backend: LMDB
-	  }
-	}
+```
+layer {
+  name: "mnist"
+  type: "Data"
+  top: "data"
+  top: "label"
+  data_param {
+    source: "mnist_train"
+    batch_size: 64
+    backend: LMDB
+  }
+}
+```
 	
-- How to set maximum number of iterations for training ?
+#### How to set maximum number of iterations for training ?
 
-Caffe allows you to set the maximum number of iterations in solver specification
+The maximum number of iterations can be set in the solver specification
 
-	# The maximum number of iterations
-	max_iter: 2000
-	
-- How to set the size of the validation dataset ?
+```bash
+# The maximum number of iterations
+max_iter: 2000
+```
+
+#### How to set the size of the validation dataset ?
 
 The size of the validation dataset is determined by the parameters `test_iter` and the batch size. For example: If the batch size is 64 and 
 `test_iter` is 10, then the validation size is 640. This setting generates following DML code internally:
 
-	num_images = nrow(y_full)
-	BATCH_SIZE = 64
-	num_validation = 10 * BATCH_SIZE
-	X = X_full[(num_validation+1):num_images,]; y = y_full[(num_validation+1):num_images,]
-	X_val = X_full[1:num_validation,]; y_val = y_full[1:num_validation,]
-	num_images = nrow(y) 
+```python
+num_images = nrow(y_full)
+BATCH_SIZE = 64
+num_validation = 10 * BATCH_SIZE
+X = X_full[(num_validation+1):num_images,]; y = y_full[(num_validation+1):num_images,]
+X_val = X_full[1:num_validation,]; y_val = y_full[1:num_validation,]
+num_images = nrow(y)
+``` 
 
-- How to monitor loss via command-line ?
+#### How to monitor loss via command-line ?
 
 To monitor loss, please set following parameters in the solver specification
 
-	# Display training loss and accuracy every 100 iterations
-	display: 100
-	# Carry out validation every 500 training iterations and display validation loss and accuracy.
-	test_iter: 10
-	test_interval: 500
-	
- - How to pass a single jpeg image to Caffe2DML for prediction ?
+```
+# Display training loss and accuracy every 100 iterations
+display: 100
+# Carry out validation every 500 training iterations and display validation loss and accuracy.
+test_iter: 10
+test_interval: 500
+```
+
+#### How to pass a single jpeg image to Caffe2DML for prediction ?
+
+To convert a jpeg into NumPy matrix, you can use the [pillow package](https://pillow.readthedocs.io/) and 
+SystemML's  `convertImageToNumPyArr` utility function. The below pyspark code demonstrates the usage:
  
-	from PIL import Image
-	import systemml as sml
-	from systemml.mllearn import Caffe2DML
-	img_shape = (3, 224, 224)
-	input_image = sml.convertImageToNumPyArr(Image.open(img_file_path), img_shape=img_shape)
-	resnet = Caffe2DML(sqlCtx, solver='ResNet_50_solver.proto', weights='ResNet_50_pretrained_weights', input_shape=img_shape)
-	resnet.predict(input_image)
+```python
+from PIL import Image
+import systemml as sml
+from systemml.mllearn import Caffe2DML
+img_shape = (3, 224, 224)
+input_image = sml.convertImageToNumPyArr(Image.open(img_file_path), img_shape=img_shape)
+resnet = Caffe2DML(sqlCtx, solver='ResNet_50_solver.proto', weights='ResNet_50_pretrained_weights', input_shape=img_shape)
+resnet.predict(input_image)
+```
 
-- How to prepare a directory of jpeg images for training with Caffe2DML ?
+#### How to prepare a directory of jpeg images for training with Caffe2DML ?
 
-The below example assumes that the input dataset has 2 labels `cat` and `dogs` and the filename has these labels as prefix.
+The below pyspark code assumes that the input dataset has 2 labels `cat` and `dogs` and the filename has these labels as prefix.
 We iterate through the directory and convert each jpeg image into pyspark.ml.linalg.Vector using pyspark.
 These vectors are stored as DataFrame and randomized using Spark SQL's `orderBy(rand())` function.
 The DataFrame is then saved in parquet format to reduce the cost of preprocessing for repeated training.
 
-	from systemml.mllearn import Caffe2DML
-	from pyspark.sql import SQLContext
-	import numpy as np
-	import urllib, os, scipy.ndimage
-	from pyspark.ml.linalg import Vectors
-	from pyspark import StorageLevel
-	import systemml as sml
-	from pyspark.sql.functions import rand 
-	# ImageNet specific parameters
-	img_shape = (3, 224, 224)
-	train_dir = '/home/biuser/dogs_vs_cats/train'
-	def getLabelFeatures(filename):
-		from PIL import Image
-		vec = Vectors.dense(sml.convertImageToNumPyArr(Image.open(os.path.join(train_dir, filename)), img_shape=img_shape)[0,:])
-		if filename.lower().startswith('cat'):
-			return (1, vec)
-		elif filename.lower().startswith('dog'):
-			return (2, vec)
-		else:
-			raise ValueError('Expected the filename to start with either cat or dog')
-	
-	list_jpeg_files = os.listdir(train_dir)
-	# 10 files per partition
-	train_df = sc.parallelize(list_jpeg_files, int(len(list_jpeg_files)/10)).map(lambda filename : getLabelFeatures(filename)).toDF(['label', 'features']).orderBy(rand())
-	# Optional: but helps seperates conversion-related from training
-	# Alternatively, this dataframe can be passed directly to `caffe2dml_model.fit(train_df)`
-	train_df.write.parquet('kaggle-cats-dogs.parquet')
\ No newline at end of file
+```python
+from systemml.mllearn import Caffe2DML
+from pyspark.sql import SQLContext
+import numpy as np
+import urllib, os, scipy.ndimage
+from pyspark.ml.linalg import Vectors
+from pyspark import StorageLevel
+import systemml as sml
+from pyspark.sql.functions import rand 
+# ImageNet specific parameters
+img_shape = (3, 224, 224)
+train_dir = '/home/biuser/dogs_vs_cats/train'
+def getLabelFeatures(filename):
+	from PIL import Image
+	vec = Vectors.dense(sml.convertImageToNumPyArr(Image.open(os.path.join(train_dir, filename)), img_shape=img_shape)[0,:])
+	if filename.lower().startswith('cat'):
+		return (1, vec)
+	elif filename.lower().startswith('dog'):
+		return (2, vec)
+	else:
+		raise ValueError('Expected the filename to start with either cat or dog')
+list_jpeg_files = os.listdir(train_dir)
+# 10 files per partition
+train_df = sc.parallelize(list_jpeg_files, int(len(list_jpeg_files)/10)).map(lambda filename : getLabelFeatures(filename)).toDF(['label', 'features']).orderBy(rand())
+# Optional: but helps seperates conversion-related from training
+# Alternatively, this dataframe can be passed directly to `caffe2dml_model.fit(train_df)`
+train_df.write.parquet('kaggle-cats-dogs.parquet')
+```
+
+#### Can I use Caffe2DML via Scala ?
+
+Though we recommend using Caffe2DML via its Python interfaces, it is possible to use it by creating an object of the class
+`org.apache.sysml.api.dl.Caffe2DML`. It is important to note that Caffe2DML's scala API is packaged in `systemml-*-extra.jar`.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/beginners-guide-python.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-python.md b/docs/beginners-guide-python.md
index 9beba19..b75e73c 100644
--- a/docs/beginners-guide-python.md
+++ b/docs/beginners-guide-python.md
@@ -204,7 +204,7 @@ will use `mllearn` API described in the next section.
 
 ## Invoke SystemML's algorithms
 
-SystemML also exposes a subpackage `mllearn`. This subpackage allows Python users to invoke SystemML algorithms
+SystemML also exposes a subpackage [mllearn](https://apache.github.io/incubator-systemml/python-reference#mllearn-api). This subpackage allows Python users to invoke SystemML algorithms
 using Scikit-learn or MLPipeline API.  
 
 ### Scikit-learn interface
@@ -216,7 +216,6 @@ algorithm.
 import numpy as np
 from sklearn import datasets
 from systemml.mllearn import LinearRegression
-from pyspark.sql import SQLContext
 # Load the diabetes dataset
 diabetes = datasets.load_diabetes()
 # Use only one feature
@@ -228,7 +227,7 @@ X_test = diabetes_X[-20:]
 y_train = diabetes.target[:-20]
 y_test = diabetes.target[-20:]
 # Create linear regression object
-regr = LinearRegression(sqlCtx, fit_intercept=True, C=float("inf"), solver='direct-solve')
+regr = LinearRegression(spark, fit_intercept=True, C=float("inf"), solver='direct-solve')
 # Train the model using the training sets
 regr.fit(X_train, y_train)
 y_predicted = regr.predict(X_test)
@@ -248,24 +247,34 @@ algorithm on digits datasets.
 
 ```python
 # Scikit-learn way
-from sklearn import datasets
+from sklearn import datasets, neighbors
 from systemml.mllearn import LogisticRegression
 digits = datasets.load_digits()
 X_digits = digits.data
-y_digits = digits.target 
+y_digits = digits.target
 n_samples = len(X_digits)
 X_train = X_digits[:int(.9 * n_samples)]
 y_train = y_digits[:int(.9 * n_samples)]
 X_test = X_digits[int(.9 * n_samples):]
 y_test = y_digits[int(.9 * n_samples):]
-logistic = LogisticRegression(sqlCtx)
+logistic = LogisticRegression(spark)
 print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
 ```
 
 Output:
 
 ```bash
-LogisticRegression score: 0.922222
+LogisticRegression score: 0.927778
+```
+
+You can also save the trained model and load it later for prediction:
+
+```python
+# Assuming logistic.fit(X_train, y_train) is already invoked
+logistic.save('logistic_model')
+new_logistic = LogisticRegression(spark)
+new_logistic.load('logistic_model')
+print('LogisticRegression score: %f' % new_logistic.score(X_test, y_test))
 ```
 
 ### Passing PySpark DataFrame
@@ -275,7 +284,6 @@ To train the above algorithm on larger dataset, we can load the dataset into Dat
 ```python
 from sklearn import datasets
 from systemml.mllearn import LogisticRegression
-from pyspark.sql import SQLContext
 import pandas as pd
 from sklearn.metrics import accuracy_score
 import systemml as sml
@@ -285,8 +293,8 @@ y_digits = digits.target
 n_samples = len(X_digits)
 # Split the data into training/testing sets and convert to PySpark DataFrame
 df_train = sml.convertToLabeledDF(sqlCtx, X_digits[:int(.9 * n_samples)], y_digits[:int(.9 * n_samples)])
-X_test = sqlCtx.createDataFrame(pd.DataFrame(X_digits[int(.9 * n_samples):]))
-logistic = LogisticRegression(sqlCtx)
+X_test = spark.createDataFrame(pd.DataFrame(X_digits[int(.9 * n_samples):]))
+logistic = LogisticRegression(spark)
 logistic.fit(df_train)
 y_predicted = logistic.predict(X_test)
 y_predicted = y_predicted.select('prediction').toPandas().as_matrix().flatten()
@@ -310,8 +318,7 @@ large data pipelines.
 from pyspark.ml import Pipeline
 from systemml.mllearn import LogisticRegression
 from pyspark.ml.feature import HashingTF, Tokenizer
-from pyspark.sql import SQLContext
-training = sqlCtx.createDataFrame([
+training = spark.createDataFrame([
     (0, "a b c d e spark", 1.0),
     (1, "b d", 2.0),
     (2, "spark f g h", 1.0),
@@ -330,7 +337,7 @@ hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
 lr = LogisticRegression(sqlCtx)
 pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
 model = pipeline.fit(training)
-test = sqlCtx.createDataFrame([
+test = spark.createDataFrame([
     (12, "spark i j k"),
     (13, "l m n"),
     (14, "mapreduce spark"),

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/native-backend.md
----------------------------------------------------------------------
diff --git a/docs/native-backend.md b/docs/native-backend.md
index 6207932..d6a6228 100644
--- a/docs/native-backend.md
+++ b/docs/native-backend.md
@@ -50,11 +50,11 @@ The current version of SystemML only supports BLAS on **Linux** machines.
 
 ## Step 1: Install BLAS
 
-### Option 1: Install Intel MKL (recommended)
+### Option 1: Install Intel MKL
 
 Download and install the [community version of Intel MKL](https://software.intel.com/sites/campaigns/nest/).
 Intel requires you to first register your email address and then sends the download link to your email address
-with license key.
+with license key. Since we use MKL DNN primitives, we depend on Intel MKL version 2017 or higher.
 
 * Linux users will have to extract the downloaded `.tgz` file, execute `install.sh` and follow the guided setup.
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/docs/python-reference.md
----------------------------------------------------------------------
diff --git a/docs/python-reference.md b/docs/python-reference.md
index 8d38598..0d90ec3 100644
--- a/docs/python-reference.md
+++ b/docs/python-reference.md
@@ -43,8 +43,36 @@ and its algorithms without the need to know DML or PyDML. We explain these APIs
 ## matrix API
 
 The matrix class allows users to perform linear algebra operations in SystemML using a NumPy-like interface.
-This class supports several arithmetic operators (such as +, -, *, /, ^, etc) and also supports most of NumPy's universal functions (i.e. ufuncs).
+This class supports several arithmetic operators (such as +, -, *, /, ^, etc).
 
+matrix class is a python wrapper that implements basic matrix
+operators, matrix functions as well as converters to common Python
+types (for example: Numpy arrays, PySpark DataFrame and Pandas
+DataFrame).
+
+The operators supported are:
+
+1.  Arithmetic operators: +, -, *, /, //, %, \** as well as dot
+    (i.e. matrix multiplication)
+2.  Indexing in the matrix
+3.  Relational/Boolean operators: \<, \<=, \>, \>=, ==, !=, &, \|
+
+In addition, following functions are supported for matrix:
+
+1.  transpose
+2.  Aggregation functions: sum, mean, var, sd, max, min, argmin,
+    argmax, cumsum
+3.  Global statistical built-In functions: exp, log, abs, sqrt,
+    round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
+
+For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
+For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
+whereas SystemML returns a 2d matrix of dimension (3, 1).
+
+Note: an evaluated matrix contains a data field computed by eval
+method as DataFrame or NumPy array.
+
+It is important to note that matrix class also supports most of NumPy's universal functions (i.e. ufuncs).
 The current version of NumPy explicitly disables overriding ufunc, but this should be enabled in next release. 
 Until then to test above code, please use:
 
@@ -123,435 +151,62 @@ array([[-60.],
 ```
 
 
-### Reference Documentation:
-
- *class*`systemml.defmatrix.matrix`(*data*, *op=None*)
-:   Bases: `object`
-
-    matrix class is a python wrapper that implements basic matrix
-    operators, matrix functions as well as converters to common Python
-    types (for example: Numpy arrays, PySpark DataFrame and Pandas
-    DataFrame).
-
-    The operators supported are:
-
-    1.  Arithmetic operators: +, -, *, /, //, %, \** as well as dot
-        (i.e. matrix multiplication)
-    2.  Indexing in the matrix
-    3.  Relational/Boolean operators: \<, \<=, \>, \>=, ==, !=, &, \|
-
-    In addition, following functions are supported for matrix:
-
-    1.  transpose
-    2.  Aggregation functions: sum, mean, var, sd, max, min, argmin,
-        argmax, cumsum
-    3.  Global statistical built-In functions: exp, log, abs, sqrt,
-        round, floor, ceil, sin, cos, tan, asin, acos, atan, sign, solve
-
-    For all the above functions, we always return a two dimensional matrix, especially for aggregation functions with axis. 
-    For example: Assuming m1 is a matrix of (3, n), NumPy returns a 1d vector of dimension (3,) for operation m1.sum(axis=1)
-    whereas SystemML returns a 2d matrix of dimension (3, 1).
-    
-    Note: an evaluated matrix contains a data field computed by eval
-    method as DataFrame or NumPy array.
-
-        >>> import SystemML as sml
-        >>> import numpy as np
-        >>> sml.setSparkContext(sc)
-
-    Welcome to Apache SystemML!
-
-        >>> m1 = sml.matrix(np.ones((3,3)) + 2)
-        >>> m2 = sml.matrix(np.ones((3,3)) + 3)
-        >>> m2 = m1 * (m2 + m1)
-        >>> m4 = 1.0 - m2
-        >>> m4
-        # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
-        mVar1 = load(" ", format="csv")
-        mVar2 = load(" ", format="csv")
-        mVar3 = mVar2 + mVar1
-        mVar4 = mVar1 * mVar3
-        mVar5 = 1.0 - mVar4
-        save(mVar5, " ")
-        >>> m2.eval()
-        >>> m2
-        # This matrix (mVar4) is backed by NumPy array. To fetch the NumPy array, invoke toNumPy() method.
-        >>> m4
-        # This matrix (mVar5) is backed by below given PyDML script (which is not yet evaluated). To fetch the data of this matrix, invoke toNumPy() or toDF() or toPandas() methods.
-        mVar4 = load(" ", format="csv")
-        mVar5 = 1.0 - mVar4
-        save(mVar5, " ")
-        >>> m4.sum(axis=1).toNumPy()
-        array([[-60.],
-               [-60.],
-               [-60.]])
-
-    Design Decisions:
-
-    1.  Until eval() method is invoked, we create an AST (not exposed to
-        the user) that consist of unevaluated operations and data
-        required by those operations. As an anology, a spark user can
-        treat eval() method similar to calling RDD.persist() followed by
-        RDD.count().
-    2.  The AST consist of two kinds of nodes: either of type matrix or
-        of type DMLOp. Both these classes expose \_visit method, that
-        helps in traversing the AST in DFS manner.
-    3.  A matrix object can either be evaluated or not. If evaluated,
-        the attribute 'data' is set to one of the supported types (for
-        example: NumPy array or DataFrame). In this case, the attribute
-        'op' is set to None. If not evaluated, the attribute 'op' which
-        refers to one of the intermediate node of AST and if of type
-        DMLOp. In this case, the attribute 'data' is set to None.
-
-    5.  DMLOp has an attribute 'inputs' which contains list of matrix
-        objects or DMLOp.
-
-    6.  To simplify the traversal, every matrix object is considered
-        immutable and an matrix operations creates a new matrix object.
-        As an example: m1 = sml.matrix(np.ones((3,3))) creates a matrix
-        object backed by 'data=(np.ones((3,3))'. m1 = m1 \* 2 will
-        create a new matrix object which is now backed by 'op=DMLOp( ...
-        )' whose input is earlier created matrix object.
-
-    7.  Left indexing (implemented in \_\_setitem\_\_ method) is a
-        special case, where Python expects the existing object to be
-        mutated. To ensure the above property, we make deep copy of
-        existing object and point any references to the left-indexed
-        matrix to the newly created object. Then the left-indexed matrix
-        is set to be backed by DMLOp consisting of following pydml:
-        left-indexed-matrix = new-deep-copied-matrix
-        left-indexed-matrix[index] = value
-
-    8.  Please use m.print\_ast() and/or type m for debugging. Here is a
-        sample session:
-
-            >>> npm = np.ones((3,3))
-            >>> m1 = sml.matrix(npm + 3)
-            >>> m2 = sml.matrix(npm + 5)
-            >>> m3 = m1 + m2
-            >>> m3
-            mVar2 = load(" ", format="csv")
-            mVar1 = load(" ", format="csv")
-            mVar3 = mVar1 + mVar2
-            save(mVar3, " ")
-            >>> m3.print_ast()
-            - [mVar3] (op).
-              - [mVar1] (data).
-              - [mVar2] (data).    
-
- `abs`()
-:   
-
- `acos`()
-:   
-
- `arccos`()
-:   
-
- `arcsin`()
-:   
-
- `arctan`()
-:   
-
- `argmax`(*axis=None*)
-:   Returns the indices of the maximum values along an axis.
-
-    axis : int, optional (only axis=1, i.e. rowIndexMax is supported
-    in this version)
-
- `argmin`(*axis=None*)
-:   Returns the indices of the minimum values along an axis.
-
-    axis : int, optional (only axis=1, i.e. rowIndexMax is supported
-    in this version)
-
- `asfptype`()
-:   
-
- `asin`()
-:   
-
- `astype`(*t*)
-:   
-
- `atan`()
-:   
-
- `ceil`()
-:   
-
- `cos`()
-:   
-
- `cumsum`(*axis=None*)
-:   Returns the indices of the maximum values along an axis.
-
-    axis : int, optional (only axis=0, i.e. cumsum along the rows is
-    supported in this version)
-
- `deg2rad`()
-:   Convert angles from degrees to radians.
-
- `dot`(*other*)[](#systemml.defmatrix.matrix.dot "Permalink to this definition")
-:   Numpy way of performing matrix multiplication
-
- `eval`(*outputDF=False*)[](#systemml.defmatrix.matrix.eval "Permalink to this definition")
-:   This is a convenience function that calls the global eval method
-
- `exp`()[](#systemml.defmatrix.matrix.exp "Permalink to this definition")
-:   
-
- `exp2`()[](#systemml.defmatrix.matrix.exp2 "Permalink to this definition")
-:   
-
- `expm1`()[](#systemml.defmatrix.matrix.expm1 "Permalink to this definition")
-:   
-
- `floor`()[](#systemml.defmatrix.matrix.floor "Permalink to this definition")
-:   
-
- `get_shape`()[](#systemml.defmatrix.matrix.get_shape "Permalink to this definition")
-:   
-
- `ldexp`(*other*)[](#systemml.defmatrix.matrix.ldexp "Permalink to this definition")
-:   
-
- `log`(*y=None*)[](#systemml.defmatrix.matrix.log "Permalink to this definition")
-:   
-
- `log10`()[](#systemml.defmatrix.matrix.log10 "Permalink to this definition")
-:   
-
- `log1p`()[](#systemml.defmatrix.matrix.log1p "Permalink to this definition")
-:   
-
- `log2`()[](#systemml.defmatrix.matrix.log2 "Permalink to this definition")
-:   
-
- `logaddexp`(*other*)[](#systemml.defmatrix.matrix.logaddexp "Permalink to this definition")
-:   
-
- `logaddexp2`(*other*)[](#systemml.defmatrix.matrix.logaddexp2 "Permalink to this definition")
-:   
-
- `logical_not`()[](#systemml.defmatrix.matrix.logical_not "Permalink to this definition")
-:   
-
- `max`(*other=None*, *axis=None*)[](#systemml.defmatrix.matrix.max "Permalink to this definition")
-:   Compute the maximum value along the specified axis
-
-    other: matrix or numpy array (& other supported types) or scalar
-    axis : int, optional
-
- `mean`(*axis=None*)[](#systemml.defmatrix.matrix.mean "Permalink to this definition")
-:   Compute the arithmetic mean along the specified axis
-
-    axis : int, optional
-
- `min`(*other=None*, *axis=None*)[](#systemml.defmatrix.matrix.min "Permalink to this definition")
-:   Compute the minimum value along the specified axis
+### Design Decisions:
 
-    other: matrix or numpy array (& other supported types) or scalar
-    axis : int, optional
+1.  Until eval() method is invoked, we create an AST (not exposed to
+the user) that consist of unevaluated operations and data
+required by those operations. As an anology, a spark user can
+treat eval() method similar to calling RDD.persist() followed by
+RDD.count().
 
- `mod`(*other*)[](#systemml.defmatrix.matrix.mod "Permalink to this definition")
-:   
+2.  The AST consist of two kinds of nodes: either of type matrix or
+of type DMLOp. Both these classes expose \_visit method, that
+helps in traversing the AST in DFS manner.
 
- `ndim`*= 2*[](#systemml.defmatrix.matrix.ndim "Permalink to this definition")
-:   
+3.  A matrix object can either be evaluated or not. If evaluated,
+the attribute 'data' is set to one of the supported types (for
+example: NumPy array or DataFrame). In this case, the attribute
+'op' is set to None. If not evaluated, the attribute 'op' which
+refers to one of the intermediate node of AST and if of type
+DMLOp. In this case, the attribute 'data' is set to None.
 
- `negative`()[](#systemml.defmatrix.matrix.negative "Permalink to this definition")
-:   
+5.  DMLOp has an attribute 'inputs' which contains list of matrix
+objects or DMLOp.
 
- `ones_like`()[](#systemml.defmatrix.matrix.ones_like "Permalink to this definition")
-:   
+6.  To simplify the traversal, every matrix object is considered
+immutable and an matrix operations creates a new matrix object.
+As an example: m1 = sml.matrix(np.ones((3,3))) creates a matrix
+object backed by 'data=(np.ones((3,3))'. m1 = m1 \* 2 will
+create a new matrix object which is now backed by 'op=DMLOp( ...)' 
+whose input is earlier created matrix object.
 
- `print_ast`()[](#systemml.defmatrix.matrix.print_ast "Permalink to this definition")
-:   Please use m.print\_ast() and/or type m for debugging. Here is a
-    sample session:
-
-        >>> npm = np.ones((3,3))
-        >>> m1 = sml.matrix(npm + 3)
-        >>> m2 = sml.matrix(npm + 5)
-        >>> m3 = m1 + m2
-        >>> m3
-        mVar2 = load(" ", format="csv")
-        mVar1 = load(" ", format="csv")
-        mVar3 = mVar1 + mVar2
-        save(mVar3, " ")
-        >>> m3.print_ast()
-        - [mVar3] (op).
-          - [mVar1] (data).
-          - [mVar2] (data).
-
- `rad2deg`()[](#systemml.defmatrix.matrix.rad2deg "Permalink to this definition")
-:   Convert angles from radians to degrees.
-
- `reciprocal`()[](#systemml.defmatrix.matrix.reciprocal "Permalink to this definition")
-:   
-
- `remainder`(*other*)[](#systemml.defmatrix.matrix.remainder "Permalink to this definition")
-:   
-
- `round`()[](#systemml.defmatrix.matrix.round "Permalink to this definition")
-:   
-
- `script`*= None*[](#systemml.defmatrix.matrix.script "Permalink to this definition")
-:   
-
- `sd`(*axis=None*)[](#systemml.defmatrix.matrix.sd "Permalink to this definition")
-:   Compute the standard deviation along the specified axis
-
-    axis : int, optional
-
- `set_shape`(*shape*)[](#systemml.defmatrix.matrix.set_shape "Permalink to this definition")
-:   
-
- `shape`[](#systemml.defmatrix.matrix.shape "Permalink to this definition")
-:   
-
- `sign`()[](#systemml.defmatrix.matrix.sign "Permalink to this definition")
-:   
-
- `sin`()[](#systemml.defmatrix.matrix.sin "Permalink to this definition")
-:   
-
- `sqrt`()[](#systemml.defmatrix.matrix.sqrt "Permalink to this definition")
-:   
-
- `square`()[](#systemml.defmatrix.matrix.square "Permalink to this definition")
-:   
-
- `sum`(*axis=None*)[](#systemml.defmatrix.matrix.sum "Permalink to this definition")
-:   Compute the sum along the specified axis. 
-
-    axis : int, optional
-
- `systemmlVarID`*= 0*[](#systemml.defmatrix.matrix.systemmlVarID "Permalink to this definition")
-:   
-
- `tan`()[](#systemml.defmatrix.matrix.tan "Permalink to this definition")
-:   
-
- `toDF`()[](#systemml.defmatrix.matrix.toDF "Permalink to this definition")
-:   This is a convenience function that calls the global eval method
-    and then converts the matrix object into DataFrame.
-
- `toNumPy`()[](#systemml.defmatrix.matrix.toNumPy "Permalink to this definition")
-:   This is a convenience function that calls the global eval method
-    and then converts the matrix object into NumPy array.
-
- `toPandas`()[](#systemml.defmatrix.matrix.toPandas "Permalink to this definition")
-:   This is a convenience function that calls the global eval method
-    and then converts the matrix object into Pandas DataFrame.
-
- `trace`()[](#systemml.defmatrix.matrix.trace "Permalink to this definition")
-:   Return the sum of the cells of the main diagonal square matrix
-
- `transpose`()[](#systemml.defmatrix.matrix.transpose "Permalink to this definition")
-:   Transposes the matrix.
-
- `var`(*axis=None*)[](#systemml.defmatrix.matrix.var "Permalink to this definition")
-:   Compute the variance along the specified axis
-
-    axis : int, optional
-
- `zeros_like`()[](#systemml.defmatrix.matrix.zeros_like "Permalink to this definition")
-:   
-
- `systemml.defmatrix.eval`(*outputs*, *outputDF=False*, *execute=True*)[](#systemml.defmatrix.eval "Permalink to this definition")
-:   Executes the unevaluated DML script and computes the matrices
-    specified by outputs.
-
-    outputs: list of matrices or a matrix object outputDF: back the data
-    of matrix as PySpark DataFrame
-
- `systemml.defmatrix.solve`(*A*, *b*)[](#systemml.defmatrix.solve "Permalink to this definition")
-:   Computes the least squares solution for system of linear equations A
-    %\*% x = b
-
-        >>> import numpy as np
-        >>> from sklearn import datasets
-        >>> import SystemML as sml
-        >>> from pyspark.sql import SQLContext
-        >>> diabetes = datasets.load_diabetes()
-        >>> diabetes_X = diabetes.data[:, np.newaxis, 2]
-        >>> X_train = diabetes_X[:-20]
-        >>> X_test = diabetes_X[-20:]
-        >>> y_train = diabetes.target[:-20]
-        >>> y_test = diabetes.target[-20:]
-        >>> sml.setSparkContext(sc)
-        >>> X = sml.matrix(X_train)
-        >>> y = sml.matrix(y_train)
-        >>> A = X.transpose().dot(X)
-        >>> b = X.transpose().dot(y)
-        >>> beta = sml.solve(A, b).toNumPy()
-        >>> y_predicted = X_test.dot(beta)
-        >>> print('Residual sum of squares: %.2f' % np.mean((y_predicted - y_test) ** 2))
-        Residual sum of squares: 25282.12
-
- `systemml.defmatrix.set_lazy`(*isLazy*)[](#systemml.defmatrix.set_max_depth "Permalink to this definition")
-:   This method allows users to set whether the matrix operations should be executed in lazy manner.
-
-    isLazy: True if matrix operations should be evaluated in lazy manner.
-
- `systemml.defmatrix.debug_array_conversion`(*throwError*)[](#systemml.defmatrix.debug_array_conversion "Permalink to this definition")
-:   
-
- `systemml.random.sampling.normal`(*loc=0.0*, *scale=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.normal "Permalink to this definition")
-:   Draw random samples from a normal (Gaussian) distribution.
-
-    loc: Mean ('centre') of the distribution. scale: Standard deviation
-    (spread or 'width') of the distribution. size: Output shape (only
-    tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity
-    (between 0.0 and 1.0).
-
-        >>> import systemml as sml
-        >>> import numpy as np
-        >>> sml.setSparkContext(sc)
-        >>> from systemml import random
-        >>> m1 = sml.random.normal(loc=3, scale=2, size=(3,3))
-        >>> m1.toNumPy()
-        array([[ 3.48857226,  6.17261819,  2.51167259],
-               [ 3.60506708, -1.90266305,  3.97601633],
-               [ 3.62245706,  5.9430881 ,  2.53070413]])
-
- `systemml.random.sampling.uniform`(*low=0.0*, *high=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.uniform "Permalink to this definition")
-:   Draw samples from a uniform distribution.
-
-    low: Lower boundary of the output interval. high: Upper boundary of
-    the output interval. size: Output shape (only tuple of length 2,
-    i.e. (m, n), supported). sparsity: Sparsity (between 0.0 and 1.0).
-
-        >>> import systemml as sml
-        >>> import numpy as np
-        >>> sml.setSparkContext(sc)
-        >>> from systemml import random
-        >>> m1 = sml.random.uniform(size=(3,3))
-        >>> m1.toNumPy()
-        array([[ 0.54511396,  0.11937437,  0.72975775],
-               [ 0.14135946,  0.01944448,  0.52544478],
-               [ 0.67582422,  0.87068849,  0.02766852]])
-
- `systemml.random.sampling.poisson`(*lam=1.0*, *size=(1*, *1)*, *sparsity=1.0*)(#systemml.random.sampling.poisson "Permalink to this definition")
-:   Draw samples from a Poisson distribution.
-
-    lam: Expectation of interval, should be \> 0. size: Output shape
-    (only tuple of length 2, i.e. (m, n), supported). sparsity: Sparsity
-    (between 0.0 and 1.0).
-
-        >>> import systemml as sml
-        >>> import numpy as np
-        >>> sml.setSparkContext(sc)
-        >>> from systemml import random
-        >>> m1 = sml.random.poisson(lam=1, size=(3,3))
-        >>> m1.toNumPy()
-        array([[ 1.,  0.,  2.],
-               [ 1.,  0.,  0.],
-               [ 0.,  0.,  0.]])
+7.  Left indexing (implemented in \_\_setitem\_\_ method) is a
+special case, where Python expects the existing object to be
+mutated. To ensure the above property, we make deep copy of
+existing object and point any references to the left-indexed
+matrix to the newly created object. Then the left-indexed matrix
+is set to be backed by DMLOp consisting of following pydml:
+left-indexed-matrix = new-deep-copied-matrix
+left-indexed-matrix[index] = value
 
+8.  Please use m.print\_ast() and/or type m for debugging. Here is a
+sample session:
 
+```python
+>>> npm = np.ones((3,3))
+>>> m1 = sml.matrix(npm + 3)
+>>> m2 = sml.matrix(npm + 5)
+>>> m3 = m1 + m2
+>>> m3
+mVar2 = load(" ", format="csv")
+mVar1 = load(" ", format="csv")
+mVar3 = mVar1 + mVar2
+save(mVar3, " ")
+>>> m3.print_ast()
+- [mVar3] (op).
+  - [mVar1] (data).
+  - [mVar2] (data).    
+```
 
 ## MLContext API
 
@@ -584,323 +239,179 @@ script = sml.dml(scriptPath).input(X=X_df, Y_vec=y_df).output("B_out")
 beta = ml.execute(script).get('B_out').toNumPy()
 ```
 
-### Reference documentation
-
- *class*`systemml.mlcontext.MLResults`(*results*, *sc*)[](#systemml.mlcontext.MLResults "Permalink to this definition")
-:   Bases: `object`{.xref .py .py-class .docutils .literal}
-
-    Wrapper around a Java ML Results object.
-
-    results: JavaObject
-    :   A Java MLResults object as returned by calling ml.execute().
-    sc: SparkContext
-    :   SparkContext
-
-     `get`(*\*outputs*)[](#systemml.mlcontext.MLResults.get "Permalink to this definition")
-    :   outputs: string, list of strings
-        :   Output variables as defined inside the DML script.
-
- *class*`systemml.mlcontext.MLContext`(*sc*)[](#systemml.mlcontext.MLContext "Permalink to this definition")
-:   Bases: `object`{.xref .py .py-class .docutils .literal}
-
-    Wrapper around the new SystemML MLContext.
-
-    sc: SparkContext
-    :   SparkContext
-
- `execute`(*script*)[](#systemml.mlcontext.MLContext.execute "Permalink to this definition")
-:   Execute a DML / PyDML script.
-
-    script: Script instance
-    :   Script instance defined with the appropriate input and
-        output variables.
-
-    ml\_results: MLResults
-    :   MLResults instance.
-
- `setExplain`(*explain*)[](#systemml.mlcontext.MLContext.setExplain "Permalink to this definition")
-:   Explanation about the program. Mainly intended for developers.
-
-    explain: boolean
-
- `setExplainLevel`(*explainLevel*)[](#systemml.mlcontext.MLContext.setExplainLevel "Permalink to this definition")
-:   Set explain level.
-
-    explainLevel: string
-    :   Can be one of 'hops', 'runtime', 'recompile\_hops',
-        'recompile\_runtime' or in the above in upper case.
-
- `setStatistics`(*statistics*)[](#systemml.mlcontext.MLContext.setStatistics "Permalink to this definition")
-:   Whether or not to output statistics (such as execution time,
-    elapsed time) about script executions.
-
-    statistics: boolean
-
- `setStatisticsMaxHeavyHitters`(*maxHeavyHitters*)[](#systemml.mlcontext.MLContext.setStatisticsMaxHeavyHitters "Permalink to this definition")
-:   The maximum number of heavy hitters that are printed as part of
-    the statistics.
-
-    maxHeavyHitters: int
-
- *class*`systemml.mlcontext.Script`(*scriptString*, *scriptType='dml'*)[](#systemml.mlcontext.Script "Permalink to this definition")
-:   Bases: `object`{.xref .py .py-class .docutils .literal}
-
-    Instance of a DML/PyDML Script.
-
-    scriptString: string
-    :   Can be either a file path to a DML script or a DML script
-        itself.
-    scriptType: string
-    :   Script language, either 'dml' for DML (R-like) or 'pydml' for
-        PyDML (Python-like).
-
- `input`(*\*args*, *\*\*kwargs*)[](#systemml.mlcontext.Script.input "Permalink to this definition")
-:   args: name, value tuple
-    :   where name is a string, and currently supported value
-        formats are double, string, dataframe, rdd, and list of such
-        object.
-    kwargs: dict of name, value pairs
-    :   To know what formats are supported for name and value, look
-        above.
-
- `output`(*\*names*)[](#systemml.mlcontext.Script.output "Permalink to this definition")
-:   names: string, list of strings
-    :   Output variables as defined inside the DML script.
-
- `systemml.mlcontext.dml`(*scriptString*)[](#systemml.mlcontext.dml "Permalink to this definition")
-:   Create a dml script object based on a string.
-
-    scriptString: string
-    :   Can be a path to a dml script or a dml script itself.
-
-    script: Script instance
-    :   Instance of a script object.
-
- `systemml.mlcontext.pydml`(*scriptString*)[](#systemml.mlcontext.pydml "Permalink to this definition")
-:   Create a pydml script object based on a string.
-
-    scriptString: string
-    :   Can be a path to a pydml script or a pydml script itself.
-
-    script: Script instance
-    :   Instance of a script object.
-
- `systemml.mlcontext.getNumCols`(*numPyArr*)[](#systemml.mlcontext.getNumCols "Permalink to this definition")
-:   
-
- `systemml.mlcontext.convertToMatrixBlock`(*sc*, *src*)[](#systemml.mlcontext.convertToMatrixBlock "Permalink to this definition")
-:   
-
- `systemml.mlcontext.convertToNumPyArr`(*sc*, *mb*)[](#systemml.mlcontext.convertToNumPyArr "Permalink to this definition")
-:   
-
- `systemml.mlcontext.convertToPandasDF`(*X*)[](#systemml.mlcontext.convertToPandasDF "Permalink to this definition")
-:   
-
- `systemml.mlcontext.convertToLabeledDF`(*sqlCtx*, *X*, *y=None*)[](#systemml.mlcontext.convertToLabeledDF "Permalink to this definition")
-:   
-
 
 ## mllearn API
 
-### Usage
+The below code describes how to use mllearn API for training:
+
+<div class="codetabs">
+<div data-lang="sklearn way" markdown="1">
+{% highlight python %}
+# Input: Two Python objects (X_train, y_train) of type numpy, pandas or scipy.
+model.fit(X_train, y_train)
+{% endhighlight %}
+</div>
+<div data-lang="mllib way" markdown="1">
+{% highlight python %}
+# Input: One LabeledPoint DataFrame with atleast two columns: features (of type Vector) and labels.
+model.fit(X_df)
+{% endhighlight %}
+</div>
+</div>
+
+The below code describes how to use mllearn API for prediction:
+
+<div class="codetabs">
+<div data-lang="sklearn way" markdown="1">
+{% highlight python %}
+# Input: One Python object (X_test) of type numpy, pandas or scipy.
+model.predict(X_test)
+# OR model.score(X_test, y_test)
+{% endhighlight %}
+</div>
+<div data-lang="mllib way" markdown="1">
+{% highlight python %}
+# Input: One LabeledPoint DataFrame (df_test) with atleast one column: features (of type Vector).
+model.transform(df_test)
+{% endhighlight %}
+</div>
+</div>
+
+
+The table below describes the parameter available for mllearn algorithms:
+
+| Parameters | Description of the Parameters | LogisticRegression | LinearRegression | SVM | NaiveBayes |
+|----------------|-----------------------------------------------------------------------------------------------|-----------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----|------------|
+| sparkSession | PySpark SparkSession | X | X | X | X |
+| penalty | Used to specify the norm used in the penalization (default: 'l2') | only 'l2' supported | - | - | - |
+| fit_intercept | Specifies whether to add intercept or not (default: True) | X | X | X | - |
+| normalize | This parameter is ignored when fit_intercept is set to False. (default: False) | X | X | X | - |
+| max_iter | Maximum number of iterations (default: 100) | X | X | X | - |
+| max_inner_iter | Maximum number of inner iterations, or 0 if no maximum limit provided (default: 0) | X | - | - | - |
+| tol | Tolerance used in the convergence criterion (default: 0.000001) | X | X | X | - |
+| C | 1/regularization parameter (default: 1.0). To disable regularization, please use float("inf") | X | X | X | - |
+| solver | Algorithm to use in the optimization problem. | Only 'newton-cg' solver supported | Supports either 'newton-cg' or 'direct-solve' (default: 'newton-cg'). Depending on the size and the sparsity of the feature matrix, one or the other solver may be more efficient. 'direct-solve' solver is more efficient when the number of features is relatively small (m < 1000) and input matrix X is either tall or fairly dense; otherwise 'newton-cg' solver is more efficient. | - | - |
+| is_multi_class | Specifies whether to use binary-class or multi-class classifier (default: False) | - | - | X | - |
+| laplace | Laplace smoothing specified by the user to avoid creation of 0 probabilities (default: 1.0) | - | - | - | X |
+
+In the below example, we invoke SystemML's [Logistic Regression](https://apache.github.io/incubator-systemml/algorithms-classification.html#multinomial-logistic-regression)
+algorithm on digits datasets.
 
 ```python
 # Scikit-learn way
 from sklearn import datasets, neighbors
 from systemml.mllearn import LogisticRegression
-from pyspark.sql import SQLContext
-sqlCtx = SQLContext(sc)
 digits = datasets.load_digits()
 X_digits = digits.data
-y_digits = digits.target 
+y_digits = digits.target
 n_samples = len(X_digits)
-X_train = X_digits[:.9 * n_samples]
-y_train = y_digits[:.9 * n_samples]
-X_test = X_digits[.9 * n_samples:]
-y_test = y_digits[.9 * n_samples:]
-logistic = LogisticRegression(sqlCtx)
+X_train = X_digits[:int(.9 * n_samples)]
+y_train = y_digits[:int(.9 * n_samples)]
+X_test = X_digits[int(.9 * n_samples):]
+y_test = y_digits[int(.9 * n_samples):]
+logistic = LogisticRegression(spark)
 print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
 ```
 
 Output:
 
 ```bash
-LogisticRegression score: 0.922222
+LogisticRegression score: 0.927778
 ```
 
-### Reference documentation
-
- *class*`systemml.mllearn.estimators.LinearRegression`(*sqlCtx*, *fit\_intercept=True*, *normalize=False*, *max\_iter=100*, *tol=1e-06*, *C=float("inf")*, *solver='newton-cg'*, *transferUsingDF=False*)(#systemml.mllearn.estimators.LinearRegression "Permalink to this definition")
-:   Bases: `systemml.mllearn.estimators.BaseSystemMLRegressor`{.xref .py
-    .py-class .docutils .literal}
-
-    Performs linear regression to model the relationship between one
-    numerical response variable and one or more explanatory (feature)
-    variables.
-
-        >>> import numpy as np
-        >>> from sklearn import datasets
-        >>> from systemml.mllearn import LinearRegression
-        >>> from pyspark.sql import SQLContext
-        >>> # Load the diabetes dataset
-        >>> diabetes = datasets.load_diabetes()
-        >>> # Use only one feature
-        >>> diabetes_X = diabetes.data[:, np.newaxis, 2]
-        >>> # Split the data into training/testing sets
-        >>> diabetes_X_train = diabetes_X[:-20]
-        >>> diabetes_X_test = diabetes_X[-20:]
-        >>> # Split the targets into training/testing sets
-        >>> diabetes_y_train = diabetes.target[:-20]
-        >>> diabetes_y_test = diabetes.target[-20:]
-        >>> # Create linear regression object
-        >>> regr = LinearRegression(sqlCtx, solver='newton-cg')
-        >>> # Train the model using the training sets
-        >>> regr.fit(diabetes_X_train, diabetes_y_train)
-        >>> # The mean square error
-        >>> print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
-
- *class*`systemml.mllearn.estimators.LogisticRegression`(*sqlCtx*, *penalty='l2'*, *fit\_intercept=True*, *normalize=False*,  *max\_iter=100*, *max\_inner\_iter=0*, *tol=1e-06*, *C=1.0*, *solver='newton-cg'*, *transferUsingDF=False*)(#systemml.mllearn.estimators.LogisticRegression "Permalink to this definition")
-:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
-    .py .py-class .docutils .literal}
-
-    Performs both binomial and multinomial logistic regression.
-
-    Scikit-learn way
-
-        >>> from sklearn import datasets, neighbors
-        >>> from systemml.mllearn import LogisticRegression
-        >>> from pyspark.sql import SQLContext
-        >>> sqlCtx = SQLContext(sc)
-        >>> digits = datasets.load_digits()
-        >>> X_digits = digits.data
-        >>> y_digits = digits.target + 1
-        >>> n_samples = len(X_digits)
-        >>> X_train = X_digits[:.9 * n_samples]
-        >>> y_train = y_digits[:.9 * n_samples]
-        >>> X_test = X_digits[.9 * n_samples:]
-        >>> y_test = y_digits[.9 * n_samples:]
-        >>> logistic = LogisticRegression(sqlCtx)
-        >>> print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
-
-    MLPipeline way
-
-        >>> from pyspark.ml import Pipeline
-        >>> from systemml.mllearn import LogisticRegression
-        >>> from pyspark.ml.feature import HashingTF, Tokenizer
-        >>> from pyspark.sql import SQLContext
-        >>> sqlCtx = SQLContext(sc)
-        >>> training = sqlCtx.createDataFrame([
-        >>>     (0L, "a b c d e spark", 1.0),
-        >>>     (1L, "b d", 2.0),
-        >>>     (2L, "spark f g h", 1.0),
-        >>>     (3L, "hadoop mapreduce", 2.0),
-        >>>     (4L, "b spark who", 1.0),
-        >>>     (5L, "g d a y", 2.0),
-        >>>     (6L, "spark fly", 1.0),
-        >>>     (7L, "was mapreduce", 2.0),
-        >>>     (8L, "e spark program", 1.0),
-        >>>     (9L, "a e c l", 2.0),
-        >>>     (10L, "spark compile", 1.0),
-        >>>     (11L, "hadoop software", 2.0)
-        >>> ], ["id", "text", "label"])
-        >>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
-        >>> hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-        >>> lr = LogisticRegression(sqlCtx)
-        >>> pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
-        >>> model = pipeline.fit(training)
-        >>> test = sqlCtx.createDataFrame([
-        >>>     (12L, "spark i j k"),
-        >>>     (13L, "l m n"),
-        >>>     (14L, "mapreduce spark"),
-        >>>     (15L, "apache hadoop")], ["id", "text"])
-        >>> prediction = model.transform(test)
-        >>> prediction.show()
-
- *class*`systemml.mllearn.estimators.SVM`(*sqlCtx*, *fit\_intercept=True*, *normalize=False*, *max\_iter=100*, *tol=1e-06*, *C=1.0*, *is\_multi\_class=False*, *transferUsingDF=False*)(#systemml.mllearn.estimators.SVM "Permalink to this definition")
-:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
-    .py .py-class .docutils .literal}
-
-    Performs both binary-class and multiclass SVM (Support Vector
-    Machines).
-
-        >>> from sklearn import datasets, neighbors
-        >>> from systemml.mllearn import SVM
-        >>> from pyspark.sql import SQLContext
-        >>> sqlCtx = SQLContext(sc)
-        >>> digits = datasets.load_digits()
-        >>> X_digits = digits.data
-        >>> y_digits = digits.target 
-        >>> n_samples = len(X_digits)
-        >>> X_train = X_digits[:.9 * n_samples]
-        >>> y_train = y_digits[:.9 * n_samples]
-        >>> X_test = X_digits[.9 * n_samples:]
-        >>> y_test = y_digits[.9 * n_samples:]
-        >>> svm = SVM(sqlCtx, is_multi_class=True)
-        >>> print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
-
- *class*`systemml.mllearn.estimators.NaiveBayes`(*sqlCtx*, *laplace=1.0*, *transferUsingDF=False*)(#systemml.mllearn.estimators.NaiveBayes "Permalink to this definition")
-:   Bases: `systemml.mllearn.estimators.BaseSystemMLClassifier`{.xref
-    .py .py-class .docutils .literal}
-
-    Performs Naive Bayes.
-
-        >>> from sklearn.datasets import fetch_20newsgroups
-        >>> from sklearn.feature_extraction.text import TfidfVectorizer
-        >>> from systemml.mllearn import NaiveBayes
-        >>> from sklearn import metrics
-        >>> from pyspark.sql import SQLContext
-        >>> sqlCtx = SQLContext(sc)
-        >>> categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
-        >>> newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
-        >>> newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
-        >>> vectorizer = TfidfVectorizer()
-        >>> # Both vectors and vectors_test are SciPy CSR matrix
-        >>> vectors = vectorizer.fit_transform(newsgroups_train.data)
-        >>> vectors_test = vectorizer.transform(newsgroups_test.data)
-        >>> nb = NaiveBayes(sqlCtx)
-        >>> nb.fit(vectors, newsgroups_train.target)
-        >>> pred = nb.predict(vectors_test)
-        >>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
-
-
-## Utility classes (used internally)
-
-### systemml.classloader 
-
- `systemml.classloader.createJavaObject`(*sc*, *obj\_type*)[](#systemml.classloader.createJavaObject "Permalink to this definition")
-:   Performs appropriate check if SystemML.jar is available and returns
-    the handle to MLContext object on JVM
+You can also save the trained model and load it later for prediction:
 
-    sc: SparkContext
-    :   SparkContext
+```python
+# Assuming logistic.fit(X_train, y_train) is already invoked
+logistic.save('logistic_model')
+new_logistic = LogisticRegression(spark)
+new_logistic.load('logistic_model')
+print('LogisticRegression score: %f' % new_logistic.score(X_test, y_test))
+```
 
-    obj\_type: Type of object to create ('mlcontext' or 'dummy')
+#### Passing PySpark DataFrame
 
-### systemml.converters
+To train the above algorithm on larger dataset, we can load the dataset into DataFrame and pass it to the `fit` method:
 
- `systemml.converters.getNumCols`(*numPyArr*)[](#systemml.converters.getNumCols "Permalink to this definition")
-:   
+```python
+from sklearn import datasets
+from systemml.mllearn import LogisticRegression
+import pandas as pd
+from sklearn.metrics import accuracy_score
+import systemml as sml
+digits = datasets.load_digits()
+X_digits = digits.data
+y_digits = digits.target
+n_samples = len(X_digits)
+# Split the data into training/testing sets and convert to PySpark DataFrame
+df_train = sml.convertToLabeledDF(sqlCtx, X_digits[:int(.9 * n_samples)], y_digits[:int(.9 * n_samples)])
+X_test = spark.createDataFrame(pd.DataFrame(X_digits[int(.9 * n_samples):]))
+logistic = LogisticRegression(spark)
+logistic.fit(df_train)
+y_predicted = logistic.predict(X_test)
+y_predicted = y_predicted.select('prediction').toPandas().as_matrix().flatten()
+y_test = y_digits[int(.9 * n_samples):]
+print('LogisticRegression score: %f' % accuracy_score(y_test, y_predicted))
+```
 
- `systemml.converters.convertToMatrixBlock`(*sc*, *src*)[](#systemml.converters.convertToMatrixBlock "Permalink to this definition")
-:   
+Output:
 
- `systemml.converters.convertToNumPyArr`(*sc*, *mb*)[](#systemml.converters.convertToNumPyArr "Permalink to this definition")
-:   
+```bash
+LogisticRegression score: 0.922222
+```
 
- `systemml.converters.convertToPandasDF`(*X*)[](#systemml.converters.convertToPandasDF "Permalink to this definition")
-:   
+#### MLPipeline interface
 
- `systemml.converters.convertToLabeledDF`(*sqlCtx*, *X*, *y=None*)[](#systemml.converters.convertToLabeledDF "Permalink to this definition")
-:  
+In the below example, we demonstrate how the same `LogisticRegression` class can allow SystemML to fit seamlessly into 
+large data pipelines.
 
-### Other classes from systemml.defmatrix
+```python
+# MLPipeline way
+from pyspark.ml import Pipeline
+from systemml.mllearn import LogisticRegression
+from pyspark.ml.feature import HashingTF, Tokenizer
+training = spark.createDataFrame([
+    (0, "a b c d e spark", 1.0),
+    (1, "b d", 2.0),
+    (2, "spark f g h", 1.0),
+    (3, "hadoop mapreduce", 2.0),
+    (4, "b spark who", 1.0),
+    (5, "g d a y", 2.0),
+    (6, "spark fly", 1.0),
+    (7, "was mapreduce", 2.0),
+    (8, "e spark program", 1.0),
+    (9, "a e c l", 2.0),
+    (10, "spark compile", 1.0),
+    (11, "hadoop software", 2.0)
+], ["id", "text", "label"])
+tokenizer = Tokenizer(inputCol="text", outputCol="words")
+hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+lr = LogisticRegression(sqlCtx)
+pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+model = pipeline.fit(training)
+test = spark.createDataFrame([
+    (12, "spark i j k"),
+    (13, "l m n"),
+    (14, "mapreduce spark"),
+    (15, "apache hadoop")], ["id", "text"])
+prediction = model.transform(test)
+prediction.show()
+```
+
+Output:
+
+```bash
++-------+---+---------------+------------------+--------------------+--------------------+----------+
+|__INDEX| id|           text|             words|            features|         probability|prediction|
++-------+---+---------------+------------------+--------------------+--------------------+----------+
+|    1.0| 12|    spark i j k|  [spark, i, j, k]|(20,[5,6,7],[2.0,...|[0.99999999999975...|       1.0|
+|    2.0| 13|          l m n|         [l, m, n]|(20,[8,9,10],[1.0...|[1.37552128844736...|       2.0|
+|    3.0| 14|mapreduce spark|[mapreduce, spark]|(20,[5,10],[1.0,1...|[0.99860290938153...|       1.0|
+|    4.0| 15|  apache hadoop|  [apache, hadoop]|(20,[9,14],[1.0,1...|[5.41688748236143...|       2.0|
++-------+---+---------------+------------------+--------------------+--------------------+----------+
+```
 
- *class*`systemml.defmatrix.DMLOp`(*inputs*, *dml=None*)[](#systemml.defmatrix.DMLOp "Permalink to this definition")
-:   Bases: `object`{.xref .py .py-class .docutils .literal}
 
-    Represents an intermediate node of Abstract syntax tree created to
-    generate the PyDML script
 
 
 ## Troubleshooting Python APIs

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index dd2558a..208dac5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -101,6 +101,8 @@
 				<exclude>staging/**/*</exclude>
 				<exclude>staging</exclude>
 				<exclude>nn/test/compare_backends/*</exclude>
+				<exclude>nn/test/compare_backends/*</exclude>
+				<exclude>nn/examples/caffe2dml/**/*</exclude>
 				<!-- <exclude>*.sh</exclude> --> <!-- applies to sparkDML.sh -->
 			</excludes>
 			<targetPath>scripts</targetPath>
@@ -874,6 +876,7 @@
 								<exclude>src/main/proto/caffe/caffe.proto</exclude>
 								<exclude>src/main/proto/tensorflow/event.proto</exclude>
 								<exclude>src/main/proto/tensorflow/summary.proto</exclude>
+								<exclude>scripts/nn/examples/caffe2dml/models/**/*</exclude>
 								<!-- Test Validation files -->
 								<exclude>src/test/scripts/functions/jmlc/**/*.impute</exclude>
 								<exclude>src/test/scripts/functions/jmlc/**/*.map</exclude>

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet.proto
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet.proto b/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet.proto
new file mode 100644
index 0000000..756734a
--- /dev/null
+++ b/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet.proto
@@ -0,0 +1,195 @@
+name: "LeNet"
+layer {
+  name: "mnist"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TRAIN
+  }
+  transform_param {
+    scale: 0.00390625
+  }
+  data_param {
+    source: "mnist_train"
+    batch_size: 64
+    backend: LMDB
+  }
+}
+layer {
+  name: "mnist"
+  type: "Data"
+  top: "data"
+  top: "label"
+  include {
+    phase: TEST
+  }
+  transform_param {
+    scale: 0.00390625
+  }
+  data_param {
+    source: "mnist_test"
+    batch_size: 100
+    backend: LMDB
+  }
+}
+layer {
+  name: "conv1"
+  type: "Convolution"
+  bottom: "mnist"
+  top: "conv1"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 32
+    kernel_size: 5
+    stride: 1
+	pad: 2
+    weight_filler {
+      type: "msra"
+    }
+    bias_filler {
+      type: "constant"
+	  value: 0.1
+    }
+  }
+}
+layer {
+  name: "relu1"
+  type: "ReLU"
+  bottom: "conv1"
+  top: "relu1"
+}
+layer {
+  name: "pool1"
+  type: "Pooling"
+  bottom: "relu1"
+  top: "pool1"
+  pooling_param {
+    pool: MAX
+    kernel_size: 2
+    stride: 2
+  }
+}
+layer {
+  name: "conv2"
+  type: "Convolution"
+  bottom: "pool1"
+  top: "conv2"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  convolution_param {
+    num_output: 64
+    kernel_size: 5
+    stride: 1
+	pad: 2
+    weight_filler {
+      type: "msra"
+    }
+    bias_filler {
+      type: "constant"
+	  value: 0.1
+    }
+  }
+}
+layer {
+  name: "relu2"
+  type: "ReLU"
+  bottom: "conv2"
+  top: "relu2"
+}
+layer {
+  name: "pool2"
+  type: "Pooling"
+  bottom: "relu2"
+  top: "pool2"
+  pooling_param {
+    pool: MAX
+    kernel_size: 2
+    stride: 2
+  }
+}
+layer {
+  name: "ip1"
+  type: "InnerProduct"
+  bottom: "pool2"
+  top: "ip1"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  inner_product_param {
+    num_output: 512
+    weight_filler {
+      type: "msra"
+    }
+    bias_filler {
+      type: "constant"
+	  value: 0
+    }
+  }
+}
+layer {
+  name: "relu3"
+  type: "ReLU"
+  bottom: "ip1"
+  top: "relu3"
+}
+layer {
+  name: "drop1"
+  type: "Dropout"
+  bottom: "relu3"
+  top: "drop1"
+  dropout_param {
+    dropout_ratio: 0.5
+  }
+}
+layer {
+  name: "ip2"
+  type: "InnerProduct"
+  bottom: "drop1"
+  top: "ip2"
+  param {
+    lr_mult: 1
+  }
+  param {
+    lr_mult: 2
+  }
+  inner_product_param {
+    num_output: 10
+    weight_filler {
+      type: "msra"
+    }
+    bias_filler {
+      type: "constant"
+	  value: 0
+    }
+  }
+}
+layer {
+  name: "accuracy"
+  type: "Accuracy"
+  bottom: "ip2"
+  bottom: "label"
+  top: "accuracy"
+  include {
+    phase: TEST
+  }
+}
+layer {
+  name: "loss"
+  type: "SoftmaxWithLoss"
+  bottom: "ip2"
+  bottom: "label"
+  top: "loss"
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto b/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto
new file mode 100644
index 0000000..3b943be
--- /dev/null
+++ b/scripts/nn/examples/caffe2dml/models/mnist_lenet/lenet_solver.proto
@@ -0,0 +1,19 @@
+# The train/test net protocol buffer definition
+net: "lenet.proto"
+# The base learning rate, momentum and the weight decay of the network.
+base_lr: 0.01
+momentum: 0.9
+weight_decay: 5e-4
+# The learning rate policy
+lr_policy: "exp"
+gamma: 0.95
+# Display every 100 iterations
+display: 100
+# solver mode: CPU or GPU
+solver_mode: CPU
+type: "SGD"
+# The maximum number of iterations
+max_iter: 2000
+# Carry out testing every 500 training iterations.
+test_iter: 10
+test_interval: 500
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d69f3441/src/main/python/systemml/converters.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/converters.py b/src/main/python/systemml/converters.py
index 8bf05d7..87a9a45 100644
--- a/src/main/python/systemml/converters.py
+++ b/src/main/python/systemml/converters.py
@@ -19,10 +19,11 @@
 #
 #-------------------------------------------------------------
 
-__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF', 'convertImageToNumPyArr']
+__all__ = [ 'getNumCols', 'convertToMatrixBlock', 'convert_caffemodel', 'convert_lmdb_to_jpeg', 'convertToNumPyArr', 'convertToPandasDF', 'SUPPORTED_TYPES' , 'convertToLabeledDF', 'convertImageToNumPyArr']
 
 import numpy as np
 import pandas as pd
+import os
 import math
 from pyspark.context import SparkContext
 from scipy.sparse import coo_matrix, spmatrix, csr_matrix
@@ -36,6 +37,99 @@ def getNumCols(numPyArr):
     else:
         return numPyArr.shape[1]
 
+def get_pretty_str(key, value):
+    return '\t"' + key + '": ' + str(value) + ',\n'
+        
+def save_tensor_csv(tensor, file_path, shouldTranspose):
+    w = w.reshape(w.shape[0], -1)
+    if shouldTranspose:
+        w = w.T
+    np.savetxt(file_path, w, delimiter=',')
+    with open(file_path + '.mtd', 'w') as file:
+        file.write('{\n\t"data_type": "matrix",\n\t"value_type": "double",\n')
+        file.write(get_pretty_str('rows', w.shape[0]))
+        file.write(get_pretty_str('cols', w.shape[1]))
+        file.write(get_pretty_str('nnz', np.count_nonzero(w)))
+        file.write('\t"format": "csv",\n\t"description": {\n\t\t"author": "SystemML"\n\t}\n}\n')
+    
+def convert_caffemodel(sc, deploy_file, caffemodel_file, output_dir, format="binary", is_caffe_installed=False):
+    """
+    Saves the weights and bias in the caffemodel file to output_dir in the specified format. 
+    This method does not requires caffe to be installed.
+    
+    Parameters
+    ----------
+    sc: SparkContext
+        SparkContext
+    
+    deploy_file: string
+        Path to the input network file
+        
+    caffemodel_file: string
+        Path to the input caffemodel file
+    
+    output_dir: string
+        Path to the output directory
+    
+    format: string
+        Format of the weights and bias (can be binary, csv or text)
+    
+    is_caffe_installed: bool
+        True if caffe is installed
+    """
+    if is_caffe_installed:
+        if format != 'csv':
+            raise ValueError('The format ' + str(format) + ' is not supported when caffe is installed. Hint: Please specify format=csv')
+        import caffe
+        net = caffe.Net(deploy_file, caffemodel_file, caffe.TEST)
+        for layerName in net.params.keys():
+            num_parameters = len(net.params[layerName])
+            if num_parameters == 0:
+                continue
+            elif num_parameters == 2:
+                # Weights and Biases
+                layerType = net.layers[list(net._layer_names).index(layerName)].type
+                shouldTranspose = True if layerType == 'InnerProduct' else False
+                save_tensor_csv(net.params[layerName][0].data, os.path.join(output_dir, layerName + '_weight.mtx'), shouldTranspose)
+                save_tensor_csv(net.params[layerName][1].data, os.path.join(output_dir, layerName + '_bias.mtx'), shouldTranspose)
+            elif num_parameters == 1:
+                # Only Weight
+                layerType = net.layers[list(net._layer_names).index(layerName)].type
+                shouldTranspose = True if layerType == 'InnerProduct' else False
+                save_tensor_csv(net.params[layerName][0].data, os.path.join(output_dir, layerName + '_weight.mtx'), shouldTranspose)
+            else:
+                raise ValueError('Unsupported number of parameters:' + str(num_parameters))
+    else:
+        createJavaObject(sc, 'dummy')
+        utilObj = sc._jvm.org.apache.sysml.api.dl.Utils()
+        utilObj.saveCaffeModelFile(sc._jsc, deploy_file, caffemodel_file, output_dir, format)
+
+    
+def convert_lmdb_to_jpeg(lmdb_img_file, output_dir):
+    """
+    Saves the images in the lmdb file as jpeg in the output_dir. This method requires caffe to be installed along with lmdb and cv2 package.
+    To install cv2 package, do `pip install opencv-python`.
+    
+    Parameters
+    ----------
+    lmdb_img_file: string
+        Path to the input lmdb file
+    
+    output_dir: string
+        Output directory for images (local filesystem)
+    """
+    import lmdb, caffe, cv2
+    lmdb_cursor = lmdb.open(lmdb_file, readonly=True).begin().cursor()
+    datum = caffe.proto.caffe_pb2.Datum()
+    i = 1
+    for _, value in lmdb_cursor:
+        datum.ParseFromString(value)
+        data = caffe.io.datum_to_array(datum)
+        output_file_path = os.path.join(output_dir, 'file_' + str(i) + '.jpg')
+        image = np.transpose(data, (1,2,0)) # CxHxW to HxWxC in cv2
+        cv2.imwrite(output_file_path, image)
+        i = i + 1
+
 
 def convertToLabeledDF(sparkSession, X, y=None):
     from pyspark.ml.feature import VectorAssembler