You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/08/09 20:30:44 UTC

[1/2] incubator-systemml git commit: [SYSTEMML-234] [SYSTEMML-208] Added mllearn library to support scikit-learn and MLPipeline

Repository: incubator-systemml
Updated Branches:
  refs/heads/master b62a67c0e -> f02f7c018


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
new file mode 100644
index 0000000..fd05f27
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/NaiveBayes.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.rdd.RDD
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+
+object NaiveBayes {
+  final val scriptPath = "scripts" + File.separator + "algorithms" + File.separator + "naive-bayes.dml"
+}
+
+class NaiveBayes(override val uid: String, val sc: SparkContext) extends Estimator[NaiveBayesModel] with HasLaplace with BaseSystemMLClassifier {
+  override def copy(extra: ParamMap): Estimator[NaiveBayesModel] = {
+    val that = new NaiveBayes(uid, sc)
+    copyValues(that, extra)
+  }
+  def setLaplace(value: Double) = set(laplace, value)
+  
+  // Note: will update the y_mb as this will be called by Python mllearn
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): NaiveBayesModel = {
+    val ret = fit(X_mb, y_mb, sc)
+    new NaiveBayesModel("naive")(ret._1, ret._2, sc)
+  }
+  
+  def fit(df: ScriptsUtils.SparkDataType): NaiveBayesModel = {
+    val ret = fit(df, sc)
+    new NaiveBayesModel("naive")(ret._1, ret._2, sc)
+  }
+  
+  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(NaiveBayes.scriptPath))
+      .in("$X", " ")
+      .in("$Y", " ")
+      .in("$prior", " ")
+      .in("$conditionals", " ")
+      .in("$accuracy", " ")
+      .in("$laplace", toDouble(getLaplace))
+      .out("classPrior", "classConditionals")
+    (script, "D", "C")
+  }
+}
+
+
+object NaiveBayesModel {
+  final val scriptPath = "scripts" + File.separator + "algorithms" + File.separator + "naive-bayes-predict.dml"
+}
+
+class NaiveBayesModel(override val uid: String)
+  (val mloutput: MLResults, val labelMapping: java.util.HashMap[Int, String], val sc: SparkContext) 
+  extends Model[NaiveBayesModel] with HasLaplace with BaseSystemMLClassifierModel {
+  
+  override def copy(extra: ParamMap): NaiveBayesModel = {
+    val that = new NaiveBayesModel(uid)(mloutput, labelMapping, sc)
+    copyValues(that, extra)
+  }
+  
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(NaiveBayesModel.scriptPath))
+      .in("$X", " ")
+      .in("$prior", " ")
+      .in("$conditionals", " ")
+      .in("$probabilities", " ")
+      .out("probs")
+    
+    val classPrior = mloutput.getBinaryBlockMatrix("classPrior")
+    val classConditionals = mloutput.getBinaryBlockMatrix("classConditionals")
+    val ret = if(isSingleNode) {
+      script.in("prior", classPrior.getMatrixBlock, classPrior.getMatrixMetadata)
+            .in("conditionals", classConditionals.getMatrixBlock, classConditionals.getMatrixMetadata)
+    }
+    else {
+      script.in("prior", classPrior.getBinaryBlocks, classPrior.getMatrixMetadata)
+            .in("conditionals", classConditionals.getBinaryBlocks, classConditionals.getMatrixMetadata)
+    }
+    (ret, "D")
+  }
+  
+  def transform(X: MatrixBlock): MatrixBlock = transform(X, mloutput, labelMapping, sc, "probs")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = transform(df, mloutput, labelMapping, sc, "probs")
+  
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala b/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
new file mode 100644
index 0000000..8e3893d
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/PredictionUtils.scala
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.sql.functions.udf
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.SparkContext
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext.MLResults
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+import org.apache.sysml.api.mlcontext.Script
+import org.apache.sysml.api.mlcontext.BinaryBlockMatrix
+
+object PredictionUtils {
+  
+  def getGLMPredictionScript(B_full: BinaryBlockMatrix, isSingleNode:Boolean, dfam:java.lang.Integer=1): (Script, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath))
+      .in("$X", " ")
+      .in("$B", " ")
+      .in("$dfam", dfam)
+      .out("means")
+    val ret = if(isSingleNode) {
+      script.in("B_full", B_full.getMatrixBlock, B_full.getMatrixMetadata)
+    }
+    else {
+      script.in("B_full", B_full)
+    }
+    (ret, "X")
+  }
+  
+  def fillLabelMapping(df: ScriptsUtils.SparkDataType, revLabelMapping: java.util.HashMap[Int, String]): RDD[String]  = {
+    val temp = df.select("label").distinct.rdd.map(_.apply(0).toString).collect()
+    val labelMapping = new java.util.HashMap[String, Int]
+    for(i <- 0 until temp.length) {
+      labelMapping.put(temp(i), i+1)
+      revLabelMapping.put(i+1, temp(i))
+    }
+    df.select("label").rdd.map( x => labelMapping.get(x.apply(0).toString).toString )
+  }
+  
+  def fillLabelMapping(y_mb: MatrixBlock, revLabelMapping: java.util.HashMap[Int, String]): Unit = {
+    val labelMapping = new java.util.HashMap[String, Int]
+    if(y_mb.getNumColumns != 1) {
+      throw new RuntimeException("Expected a column vector for y")
+    }
+    if(y_mb.isInSparseFormat()) {
+      throw new DMLRuntimeException("Sparse block is not implemented for fit")
+    }
+    else {
+      val denseBlock = y_mb.getDenseBlock()
+      var id:Int = 1
+      for(i <- 0 until denseBlock.length) {
+        val v = denseBlock(i).toString()
+        if(!labelMapping.containsKey(v)) {
+          labelMapping.put(v, id)
+          revLabelMapping.put(id, v)
+          id += 1
+        }
+        denseBlock.update(i, labelMapping.get(v))
+      }  
+    }
+  }
+  
+  class LabelMappingData(val labelMapping: java.util.HashMap[Int, String]) extends Serializable {
+   def mapLabelStr(x:Double):String = {
+     if(labelMapping.containsKey(x.toInt))
+       labelMapping.get(x.toInt)
+     else
+       throw new RuntimeException("Incorrect label mapping")
+   }
+   def mapLabelDouble(x:Double):Double = {
+     if(labelMapping.containsKey(x.toInt))
+       labelMapping.get(x.toInt).toDouble
+     else
+       throw new RuntimeException("Incorrect label mapping")
+   }
+   val mapLabel_udf =  {
+        try {
+          val it = labelMapping.values().iterator()
+          while(it.hasNext()) {
+            it.next().toDouble
+          }
+          udf(mapLabelDouble _)
+        } catch {
+          case e: Exception => udf(mapLabelStr _)
+        }
+      }
+  }  
+  def updateLabels(isSingleNode:Boolean, df:DataFrame, X: MatrixBlock, labelColName:String, labelMapping: java.util.HashMap[Int, String]): DataFrame = {
+    if(isSingleNode) {
+      if(X.isInSparseFormat()) {
+        throw new RuntimeException("Since predicted label is a column vector, expected it to be in dense format")
+      }
+      for(i <- 0 until X.getNumRows) {
+        val v:Int = X.getValue(i, 0).toInt
+        if(labelMapping.containsKey(v)) {
+          X.setValue(i, 0, labelMapping.get(v).toDouble)
+        }
+        else {
+          throw new RuntimeException("No mapping found for " + v + " in " + labelMapping.toString())
+        }
+      }
+      return null
+    }
+    else {
+      val serObj = new LabelMappingData(labelMapping)
+      return df.withColumn(labelColName, serObj.mapLabel_udf(df(labelColName)))
+               .withColumnRenamed(labelColName, "prediction")
+    }
+  }
+  
+  def joinUsingID(df1:DataFrame, df2:DataFrame):DataFrame = {
+    val tempDF1 = df1.withColumnRenamed("ID", "ID1")
+    tempDF1.join(df2, tempDF1.col("ID1").equalTo(df2.col("ID"))).drop("ID1")
+  }
+  
+  def computePredictedClassLabelsFromProbability(mlscoreoutput:MLResults, isSingleNode:Boolean, sc:SparkContext, inProbVar:String): MLResults = {
+    val ml = new org.apache.sysml.api.mlcontext.MLContext(sc)
+    val script = dml(
+        """
+        Prob = read("temp1");
+        Prediction = rowIndexMax(Prob); # assuming one-based label mapping
+        write(Prediction, "tempOut", "csv");
+        """).out("Prediction")
+    val probVar = mlscoreoutput.getBinaryBlockMatrix(inProbVar)
+    if(isSingleNode) {
+      ml.execute(script.in("Prob", probVar.getMatrixBlock, probVar.getMatrixMetadata))
+    }
+    else {
+      ml.execute(script.in("Prob", probVar))
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/SVM.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/SVM.scala b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
new file mode 100644
index 0000000..07a7283
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/SVM.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.rdd.RDD
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+
+object SVM {
+  final val scriptPathBinary = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm.dml"
+  final val scriptPathMulticlass = "scripts" + File.separator + "algorithms" + File.separator + "m-svm.dml"
+}
+
+class SVM (override val uid: String, val sc: SparkContext, val isMultiClass:Boolean=false) extends Estimator[SVMModel] with HasIcpt
+    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLClassifier {
+
+  def setIcpt(value: Int) = set(icpt, value)
+  def setMaxIter(value: Int) = set(maxOuterIter, value)
+  def setRegParam(value: Double) = set(regParam, value)
+  def setTol(value: Double) = set(tol, value)
+  
+  override def copy(extra: ParamMap): Estimator[SVMModel] = {
+    val that = new SVM(uid, sc, isMultiClass)
+    copyValues(that, extra)
+  }
+  
+  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(if(isMultiClass) SVM.scriptPathMulticlass else SVM.scriptPathBinary))
+      .in("$X", " ")
+      .in("$Y", " ")
+      .in("$model", " ")
+      .in("$Log", " ")
+      .in("$icpt", toDouble(getIcpt))
+      .in("$reg", toDouble(getRegParam))
+      .in("$tol", toDouble(getTol))
+      .in("$maxiter", toDouble(getMaxOuterIte))
+      .out("w")
+    (script, "X", "Y")
+  }
+  
+  // Note: will update the y_mb as this will be called by Python mllearn
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): SVMModel = {
+    val ret = fit(X_mb, y_mb, sc)
+    new SVMModel("svm")(ret._1, sc, isMultiClass, ret._2)
+  }
+  
+  def fit(df: ScriptsUtils.SparkDataType): SVMModel = {
+    val ret = fit(df, sc)
+    new SVMModel("svm")(ret._1, sc, isMultiClass, ret._2)
+  }
+  
+}
+
+object SVMModel {
+  final val predictionScriptPathBinary = "scripts" + File.separator + "algorithms" + File.separator + "l2-svm-predict.dml"
+  final val predictionScriptPathMulticlass = "scripts" + File.separator + "algorithms" + File.separator + "m-svm-predict.dml"
+}
+
+class SVMModel (override val uid: String)(val mloutput: MLResults, val sc: SparkContext, val isMultiClass:Boolean, 
+    val labelMapping: java.util.HashMap[Int, String]) extends Model[SVMModel] with BaseSystemMLClassifierModel {
+  override def copy(extra: ParamMap): SVMModel = {
+    val that = new SVMModel(uid)(mloutput, sc, isMultiClass, labelMapping)
+    copyValues(that, extra)
+  }
+  
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(if(isMultiClass) SVMModel.predictionScriptPathMulticlass else SVMModel.predictionScriptPathBinary))
+      .in("$X", " ")
+      .in("$model", " ")
+      .out("scores")
+    
+    val w = mloutput.getBinaryBlockMatrix("w")
+    val wVar = if(isMultiClass) "W" else "w"
+      
+    val ret = if(isSingleNode) {
+      script.in(wVar, w.getMatrixBlock, w.getMatrixMetadata)
+    }
+    else {
+      script.in(wVar, w)
+    }
+    (ret, "X")
+  }
+  
+  def transform(X: MatrixBlock): MatrixBlock = transform(X, mloutput, labelMapping, sc, "scores")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = transform(df, mloutput, labelMapping, sc, "scores")
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala b/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
index fdf682d..10f9d33 100644
--- a/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/ScriptsUtils.scala
@@ -26,6 +26,8 @@ import org.apache.sysml.runtime.DMLRuntimeException
 
 object ScriptsUtils {
   var systemmlHome = System.getenv("SYSTEMML_HOME")
+		  
+  type SparkDataType = org.apache.spark.sql.DataFrame // org.apache.spark.sql.Dataset[_]
 
   /**
    * set SystemML home


[2/2] incubator-systemml git commit: [SYSTEMML-234] [SYSTEMML-208] Added mllearn library to support scikit-learn and MLPipeline

Posted by ni...@apache.org.
[SYSTEMML-234] [SYSTEMML-208] Added mllearn library to support scikit-learn and MLPipeline

- Added following algorithms: LogisticRegression, LinearRegression
  (DS/CG), SVM (l2/msvm) and NaiveBayes.
- These algorithms use new MLContext.
- Added utility functions to convert NumPy arrays, SciPy sparse matrix and Pandas DF to DataFrame as well as MatrixBlock. These functions can be used when we switch Python MLContext to support new Scala MLContext.
- Added ScriptsUtils.SparkDataType type to help in migration to Spark 2.0
- Updated the documentation to specify new usage
- Added test case to test Python as well as Scala MLPipeline APIs


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/f02f7c01
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/f02f7c01
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/f02f7c01

Branch: refs/heads/master
Commit: f02f7c018b221384b5e02b13ac181122853f65ba
Parents: b62a67c
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Aug 9 13:22:59 2016 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Aug 9 13:22:59 2016 -0700

----------------------------------------------------------------------
 docs/algorithms-classification.md               | 192 +++++++++++++++
 docs/algorithms-regression.md                   |  70 ++++++
 scripts/algorithms/l2-svm.dml                   |   5 +-
 scripts/algorithms/m-svm.dml                    |   5 +-
 scripts/algorithms/naive-bayes-predict.dml      |  15 +-
 scripts/algorithms/naive-bayes.dml              |   5 +-
 .../java/org/apache/sysml/api/MLContext.java    |  20 +-
 .../java/org/apache/sysml/api/MLOutput.java     |  15 +-
 .../sysml/api/mlcontext/BinaryBlockMatrix.java  |   9 +
 .../api/mlcontext/MLContextConversionUtil.java  |  13 +-
 .../sysml/api/mlcontext/MLContextUtil.java      |  29 ++-
 .../apache/sysml/api/mlcontext/MLResults.java   |   9 +-
 .../org/apache/sysml/api/mlcontext/Matrix.java  |   2 +-
 .../org/apache/sysml/api/python/SystemML.py     | 235 ++++++++++++++++++-
 .../java/org/apache/sysml/api/python/test.py    | 178 ++++++++++++++
 .../spark/utils/RDDConverterUtilsExt.java       |  61 +++++
 .../sysml/api/ml/BaseSystemMLClassifier.scala   | 162 +++++++++++++
 .../sysml/api/ml/BaseSystemMLRegressor.scala    |  86 +++++++
 .../apache/sysml/api/ml/LinearRegression.scala  |  97 ++++++++
 .../sysml/api/ml/LogisticRegression.scala       | 158 ++++---------
 .../org/apache/sysml/api/ml/NaiveBayes.scala    | 109 +++++++++
 .../apache/sysml/api/ml/PredictionUtils.scala   | 154 ++++++++++++
 .../scala/org/apache/sysml/api/ml/SVM.scala     | 113 +++++++++
 .../org/apache/sysml/api/ml/ScriptsUtils.scala  |   2 +
 24 files changed, 1612 insertions(+), 132 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/docs/algorithms-classification.md
----------------------------------------------------------------------
diff --git a/docs/algorithms-classification.md b/docs/algorithms-classification.md
index 2488a8c..f25d78e 100644
--- a/docs/algorithms-classification.md
+++ b/docs/algorithms-classification.md
@@ -127,6 +127,17 @@ Eqs.�(1) and�(2).
 ### Usage
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+# C = 1/reg
+logistic = sml.mllearn.LogisticRegression(sqlCtx, fit_intercept=True, max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0)
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = logistic.fit(X_train, y_train).predict(X_test)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = logistic.fit(df_train).transform(df_test)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f MultiLogReg.dml
                             -nvargs X=<file>
@@ -214,6 +225,58 @@ SystemML Language Reference for details.
 ### Examples
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+# Scikit-learn way
+from sklearn import datasets, neighbors
+import SystemML as sml
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+digits = datasets.load_digits()
+X_digits = digits.data
+y_digits = digits.target + 1
+n_samples = len(X_digits)
+X_train = X_digits[:.9 * n_samples]
+y_train = y_digits[:.9 * n_samples]
+X_test = X_digits[.9 * n_samples:]
+y_test = y_digits[.9 * n_samples:]
+logistic = sml.mllearn.LogisticRegression(sqlCtx)
+print('LogisticRegression score: %f' % logistic.fit(X_train, y_train).score(X_test, y_test))
+
+# MLPipeline way
+from pyspark.ml import Pipeline
+import SystemML as sml
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+training = sqlCtx.createDataFrame([
+    (0L, "a b c d e spark", 1.0),
+    (1L, "b d", 2.0),
+    (2L, "spark f g h", 1.0),
+    (3L, "hadoop mapreduce", 2.0),
+    (4L, "b spark who", 1.0),
+    (5L, "g d a y", 2.0),
+    (6L, "spark fly", 1.0),
+    (7L, "was mapreduce", 2.0),
+    (8L, "e spark program", 1.0),
+    (9L, "a e c l", 2.0),
+    (10L, "spark compile", 1.0),
+    (11L, "hadoop software", 2.0)
+], ["id", "text", "label"])
+tokenizer = Tokenizer(inputCol="text", outputCol="words")
+hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+lr = sml.mllearn.LogisticRegression(sqlCtx)
+pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+model = pipeline.fit(training)
+test = sqlCtx.createDataFrame([
+    (12L, "spark i j k"),
+    (13L, "l m n"),
+    (14L, "mapreduce spark"),
+    (15L, "apache hadoop")], ["id", "text"])
+prediction = model.transform(test)
+prediction.show()
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f MultiLogReg.dml
                             -nvargs X=/user/ml/X.mtx
@@ -393,6 +456,17 @@ support vector machine (`y` with domain size `2`).
 **Binary-Class Support Vector Machines**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+# C = 1/reg
+svm = sml.mllearn.SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=False)
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = svm.fit(X_train, y_train)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = svm.fit(df_train)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f l2-svm.dml
                             -nvargs X=<file>
@@ -428,6 +502,14 @@ support vector machine (`y` with domain size `2`).
 **Binary-Class Support Vector Machines Prediction**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+# X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = svm.predict(X_test)
+# df_test is a DataFrame that contains the column "features" of type Vector
+y_test = svm.transform(df_test)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f l2-svm-predict.dml
                             -nvargs X=<file>
@@ -630,6 +712,17 @@ class labels.
 **Multi-Class Support Vector Machines**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+# C = 1/reg
+svm = sml.mllearn.SVM(sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=True)
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = svm.fit(X_train, y_train)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = svm.fit(df_train)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f m-svm.dml
                             -nvargs X=<file>
@@ -665,6 +758,14 @@ class labels.
 **Multi-Class Support Vector Machines Prediction**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+# X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = svm.predict(X_test)
+# df_test is a DataFrame that contains the column "features" of type Vector
+y_test = svm.transform(df_test)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f m-svm-predict.dml
                             -nvargs X=<file>
@@ -747,6 +848,58 @@ SystemML Language Reference for details.
 **Multi-Class Support Vector Machines**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+# Scikit-learn way
+from sklearn import datasets, neighbors
+import SystemML as sml
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+digits = datasets.load_digits()
+X_digits = digits.data
+y_digits = digits.target 
+n_samples = len(X_digits)
+X_train = X_digits[:.9 * n_samples]
+y_train = y_digits[:.9 * n_samples]
+X_test = X_digits[.9 * n_samples:]
+y_test = y_digits[.9 * n_samples:]
+svm = sml.mllearn.SVM(sqlCtx, is_multi_class=True)
+print('LogisticRegression score: %f' % svm.fit(X_train, y_train).score(X_test, y_test))
+
+# MLPipeline way
+from pyspark.ml import Pipeline
+import SystemML as sml
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+training = sqlCtx.createDataFrame([
+    (0L, "a b c d e spark", 1.0),
+    (1L, "b d", 2.0),
+    (2L, "spark f g h", 1.0),
+    (3L, "hadoop mapreduce", 2.0),
+    (4L, "b spark who", 1.0),
+    (5L, "g d a y", 2.0),
+    (6L, "spark fly", 1.0),
+    (7L, "was mapreduce", 2.0),
+    (8L, "e spark program", 1.0),
+    (9L, "a e c l", 2.0),
+    (10L, "spark compile", 1.0),
+    (11L, "hadoop software", 2.0)
+], ["id", "text", "label"])
+tokenizer = Tokenizer(inputCol="text", outputCol="words")
+hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+svm = sml.mllearn.SVM(sqlCtx, is_multi_class=True)
+pipeline = Pipeline(stages=[tokenizer, hashingTF, svm])
+model = pipeline.fit(training)
+test = sqlCtx.createDataFrame([
+    (12L, "spark i j k"),
+    (13L, "l m n"),
+    (14L, "mapreduce spark"),
+    (15L, "apache hadoop")], ["id", "text"])
+prediction = model.transform(test)
+prediction.show()
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f m-svm.dml
                             -nvargs X=/user/ml/X.mtx
@@ -871,6 +1024,16 @@ applicable when all features are counts of categorical values.
 **Naive Bayes**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+nb = sml.mllearn.NaiveBayes(sqlCtx, laplace=1.0)
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = nb.fit(X_train, y_train)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = nb.fit(df_train)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f naive-bayes.dml
                             -nvargs X=<file>
@@ -902,6 +1065,14 @@ applicable when all features are counts of categorical values.
 **Naive Bayes Prediction**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+# X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = nb.predict(X_test)
+# df_test is a DataFrame that contains the column "features" of type Vector
+y_test = nb.transform(df_test)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f naive-bayes-predict.dml
                             -nvargs X=<file>
@@ -974,6 +1145,27 @@ SystemML Language Reference for details.
 **Naive Bayes**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.feature_extraction.text import TfidfVectorizer
+import SystemML as sml
+from sklearn import metrics
+from pyspark.sql import SQLContext
+sqlCtx = SQLContext(sc)
+categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
+newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
+newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
+vectorizer = TfidfVectorizer()
+# Both vectors and vectors_test are SciPy CSR matrix
+vectors = vectorizer.fit_transform(newsgroups_train.data)
+vectors_test = vectorizer.transform(newsgroups_test.data)
+nb = sml.mllearn.NaiveBayes(sqlCtx)
+nb.fit(vectors, newsgroups_train.target)
+pred = nb.predict(vectors_test)
+metrics.f1_score(newsgroups_test.target, pred, average='weighted')
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f naive-bayes.dml
                             -nvargs X=/user/ml/X.mtx

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/docs/algorithms-regression.md
----------------------------------------------------------------------
diff --git a/docs/algorithms-regression.md b/docs/algorithms-regression.md
index 6472c17..5241f5f 100644
--- a/docs/algorithms-regression.md
+++ b/docs/algorithms-regression.md
@@ -80,6 +80,17 @@ efficient when the number of features $m$ is relatively small
 **Linear Regression - Direct Solve**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+# C = 1/reg
+lr = sml.mllearn.LinearRegression(sqlCtx, fit_intercept=True, C=1.0, solver='direct-solve')
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrame or SciPy Sparse Matrix
+y_test = lr.fit(X_train, y_train)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = lr.fit(df_train)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f LinearRegDS.dml
                             -nvargs X=<file>
@@ -111,6 +122,17 @@ efficient when the number of features $m$ is relatively small
 **Linear Regression - Conjugate Gradient**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import SystemML as sml
+# C = 1/reg
+lr = sml.mllearn.LinearRegression(sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, solver='newton-cg')
+# X_train, y_train and X_test can be NumPy matrices or Pandas DataFrames or SciPy Sparse matrices
+y_test = lr.fit(X_train, y_train)
+# df_train is DataFrame that contains two columns: "features" (of type Vector) and "label". df_test is a DataFrame that contains the column "features"
+y_test = lr.fit(df_train)
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f LinearRegCG.dml
                             -nvargs X=<file>
@@ -196,6 +218,30 @@ SystemML Language Reference for details.
 **Linear Regression - Direct Solve**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import numpy as np
+from sklearn import datasets
+import SystemML as sml
+from pyspark.sql import SQLContext
+# Load the diabetes dataset
+diabetes = datasets.load_diabetes()
+# Use only one feature
+diabetes_X = diabetes.data[:, np.newaxis, 2]
+# Split the data into training/testing sets
+diabetes_X_train = diabetes_X[:-20]
+diabetes_X_test = diabetes_X[-20:]
+# Split the targets into training/testing sets
+diabetes_y_train = diabetes.target[:-20]
+diabetes_y_test = diabetes.target[-20:]
+# Create linear regression object
+regr = sml.mllearn.LinearRegression(sqlCtx, solver='direct-solve')
+# Train the model using the training sets
+regr.fit(diabetes_X_train, diabetes_y_train)
+# The mean square error
+print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f LinearRegDS.dml
                             -nvargs X=/user/ml/X.mtx
@@ -227,6 +273,30 @@ SystemML Language Reference for details.
 **Linear Regression - Conjugate Gradient**:
 
 <div class="codetabs">
+<div data-lang="Python" markdown="1">
+{% highlight python %}
+import numpy as np
+from sklearn import datasets
+import SystemML as sml
+from pyspark.sql import SQLContext
+# Load the diabetes dataset
+diabetes = datasets.load_diabetes()
+# Use only one feature
+diabetes_X = diabetes.data[:, np.newaxis, 2]
+# Split the data into training/testing sets
+diabetes_X_train = diabetes_X[:-20]
+diabetes_X_test = diabetes_X[-20:]
+# Split the targets into training/testing sets
+diabetes_y_train = diabetes.target[:-20]
+diabetes_y_test = diabetes.target[-20:]
+# Create linear regression object
+regr = sml.mllearn.LinearRegression(sqlCtx, solver='newton-cg')
+# Train the model using the training sets
+regr.fit(diabetes_X_train, diabetes_y_train)
+# The mean square error
+print("Residual sum of squares: %.2f" % np.mean((regr.predict(diabetes_X_test) - diabetes_y_test) ** 2))
+{% endhighlight %}
+</div>
 <div data-lang="Hadoop" markdown="1">
     hadoop jar SystemML.jar -f LinearRegCG.dml
                             -nvargs X=/user/ml/X.mtx

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/scripts/algorithms/l2-svm.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/l2-svm.dml b/scripts/algorithms/l2-svm.dml
index fa40418..1117c71 100644
--- a/scripts/algorithms/l2-svm.dml
+++ b/scripts/algorithms/l2-svm.dml
@@ -160,4 +160,7 @@ extra_model_params[4,1] = dimensions
 w = t(append(t(w), t(extra_model_params)))
 write(w, $model, format=cmdLine_fmt)
 
-write(debug_str, $Log)
+logFile = $Log
+if(logFile != " ") {
+	write(debug_str, logFile)
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/scripts/algorithms/m-svm.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/m-svm.dml b/scripts/algorithms/m-svm.dml
index e4a7cad..04f8a76 100644
--- a/scripts/algorithms/m-svm.dml
+++ b/scripts/algorithms/m-svm.dml
@@ -175,4 +175,7 @@ for(iter_class in 1:ncol(debug_mat)){
 			debug_str = append(debug_str, iter_class + "," + iter + "," + obj)
 	}
 }
-write(debug_str, $Log)
+logFile = $Log
+if(logFile != " ") {
+	write(debug_str, logFile)
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/scripts/algorithms/naive-bayes-predict.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/naive-bayes-predict.dml b/scripts/algorithms/naive-bayes-predict.dml
index e6f8fa4..b687bfa 100644
--- a/scripts/algorithms/naive-bayes-predict.dml
+++ b/scripts/algorithms/naive-bayes-predict.dml
@@ -28,7 +28,6 @@
 cmdLine_Y = ifdef($Y, " ")
 cmdLine_accuracy = ifdef($accuracy, " ")
 cmdLine_confusion = ifdef($confusion, " ")
-cmdLine_probabilities = ifdef($probabilities, " ")
 cmdLine_fmt = ifdef($fmt, "text")
 
 D = read($X)
@@ -51,13 +50,13 @@ model = append(conditionals, prior)
 
 log_probs = D_w_ones %*% t(log(model))
 
-if(cmdLine_probabilities != " "){
-	mx = rowMaxs(log_probs)
-	ones = matrix(1, rows=1, cols=nrow(prior))
-	probs = log_probs - mx %*% ones
-	probs = exp(probs)/(rowSums(exp(probs)) %*% ones)
-	write(probs, cmdLine_probabilities, format=cmdLine_fmt)
-}
+
+mx = rowMaxs(log_probs)
+ones = matrix(1, rows=1, cols=nrow(prior))
+probs = log_probs - mx %*% ones
+probs = exp(probs)/(rowSums(exp(probs)) %*% ones)
+write(probs, $probabilities, format=cmdLine_fmt)
+
 
 if(cmdLine_Y != " "){
 	C = read(cmdLine_Y)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/scripts/algorithms/naive-bayes.dml
----------------------------------------------------------------------
diff --git a/scripts/algorithms/naive-bayes.dml b/scripts/algorithms/naive-bayes.dml
index a01a5fc..c1dc44c 100644
--- a/scripts/algorithms/naive-bayes.dml
+++ b/scripts/algorithms/naive-bayes.dml
@@ -74,7 +74,10 @@ acc = sum(rowIndexMax(logProbs) == C) / numRows * 100
 
 acc_str = "Training Accuracy (%): " + acc
 print(acc_str)
-write(acc, $accuracy)
+accuracyFile = $accuracy
+if(accuracyFile != " ") {
+	write(acc, accuracyFile)
+}
 
 extraModelParams = as.matrix(numFeatures)
 classPrior = rbind(classPrior, extraModelParams)

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index a03c8b7..d8a290d 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -600,6 +600,24 @@ public class MLContext {
 		checkIfRegisteringInputAllowed();
 	}
 	
+	public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
+		MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, mb.getNonZeros());
+		registerInput(varName, mb, mc);
+	}
+	
+	public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
+		if(_variables == null)
+			_variables = new LocalVariableMap();
+		if(_inVarnames == null)
+			_inVarnames = new ArrayList<String>();
+		MatrixObject mo = new MatrixObject(ValueType.DOUBLE, "temp", new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
+		mo.acquireModify(mb); 
+		mo.release();
+		_variables.put(varName, mo);
+		_inVarnames.add(varName);
+		checkIfRegisteringInputAllowed();
+	}
+	
 	// =============================================================================================
 	
 	/**
@@ -1457,4 +1475,4 @@ public class MLContext {
 //		return MLMatrix.createMLMatrix(this, sqlContext, blocks, mc);
 //	}
 	
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/MLOutput.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLOutput.java b/src/main/java/org/apache/sysml/api/MLOutput.java
index a3e6019..55daf17 100644
--- a/src/main/java/org/apache/sysml/api/MLOutput.java
+++ b/src/main/java/org/apache/sysml/api/MLOutput.java
@@ -39,6 +39,7 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -54,11 +55,17 @@ import scala.Tuple2;
  */
 public class MLOutput {
 	
-	
-	
 	Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs;
 	private Map<String, MatrixCharacteristics> _outMetadata = null;
 	
+	public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
+		MatrixCharacteristics mc = getMatrixCharacteristics(varName);
+		// The matrix block is always pushed to an RDD and then we do collect
+		// We can later avoid this by returning symbol table rather than "Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs"
+		MatrixBlock mb = SparkExecutionContext.toMatrixBlock(getBinaryBlockedRDD(varName), (int) mc.getRows(), (int) mc.getCols(), 
+				mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
+		return mb;
+	}
 	public MLOutput(Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
 		this._outputs = outputs;
 		this._outMetadata = outMetadata;
@@ -238,7 +245,7 @@ public class MLOutput {
     		int lclen = UtilFunctions.computeBlockSize(clen, blockColIndex, bclen);
     		// ------------------------------------------------------------------
 			
-			long startRowIndex = (kv._1.getRowIndex()-1) * bclen;
+			long startRowIndex = (kv._1.getRowIndex()-1) * bclen + 1;
 			MatrixBlock blk = kv._2;
 			ArrayList<Tuple2<Long, Tuple2<Long, Double[]>>> retVal = new ArrayList<Tuple2<Long,Tuple2<Long,Double[]>>>();
 			for(int i = 0; i < lrlen; i++) {
@@ -410,4 +417,4 @@ public class MLOutput {
 			return RowFactory.create(row);
 		}
 	}
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
index 8c9f923..ea6fcf0 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java
@@ -21,6 +21,8 @@ package org.apache.sysml.api.mlcontext;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.sql.DataFrame;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
@@ -97,6 +99,13 @@ public class BinaryBlockMatrix {
 	public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() {
 		return binaryBlocks;
 	}
+	
+	public MatrixBlock getMatrixBlock() throws DMLRuntimeException {
+		MatrixCharacteristics mc = getMatrixCharacteristics();
+		MatrixBlock mb = SparkExecutionContext.toMatrixBlock(binaryBlocks, (int) mc.getRows(), (int) mc.getCols(), 
+				mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
+		return mb;
+	}
 
 	/**
 	 * Obtain the SystemML binary-block matrix characteristics

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index 33226d2..161ad17 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -676,7 +676,7 @@ public class MLContextConversionUtil {
 	 * @return the {@code MatrixObject} converted to a {@code DataFrame}
 	 */
 	public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject,
-			SparkExecutionContext sparkExecutionContext) {
+			SparkExecutionContext sparkExecutionContext, boolean isVectorDF) {
 		try {
 			@SuppressWarnings("unchecked")
 			JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockMatrix = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext
@@ -686,8 +686,17 @@ public class MLContextConversionUtil {
 			MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext();
 			SparkContext sc = activeMLContext.getSparkContext();
 			SQLContext sqlContext = new SQLContext(sc);
-			DataFrame df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
+			DataFrame df = null;
+			if(isVectorDF) {
+				df = RDDConverterUtilsExt.binaryBlockToVectorDataFrame(binaryBlockMatrix, matrixCharacteristics,
+						sqlContext);
+			}
+			else {
+				df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics,
 					sqlContext);
+			}
+			
+			
 			return df;
 		} catch (DMLRuntimeException e) {
 			throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index feb616e..fc942e9 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -44,7 +44,9 @@ import org.apache.sysml.conf.CompilerConfig.ConfigType;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.conf.DMLConfig;
 import org.apache.sysml.parser.ParseException;
+import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.caching.CacheException;
 import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.instructions.cp.BooleanObject;
@@ -52,8 +54,12 @@ import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.DoubleObject;
 import org.apache.sysml.runtime.instructions.cp.IntObject;
 import org.apache.sysml.runtime.instructions.cp.StringObject;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
 
 /**
  * Utility class containing methods for working with the MLContext API.
@@ -72,7 +78,7 @@ public final class MLContextUtil {
 	 */
 	@SuppressWarnings("rawtypes")
 	public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, DataFrame.class,
-			BinaryBlockMatrix.class, Matrix.class, (new double[][] {}).getClass() };
+			BinaryBlockMatrix.class, Matrix.class, (new double[][] {}).getClass(), MatrixBlock.class };
 
 	/**
 	 * All data types supported by the MLContext API
@@ -391,6 +397,8 @@ public final class MLContextUtil {
 				convertedMap.put(key, Double.toString((Double) value));
 			} else if (value instanceof String) {
 				convertedMap.put(key, (String) value);
+			} else {
+				throw new MLContextException("Incorrect type for input parameters");
 			}
 		}
 		return convertedMap;
@@ -448,7 +456,24 @@ public final class MLContextUtil {
 			}
 
 			return matrixObject;
-		} else if (value instanceof DataFrame) {
+		} else if (value instanceof MatrixBlock) {
+			MatrixCharacteristics matrixCharacteristics;
+			if (matrixMetadata != null) {
+				matrixCharacteristics = matrixMetadata.asMatrixCharacteristics();
+			} else {
+				matrixCharacteristics = new MatrixCharacteristics();
+			}
+			MatrixFormatMetaData mtd = new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
+			MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/" + name, mtd);
+			try {
+				matrixObject.acquireModify((MatrixBlock)value);
+				matrixObject.release();
+			} catch (CacheException e) {
+				throw new MLContextException(e);
+			}
+			return matrixObject;
+		}
+		else if (value instanceof DataFrame) {
 			DataFrame dataFrame = (DataFrame) value;
 			MatrixObject matrixObject = MLContextConversionUtil
 					.dataFrameToMatrixObject(name, dataFrame, matrixMetadata);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
index bd1b6bc..582a73e 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java
@@ -255,7 +255,13 @@ public class MLResults {
 	 */
 	public DataFrame getDataFrame(String outputName) {
 		MatrixObject mo = getMatrixObject(outputName);
-		DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext);
+		DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false);
+		return df;
+	}
+	
+	public DataFrame getDataFrame(String outputName, boolean isVectorDF) {
+		MatrixObject mo = getMatrixObject(outputName);
+		DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF);
 		return df;
 	}
 
@@ -271,6 +277,7 @@ public class MLResults {
 		Matrix matrix = new Matrix(mo, sparkExecutionContext);
 		return matrix;
 	}
+	
 
 	/**
 	 * Obtain an output as a {@code BinaryBlockMatrix}.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
index 178a6e5..3ee41b7 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/Matrix.java
@@ -108,7 +108,7 @@ public class Matrix {
 	 * @return the matrix as a {@code DataFrame}
 	 */
 	public DataFrame asDataFrame() {
-		DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext);
+		DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(matrixObject, sparkExecutionContext, false);
 		return df;
 	}
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/python/SystemML.py
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/python/SystemML.py b/src/main/java/org/apache/sysml/api/python/SystemML.py
index 8ad3117..689403e 100644
--- a/src/main/java/org/apache/sysml/api/python/SystemML.py
+++ b/src/main/java/org/apache/sysml/api/python/SystemML.py
@@ -20,12 +20,25 @@
 #
 #-------------------------------------------------------------
 
+from __future__ import division
 from py4j.protocol import Py4JJavaError, Py4JError
 import traceback
 import os
+from pyspark.context import SparkContext 
 from pyspark.sql import DataFrame, SQLContext
 from pyspark.rdd import RDD
+import numpy as np
+import pandas as pd
+import sklearn as sk
+from sklearn import metrics
+from pyspark.ml.feature import VectorAssembler
+from pyspark.mllib.linalg import Vectors
+import sys
+from pyspark.ml import Estimator, Model
+from scipy.sparse import spmatrix
+from scipy.sparse import coo_matrix
 
+SUPPORTED_TYPES = (np.ndarray, pd.DataFrame, spmatrix)
 
 class MLContext(object):
 
@@ -57,6 +70,7 @@ class MLContext(object):
             setForcedSparkExecType = (args[1] if len(args) > 1 else False)
             self.sc = sc
             self.ml = sc._jvm.org.apache.sysml.api.MLContext(sc._jsc, monitorPerformance, setForcedSparkExecType)
+            self.sqlCtx = SQLContext(sc)
         except Py4JError:
             traceback.print_exc()
 
@@ -171,7 +185,6 @@ class MLContext(object):
             else:
                 raise TypeError('Arguments do not match MLContext-API')
         except Py4JJavaError:
-
             traceback.print_exc()
 
     def registerOutput(self, varName):
@@ -232,6 +245,10 @@ class MLOutput(object):
         except Py4JJavaError:
             traceback.print_exc()
 
+    def getPandasDF(self, sqlContext, varName):
+        df = self.toDF(sqlContext, varName).sort('ID').drop('ID')
+        return df.toPandas()
+        
     def getMLMatrix(self, sqlContext, varName):
         raise Exception('Not supported in Python MLContext')
         #try:
@@ -247,3 +264,219 @@ class MLOutput(object):
         #    return rdd
         #except Py4JJavaError:
         #    traceback.print_exc()
+
+def getNumCols(numPyArr):
+    if numPyArr.ndim == 1:
+        return 1
+    else:
+        return numPyArr.shape[1]
+       
+def convertToMatrixBlock(sc, src):
+    if isinstance(src, spmatrix):
+        src = coo_matrix(src,  dtype=np.float64)
+        numRows = src.shape[0]
+        numCols = src.shape[1]
+        data = src.data
+        row = src.row.astype(np.int32)
+        col = src.col.astype(np.int32)
+        nnz = len(src.col)
+        buf1 = bytearray(data.tostring())
+        buf2 = bytearray(row.tostring())
+        buf3 = bytearray(col.tostring())
+        return sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.convertSciPyCOOToMB(buf1, buf2, buf3, numRows, numCols, nnz)
+    elif isinstance(sc, SparkContext):
+        src = np.asarray(src)
+        numCols = getNumCols(src)
+        numRows = src.shape[0]
+        arr = src.ravel().astype(np.float64)
+        buf = bytearray(arr.tostring())
+        return sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.convertPy4JArrayToMB(buf, numRows, numCols)
+    else:
+        raise TypeError('sc needs to be of type SparkContext') # TODO: We can generalize this by creating py4j gateway ourselves
+    
+
+def convertToNumpyArr(sc, mb):
+    if isinstance(sc, SparkContext):
+        numRows = mb.getNumRows()
+        numCols = mb.getNumColumns()
+        buf = sc._jvm.org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.convertMBtoPy4JDenseArr(mb)
+        return np.frombuffer(buf, count=numRows*numCols, dtype=np.float64)
+    else:
+        raise TypeError('sc needs to be of type SparkContext') # TODO: We can generalize this by creating py4j gateway ourselves
+
+def convertToPandasDF(X):
+    if not isinstance(X, pd.DataFrame):
+        return pd.DataFrame(X, columns=['C' + str(i) for i in range(getNumCols(X))])
+    return X
+            
+def tolist(inputCols):
+    return list(inputCols)
+
+def assemble(sqlCtx, pdf, inputCols, outputCol):
+    tmpDF = sqlCtx.createDataFrame(pdf, tolist(pdf.columns))
+    assembler = VectorAssembler(inputCols=tolist(inputCols), outputCol=outputCol)
+    return assembler.transform(tmpDF)
+
+class mllearn:
+    class BaseSystemMLEstimator(Estimator):
+    # TODO: Allow users to set featuresCol (with default 'features') and labelCol (with default 'label')
+    
+        # Returns a model after calling fit(df) on Estimator object on JVM    
+        def _fit(self, X):
+            if hasattr(X, '_jdf') and 'features' in X.columns and 'label' in X.columns:
+                self.model = self.estimator.fit(X._jdf)
+                return self
+            else:
+                raise Exception('Incorrect usage: Expected dataframe as input with features/label as columns')
+        
+        # Returns a model after calling fit(X:MatrixBlock, y:MatrixBlock) on Estimator object on JVM  
+        def fit(self, X, y=None, params=None):
+            if y is None:
+                return self._fit(X)
+            elif y is not None and isinstance(X, SUPPORTED_TYPES):
+                if self.transferUsingDF:
+                    pdfX = convertToPandasDF(X)
+                    pdfY = convertToPandasDF(y)
+                    if getNumCols(pdfY) != 1:
+                        raise Exception('y should be a column vector')
+                    if pdfX.shape[0] != pdfY.shape[0]:
+                        raise Exception('Number of rows of X and y should match')
+                    colNames = pdfX.columns
+                    pdfX['label'] = pdfY[pdfY.columns[0]]
+                    df = assemble(self.sqlCtx, pdfX, colNames, 'features').select('features', 'label')
+                    self.model = self.estimator.fit(df._jdf)
+                else:
+                    numColsy = getNumCols(y)
+                    if numColsy != 1:
+                        raise Exception('Expected y to be a column vector')
+                    self.model = self.estimator.fit(convertToMatrixBlock(self.sc, X), convertToMatrixBlock(self.sc, y))
+                if self.setOutputRawPredictionsToFalse:
+                    self.model.setOutputRawPredictions(False)
+                return self
+            else:
+                raise Exception('Unsupported input type')
+        
+        def transform(self, X):
+            return self.predict(X)
+        
+        # Returns either a DataFrame or MatrixBlock after calling transform(X:MatrixBlock, y:MatrixBlock) on Model object on JVM    
+        def predict(self, X):
+            if isinstance(X, SUPPORTED_TYPES):
+                if self.transferUsingDF:
+                    pdfX = convertToPandasDF(X)
+                    df = assemble(self.sqlCtx, pdfX, pdfX.columns, 'features').select('features')
+                    retjDF = self.model.transform(df._jdf)
+                    retDF = DataFrame(retjDF, self.sqlCtx)
+                    retPDF = retDF.sort('ID').select('prediction').toPandas()
+                    if isinstance(X, np.ndarray):
+                        return retPDF.as_matrix().flatten()
+                    else:
+                        return retPDF
+                else:
+                    retNumPy = convertToNumpyArr(self.sc, self.model.transform(convertToMatrixBlock(self.sc, X)))
+                    if isinstance(X, np.ndarray):
+                        return retNumPy
+                    else:
+                        return retNumPy # TODO: Convert to Pandas
+            elif hasattr(X, '_jdf'):
+                if 'features' in X.columns:
+                    # No need to assemble as input DF is likely coming via MLPipeline
+                    df = X
+                else:
+                    assembler = VectorAssembler(inputCols=X.columns, outputCol='features')
+                    df = assembler.transform(X)
+                retjDF = self.model.transform(df._jdf)
+                retDF = DataFrame(retjDF, self.sqlCtx)
+                # Return DF
+                return retDF.sort('ID')
+            else:
+                raise Exception('Unsupported input type')
+                
+    class BaseSystemMLClassifier(BaseSystemMLEstimator):
+
+        # Scores the predicted value with ground truth 'y'
+        def score(self, X, y):
+            return metrics.accuracy_score(y, self.predict(X))    
+    
+    class BaseSystemMLRegressor(BaseSystemMLEstimator):
+
+        # Scores the predicted value with ground truth 'y'
+        def score(self, X, y):
+            return metrics.r2_score(y, self.predict(X), multioutput='variance_weighted')
+
+    
+    # Or we can create new Python project with package structure
+    class LogisticRegression(BaseSystemMLClassifier):
+
+        # See https://apache.github.io/incubator-systemml/algorithms-reference for usage
+        def __init__(self, sqlCtx, penalty='l2', fit_intercept=True, max_iter=100, max_inner_iter=0, tol=0.000001, C=1.0, solver='newton-cg', transferUsingDF=False):
+            self.sqlCtx = sqlCtx
+            self.sc = sqlCtx._sc
+            self.uid = "logReg"
+            self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LogisticRegression(self.uid, self.sc._jsc.sc())
+            self.estimator.setMaxOuterIter(max_iter)
+            self.estimator.setMaxInnerIter(max_inner_iter)
+            if C <= 0:
+                raise Exception('C has to be positive')
+            reg = 1.0 / C
+            self.estimator.setRegParam(reg)
+            self.estimator.setTol(tol)
+            self.estimator.setIcpt(int(fit_intercept))
+            self.transferUsingDF = transferUsingDF
+            self.setOutputRawPredictionsToFalse = True
+            if penalty != 'l2':
+                raise Exception('Only l2 penalty is supported')
+            if solver != 'newton-cg':
+                raise Exception('Only newton-cg solver supported')
+
+    class LinearRegression(BaseSystemMLRegressor):
+
+        # See https://apache.github.io/incubator-systemml/algorithms-reference for usage
+        def __init__(self, sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, solver='newton-cg', transferUsingDF=False):
+            self.sqlCtx = sqlCtx
+            self.sc = sqlCtx._sc
+            self.uid = "lr"
+            if solver == 'newton-cg' or solver == 'direct-solve':
+                self.estimator = self.sc._jvm.org.apache.sysml.api.ml.LinearRegression(self.uid, self.sc._jsc.sc(), solver)
+            else:
+                raise Exception('Only newton-cg solver supported')
+            self.estimator.setMaxIter(max_iter)
+            if C <= 0:
+                raise Exception('C has to be positive')
+            reg = 1.0 / C
+            self.estimator.setRegParam(reg)
+            self.estimator.setTol(tol)
+            self.estimator.setIcpt(int(fit_intercept))
+            self.transferUsingDF = transferUsingDF
+            self.setOutputRawPredictionsToFalse = False
+
+
+    class SVM(BaseSystemMLClassifier):
+
+        # See https://apache.github.io/incubator-systemml/algorithms-reference for usage
+        def __init__(self, sqlCtx, fit_intercept=True, max_iter=100, tol=0.000001, C=1.0, is_multi_class=False, transferUsingDF=False):
+            self.sqlCtx = sqlCtx
+            self.sc = sqlCtx._sc
+            self.uid = "svm"
+            self.estimator = self.sc._jvm.org.apache.sysml.api.ml.SVM(self.uid, self.sc._jsc.sc(), is_multi_class)
+            self.estimator.setMaxIter(max_iter)
+            if C <= 0:
+                raise Exception('C has to be positive')
+            reg = 1.0 / C
+            self.estimator.setRegParam(reg)
+            self.estimator.setTol(tol)
+            self.estimator.setIcpt(int(fit_intercept))
+            self.transferUsingDF = transferUsingDF
+            self.setOutputRawPredictionsToFalse = False    
+
+    class NaiveBayes(BaseSystemMLClassifier):
+
+        # See https://apache.github.io/incubator-systemml/algorithms-reference for usage
+        def __init__(self, sqlCtx, laplace=1.0, transferUsingDF=False):
+            self.sqlCtx = sqlCtx
+            self.sc = sqlCtx._sc
+            self.uid = "nb"
+            self.estimator = self.sc._jvm.org.apache.sysml.api.ml.NaiveBayes(self.uid, self.sc._jsc.sc())
+            self.estimator.setLaplace(laplace)
+            self.transferUsingDF = transferUsingDF
+            self.setOutputRawPredictionsToFalse = False
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/api/python/test.py
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/python/test.py b/src/main/java/org/apache/sysml/api/python/test.py
new file mode 100644
index 0000000..21a1f79
--- /dev/null
+++ b/src/main/java/org/apache/sysml/api/python/test.py
@@ -0,0 +1,178 @@
+#!/usr/bin/python
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+from sklearn import datasets, neighbors
+import SystemML as sml
+from pyspark.sql import SQLContext
+from pyspark.context import SparkContext
+import unittest
+from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+from pyspark.ml import Pipeline
+from pyspark.ml.feature import HashingTF, Tokenizer
+import numpy as np
+from sklearn.datasets import fetch_20newsgroups
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn import metrics
+
+sc = SparkContext()
+sqlCtx = SQLContext(sc)
+
+# Currently not integrated with JUnit test
+# ~/spark-1.6.1-scala-2.11/bin/spark-submit --master local[*] --driver-class-path SystemML.jar test.py
+class TestMLLearn(unittest.TestCase):
+    def testLogisticSK1(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:.9 * n_samples]
+        y_train = y_digits[:.9 * n_samples]
+        X_test = X_digits[.9 * n_samples:]
+        y_test = y_digits[.9 * n_samples:]
+        logistic = sml.mllearn.LogisticRegression(sqlCtx)
+        score = logistic.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+        
+    def testLogisticSK2(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:.9 * n_samples]
+        y_train = y_digits[:.9 * n_samples]
+        X_test = X_digits[.9 * n_samples:]
+        y_test = y_digits[.9 * n_samples:]
+        # Convert to DataFrame for i/o: current way to transfer data
+        logistic = sml.mllearn.LogisticRegression(sqlCtx, transferUsingDF=True)
+        score = logistic.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    def testLogisticMLPipeline1(self):
+        training = sqlCtx.createDataFrame([
+            (0L, "a b c d e spark", 1.0),
+            (1L, "b d", 2.0),
+            (2L, "spark f g h", 1.0),
+            (3L, "hadoop mapreduce", 2.0),
+            (4L, "b spark who", 1.0),
+            (5L, "g d a y", 2.0),
+            (6L, "spark fly", 1.0),
+            (7L, "was mapreduce", 2.0),
+            (8L, "e spark program", 1.0),
+            (9L, "a e c l", 2.0),
+            (10L, "spark compile", 1.0),
+            (11L, "hadoop software", 2.0)
+            ], ["id", "text", "label"])
+        tokenizer = Tokenizer(inputCol="text", outputCol="words")
+        hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
+        lr = sml.mllearn.LogisticRegression(sqlCtx)
+        pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+        model = pipeline.fit(training)
+        test = sqlCtx.createDataFrame([
+            (12L, "spark i j k", 1.0),
+            (13L, "l m n", 2.0),
+            (14L, "mapreduce spark", 1.0),
+            (15L, "apache hadoop", 2.0)], ["id", "text", "label"])
+        result = model.transform(test)
+        predictionAndLabels = result.select("prediction", "label")
+        evaluator = MulticlassClassificationEvaluator()
+        score = evaluator.evaluate(predictionAndLabels)
+        self.failUnless(score == 1.0)
+
+    def testLinearRegressionSK1(self):
+        diabetes = datasets.load_diabetes()
+        diabetes_X = diabetes.data[:, np.newaxis, 2]
+        diabetes_X_train = diabetes_X[:-20]
+        diabetes_X_test = diabetes_X[-20:]
+        diabetes_y_train = diabetes.target[:-20]
+        diabetes_y_test = diabetes.target[-20:]
+        regr = sml.mllearn.LinearRegression(sqlCtx)
+        regr.fit(diabetes_X_train, diabetes_y_train)
+        score = regr.score(diabetes_X_test, diabetes_y_test)
+        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
+
+    def testLinearRegressionSK2(self):
+        diabetes = datasets.load_diabetes()
+        diabetes_X = diabetes.data[:, np.newaxis, 2]
+        diabetes_X_train = diabetes_X[:-20]
+        diabetes_X_test = diabetes_X[-20:]
+        diabetes_y_train = diabetes.target[:-20]
+        diabetes_y_test = diabetes.target[-20:]
+        regr = sml.mllearn.LinearRegression(sqlCtx, transferUsingDF=True)
+        regr.fit(diabetes_X_train, diabetes_y_train)
+        score = regr.score(diabetes_X_test, diabetes_y_test)
+        self.failUnless(score > 0.4) # TODO: Improve r2-score (may be I am using it incorrectly)
+
+    def testSVMSK1(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:.9 * n_samples]
+        y_train = y_digits[:.9 * n_samples]
+        X_test = X_digits[.9 * n_samples:]
+        y_test = y_digits[.9 * n_samples:]
+        svm = sml.mllearn.SVM(sqlCtx, is_multi_class=True)
+        score = svm.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    def testSVMSK2(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:.9 * n_samples]
+        y_train = y_digits[:.9 * n_samples]
+        X_test = X_digits[.9 * n_samples:]
+        y_test = y_digits[.9 * n_samples:]
+        svm = sml.mllearn.SVM(sqlCtx, is_multi_class=True, transferUsingDF=True)
+        score = svm.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.9)
+
+    def testNaiveBayesSK1(self):
+        digits = datasets.load_digits()
+        X_digits = digits.data
+        y_digits = digits.target
+        n_samples = len(X_digits)
+        X_train = X_digits[:.9 * n_samples]
+        y_train = y_digits[:.9 * n_samples]
+        X_test = X_digits[.9 * n_samples:]
+        y_test = y_digits[.9 * n_samples:]
+        nb = sml.mllearn.NaiveBayes(sqlCtx)
+        score = nb.fit(X_train, y_train).score(X_test, y_test)
+        self.failUnless(score > 0.85)
+
+    def testNaiveBayesSK2(self):
+        categories = ['alt.atheism', 'talk.religion.misc', 'comp.graphics', 'sci.space']
+        newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)
+        newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)
+        vectorizer = TfidfVectorizer()
+        # Both vectors and vectors_test are SciPy CSR matrix
+        vectors = vectorizer.fit_transform(newsgroups_train.data)
+        vectors_test = vectorizer.transform(newsgroups_test.data)
+        nb = sml.mllearn.NaiveBayes(sqlCtx)
+        nb.fit(vectors, newsgroups_train.target)
+        pred = nb.predict(vectors_test)
+        score = metrics.f1_score(newsgroups_test.target, pred, average='weighted')
+        self.failUnless(score > 0.8)
+        
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index f022e40..72ab230 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -46,6 +46,8 @@ import org.apache.spark.sql.SQLContext;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
 
 import scala.Tuple2;
 
@@ -260,6 +262,65 @@ public class RDDConverterUtilsExt
 		return dataFrameToBinaryBlock(sc, df, mcOut, false, columns);
 	}
 	
+	public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen) throws DMLRuntimeException {
+		return convertPy4JArrayToMB(data, rlen, clen, false);
+	}
+	
+	public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, int rlen, int clen, int nnz) throws DMLRuntimeException {
+		MatrixBlock mb = new MatrixBlock(rlen, clen, true);
+		mb.allocateSparseRowsBlock(false);
+		ByteBuffer buf1 = ByteBuffer.wrap(data);
+		buf1.order(ByteOrder.nativeOrder());
+		ByteBuffer buf2 = ByteBuffer.wrap(row);
+		buf2.order(ByteOrder.nativeOrder());
+		ByteBuffer buf3 = ByteBuffer.wrap(col);
+		buf3.order(ByteOrder.nativeOrder());
+		for(int i = 0; i < nnz; i++) {
+			double val = buf1.getDouble();
+			int rowIndex = buf2.getInt();
+			int colIndex = buf3.getInt();
+			mb.setValue(rowIndex, colIndex, val); // TODO: Improve the performance
+		}
+		return mb;
+	}
+	
+	public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, boolean isSparse) throws DMLRuntimeException {
+		MatrixBlock mb = new MatrixBlock(rlen, clen, isSparse, -1);
+		if(isSparse) {
+			throw new DMLRuntimeException("Convertion to sparse format not supported");
+		}
+		else {
+			double [] denseBlock = new double[rlen*clen];
+			ByteBuffer buf = ByteBuffer.wrap(data);
+			buf.order(ByteOrder.nativeOrder());
+			for(int i = 0; i < rlen*clen; i++) {
+				denseBlock[i] = buf.getDouble();
+			}
+			mb.init( denseBlock, rlen, clen );
+		}
+		mb.examSparsity();
+		return mb;
+	}
+	
+	public static byte [] convertMBtoPy4JDenseArr(MatrixBlock mb) throws DMLRuntimeException {
+		byte [] ret = null;
+		if(mb.isInSparseFormat()) {
+			throw new DMLRuntimeException("Sparse to dense conversion is not yet implemented");
+		}
+		else {
+			double [] denseBlock = mb.getDenseBlock();
+			if(denseBlock == null) {
+				throw new DMLRuntimeException("Sparse to dense conversion is not yet implemented");
+			}
+			int times = Double.SIZE / Byte.SIZE;
+			ret = new byte[denseBlock.length * times];
+			for(int i=0;i < denseBlock.length;i++){
+		        ByteBuffer.wrap(ret, i*times, times).order(ByteOrder.nativeOrder()).putDouble(denseBlock[i]);
+			}
+		}
+		return ret;
+	}
+	
 	/**
 	 * Converts DataFrame into binary blocked RDD. 
 	 * Note: mcOut will be set if you don't know the dimensions.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
new file mode 100644
index 0000000..98def7c
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.rdd.RDD
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+import org.apache.spark.sql._
+
+trait HasLaplace extends Params {
+  final val laplace: Param[Double] = new Param[Double](this, "laplace", "Laplace smoothing specified by the user to avoid creation of 0 probabilities.")
+  setDefault(laplace, 1.0)
+  final def getLaplace: Double = $(laplace)
+}
+trait HasIcpt extends Params {
+  final val icpt: Param[Int] = new Param[Int](this, "icpt", "Intercept presence, shifting and rescaling X columns")
+  setDefault(icpt, 0)
+  final def getIcpt: Int = $(icpt)
+}
+trait HasMaxOuterIter extends Params {
+  final val maxOuterIter: Param[Int] = new Param[Int](this, "maxOuterIter", "max. number of outer (Newton) iterations")
+  setDefault(maxOuterIter, 100)
+  final def getMaxOuterIte: Int = $(maxOuterIter)
+}
+trait HasMaxInnerIter extends Params {
+  final val maxInnerIter: Param[Int] = new Param[Int](this, "maxInnerIter", "max. number of inner (conjugate gradient) iterations, 0 = no max")
+  setDefault(maxInnerIter, 0)
+  final def getMaxInnerIter: Int = $(maxInnerIter)
+}
+trait HasTol extends Params {
+  final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+  setDefault(tol, 0.000001)
+  final def getTol: Double = $(tol)
+}
+trait HasRegParam extends Params {
+  final val regParam: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+  setDefault(regParam, 0.000001)
+  final def getRegParam: Double = $(regParam)
+}
+
+trait BaseSystemMLEstimator {
+  
+  def transformSchema(schema: StructType): StructType = schema
+  
+  // Returns the script and variables for X and y
+  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)
+  
+  def toDouble(i:Int): java.lang.Double = {
+    double2Double(i.toDouble)
+  }
+  
+  def toDouble(d:Double): java.lang.Double = {
+    double2Double(d)
+  }
+}
+
+trait BaseSystemMLEstimatorModel {
+  def toDouble(i:Int): java.lang.Double = {
+    double2Double(i.toDouble)
+  }
+  def toDouble(d:Double): java.lang.Double = {
+    double2Double(d)
+  }
+  
+  def transformSchema(schema: StructType): StructType = schema
+  
+  // Returns the script and variable for X
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String)
+}
+
+trait BaseSystemMLClassifier extends BaseSystemMLEstimator {
+  
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock, sc: SparkContext): (MLResults, java.util.HashMap[Int, String]) = {
+    val isSingleNode = true
+    val ml = new MLContext(sc)
+    val revLabelMapping = new java.util.HashMap[Int, String]
+    PredictionUtils.fillLabelMapping(y_mb, revLabelMapping)
+    val ret = getTrainingScript(isSingleNode)
+    val script = ret._1.in(ret._2, X_mb).in(ret._3, y_mb)
+    (ml.execute(script), revLabelMapping)
+  }
+  
+  def fit(df: ScriptsUtils.SparkDataType, sc: SparkContext): (MLResults, java.util.HashMap[Int, String]) = {
+    val isSingleNode = false
+    val ml = new MLContext(df.rdd.sparkContext)
+    val mcXin = new MatrixCharacteristics()
+    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame], mcXin, false, "features")
+    val revLabelMapping = new java.util.HashMap[Int, String]
+    val yin = PredictionUtils.fillLabelMapping(df, revLabelMapping)
+    val ret = getTrainingScript(isSingleNode)
+    val Xbin = new BinaryBlockMatrix(Xin, mcXin)
+    val script = ret._1.in(ret._2, Xbin).in(ret._3, yin)
+    (ml.execute(script), revLabelMapping)
+  }
+}
+
+trait BaseSystemMLClassifierModel extends BaseSystemMLEstimatorModel {
+  
+  def transform(X: MatrixBlock, mloutput: MLResults, labelMapping: java.util.HashMap[Int, String], sc: SparkContext, probVar:String): MatrixBlock = {
+    val isSingleNode = true
+    val ml = new MLContext(sc)
+    val script = getPredictionScript(mloutput, isSingleNode)
+    val modelPredict = ml.execute(script._1.in(script._2, X))
+    val ret = PredictionUtils.computePredictedClassLabelsFromProbability(modelPredict, isSingleNode, sc, probVar)
+              .getBinaryBlockMatrix("Prediction").getMatrixBlock
+              
+    if(ret.getNumColumns != 1) {
+      throw new RuntimeException("Expected predicted label to be a column vector")
+    }
+    PredictionUtils.updateLabels(isSingleNode, null, ret, null, labelMapping)
+    return ret
+  }
+  
+  def transform(df: ScriptsUtils.SparkDataType, mloutput: MLResults, labelMapping: java.util.HashMap[Int, String], sc: SparkContext, 
+      probVar:String, outputProb:Boolean=true): DataFrame = {
+    val isSingleNode = false
+    val ml = new MLContext(sc)
+    val mcXin = new MatrixCharacteristics()
+    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame], mcXin, false, "features")
+    val script = getPredictionScript(mloutput, isSingleNode)
+    val Xin_bin = new BinaryBlockMatrix(Xin, mcXin)
+    val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
+    val predLabelOut = PredictionUtils.computePredictedClassLabelsFromProbability(modelPredict, isSingleNode, sc, probVar)
+    val predictedDF = PredictionUtils.updateLabels(isSingleNode, predLabelOut.getDataFrame("Prediction"), null, "C1", labelMapping).select("ID", "prediction")
+    if(outputProb) {
+      val prob = modelPredict.getDataFrame(probVar, true).withColumnRenamed("C1", "probability").select("ID", "probability")
+      val dataset = RDDConverterUtils.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sqlContext, "ID")
+      return PredictionUtils.joinUsingID(dataset, PredictionUtils.joinUsingID(prob, predictedDF))
+    }
+    else {
+      val dataset = RDDConverterUtils.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sqlContext, "ID")
+      return PredictionUtils.joinUsingID(dataset, predictedDF)
+    }
+    
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
new file mode 100644
index 0000000..5bcde30
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLRegressor.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.rdd.RDD
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+
+trait BaseSystemMLRegressor extends BaseSystemMLEstimator {
+  
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock, sc: SparkContext): MLResults = {
+    val isSingleNode = true
+    val ml = new MLContext(sc)
+    val ret = getTrainingScript(isSingleNode)
+    val script = ret._1.in(ret._2, X_mb).in(ret._3, y_mb)
+    ml.execute(script)
+  }
+  
+  def fit(df: ScriptsUtils.SparkDataType, sc: SparkContext): MLResults = {
+    val isSingleNode = false
+    val ml = new MLContext(df.rdd.sparkContext)
+    val mcXin = new MatrixCharacteristics()
+    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(sc, df.asInstanceOf[DataFrame], mcXin, false, "features")
+    val yin = df.select("label")
+    val ret = getTrainingScript(isSingleNode)
+    val Xbin = new BinaryBlockMatrix(Xin, mcXin)
+    val script = ret._1.in(ret._2, Xbin).in(ret._3, yin)
+    ml.execute(script)
+  }
+}
+
+trait BaseSystemMLRegressorModel extends BaseSystemMLEstimatorModel {
+  
+  def transform(X: MatrixBlock, mloutput: MLResults, sc: SparkContext, predictionVar:String): MatrixBlock = {
+    val isSingleNode = true
+    val ml = new MLContext(sc)
+    val script = getPredictionScript(mloutput, isSingleNode)
+    val modelPredict = ml.execute(script._1.in(script._2, X))
+    val ret = modelPredict.getBinaryBlockMatrix(predictionVar).getMatrixBlock
+              
+    if(ret.getNumColumns != 1) {
+      throw new RuntimeException("Expected prediction to be a column vector")
+    }
+    return ret
+  }
+  
+  def transform(df: ScriptsUtils.SparkDataType, mloutput: MLResults, sc: SparkContext, predictionVar:String): DataFrame = {
+    val isSingleNode = false
+    val ml = new MLContext(sc)
+    val mcXin = new MatrixCharacteristics()
+    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(df.rdd.sparkContext, df.asInstanceOf[DataFrame], mcXin, false, "features")
+    val script = getPredictionScript(mloutput, isSingleNode)
+    val Xin_bin = new BinaryBlockMatrix(Xin, mcXin)
+    val modelPredict = ml.execute(script._1.in(script._2, Xin_bin))
+    val predictedDF = modelPredict.getDataFrame(predictionVar).select("ID", "C1").withColumnRenamed("C1", "prediction")
+    val dataset = RDDConverterUtils.addIDToDataFrame(df.asInstanceOf[DataFrame], df.sqlContext, "ID")
+    return PredictionUtils.joinUsingID(dataset, predictedDF)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
new file mode 100644
index 0000000..cce646d
--- /dev/null
+++ b/src/main/scala/org/apache/sysml/api/ml/LinearRegression.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.api.ml
+
+import org.apache.spark.rdd.RDD
+import java.io.File
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
+
+object LinearRegression {
+  final val scriptPathCG = "scripts" + File.separator + "algorithms" + File.separator + "LinearRegCG.dml"
+  final val scriptPathDS = "scripts" + File.separator + "algorithms" + File.separator + "LinearRegDS.dml"
+}
+
+// algorithm = "direct-solve", "conjugate-gradient"
+class LinearRegression(override val uid: String, val sc: SparkContext, val solver:String="direct-solve") 
+  extends Estimator[LinearRegressionModel] with HasIcpt
+    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLRegressor {
+  
+  def setIcpt(value: Int) = set(icpt, value)
+  def setMaxIter(value: Int) = set(maxOuterIter, value)
+  def setRegParam(value: Double) = set(regParam, value)
+  def setTol(value: Double) = set(tol, value)
+  
+  override def copy(extra: ParamMap): Estimator[LinearRegressionModel] = {
+    val that = new LinearRegression(uid, sc, solver)
+    copyValues(that, extra)
+  }
+  
+          
+  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(
+        if(solver.compareTo("direct-solve") == 0) LinearRegression.scriptPathDS 
+        else if(solver.compareTo("newton-cg") == 0) LinearRegression.scriptPathCG
+        else throw new DMLRuntimeException("The algorithm should be direct-solve or newton-cg")))
+      .in("$X", " ")
+      .in("$Y", " ")
+      .in("$B", " ")
+      .in("$Log", " ")
+      .in("$fmt", "binary")
+      .in("$icpt", toDouble(getIcpt))
+      .in("$reg", toDouble(getRegParam))
+      .in("$tol", toDouble(getTol))
+      .in("$maxi", toDouble(getMaxOuterIte))
+      .out("beta_out")
+    (script, "X", "y")
+  }
+  
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LinearRegressionModel = 
+    new LinearRegressionModel("lr")(fit(X_mb, y_mb, sc), sc)
+    
+  def fit(df: ScriptsUtils.SparkDataType): LinearRegressionModel = 
+    new LinearRegressionModel("lr")(fit(df, sc), sc)
+  
+}
+
+class LinearRegressionModel(override val uid: String)(val mloutput: MLResults, val sc: SparkContext) extends Model[LinearRegressionModel] with HasIcpt
+    with HasRegParam with HasTol with HasMaxOuterIter with BaseSystemMLRegressorModel {
+  override def copy(extra: ParamMap): LinearRegressionModel = {
+    val that = new LinearRegressionModel(uid)(mloutput, sc)
+    copyValues(that, extra)
+  }
+  
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String) =
+    PredictionUtils.getGLMPredictionScript(mloutput.getBinaryBlockMatrix("beta_out"), isSingleNode)
+  
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = transform(df, mloutput, sc, "means")
+  
+  def transform(X: MatrixBlock): MatrixBlock =  transform(X, mloutput, sc, "means")
+  
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f02f7c01/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
index 2fabde1..a9ca6ab 100644
--- a/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/LogisticRegression.scala
@@ -19,47 +19,20 @@
 
 package org.apache.sysml.api.ml
 
+import org.apache.spark.rdd.RDD
 import java.io.File
-import org.apache.sysml.api.{ MLContext, MLOutput }
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics
-import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
-import org.apache.spark.{ SparkContext }
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{ Model, Estimator }
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.ml.{ Model, Estimator }
-import org.apache.spark.ml.classification._
 import org.apache.spark.ml.param.{ Params, Param, ParamMap, DoubleParam }
-import org.apache.spark.ml.param.shared._
-import org.apache.spark.SparkConf
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.mllib.regression.LabeledPoint
-import scala.reflect.ClassTag
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics
+import org.apache.sysml.runtime.matrix.data.MatrixBlock
+import org.apache.sysml.runtime.DMLRuntimeException
+import org.apache.sysml.runtime.instructions.spark.utils.{ RDDConverterUtilsExt => RDDConverterUtils }
+import org.apache.sysml.api.mlcontext._
+import org.apache.sysml.api.mlcontext.ScriptFactory._
 
-trait HasIcpt extends Params {
-  final val icpt: Param[Int] = new Param[Int](this, "icpt", "Intercept presence, shifting and rescaling X columns")
-  setDefault(icpt, 0)
-  final def getIcpt: Int = $(icpt)
-}
-trait HasMaxOuterIter extends Params {
-  final val maxOuterIter: Param[Int] = new Param[Int](this, "maxOuterIter", "max. number of outer (Newton) iterations")
-  setDefault(maxOuterIter, 100)
-  final def getMaxOuterIte: Int = $(maxOuterIter)
-}
-trait HasMaxInnerIter extends Params {
-  final val maxInnerIter: Param[Int] = new Param[Int](this, "maxInnerIter", "max. number of inner (conjugate gradient) iterations, 0 = no max")
-  setDefault(maxInnerIter, 0)
-  final def getMaxInnerIter: Int = $(maxInnerIter)
-}
-trait HasTol extends Params {
-  final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
-  setDefault(tol, 0.000001)
-  final def getTol: Double = $(tol)
-}
-trait HasRegParam extends Params {
-  final val regParam: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
-  setDefault(regParam, 0.000001)
-  final def getRegParam: Double = $(regParam)
-}
 object LogisticRegression {
   final val scriptPath = "scripts" + File.separator + "algorithms" + File.separator + "MultiLogReg.dml"
 }
@@ -68,7 +41,7 @@ object LogisticRegression {
  * Logistic Regression Scala API
  */
 class LogisticRegression(override val uid: String, val sc: SparkContext) extends Estimator[LogisticRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter {
+    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter with BaseSystemMLClassifier {
 
   def setIcpt(value: Int) = set(icpt, value)
   def setMaxOuterIter(value: Int) = set(maxOuterIter, value)
@@ -80,31 +53,31 @@ class LogisticRegression(override val uid: String, val sc: SparkContext) extends
     val that = new LogisticRegression(uid, sc)
     copyValues(that, extra)
   }
-  override def transformSchema(schema: StructType): StructType = schema
-  override def fit(df: DataFrame): LogisticRegressionModel = {
-    val ml = new MLContext(df.rdd.sparkContext)
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(sc, df, mcXin, false, "features")
-    val yin = df.select("label").rdd.map { _.apply(0).toString() }
-
-    val mloutput = {
-      val paramsMap: Map[String, String] = Map(
-        "icpt" -> this.getIcpt.toString(),
-        "reg" -> this.getRegParam.toString(),
-        "tol" -> this.getTol.toString,
-        "moi" -> this.getMaxOuterIte.toString,
-        "mii" -> this.getMaxInnerIter.toString,
-
-        "X" -> " ",
-        "Y" -> " ",
-        "B" -> " ")
-      ml.registerInput("X", Xin, mcXin);
-      ml.registerInput("Y_vec", yin, "csv");
-      ml.registerOutput("B_out");
-      ml.executeScript(ScriptsUtils.getDMLScript(LogisticRegression.scriptPath), paramsMap)
-      //ml.execute(ScriptsUtils.resolvePath(LogisticRegression.scriptPath), paramsMap)
-    }
-    new LogisticRegressionModel("logisticRegression")(mloutput)
+  
+  // Note: will update the y_mb as this will be called by Python mllearn
+  def fit(X_mb: MatrixBlock, y_mb: MatrixBlock): LogisticRegressionModel = {
+    val ret = fit(X_mb, y_mb, sc)
+    new LogisticRegressionModel("log")(ret._1, ret._2, sc)
+  }
+  
+  def fit(df: ScriptsUtils.SparkDataType): LogisticRegressionModel = {
+    val ret = fit(df, sc)
+    new LogisticRegressionModel("log")(ret._1, ret._2, sc)
+  }
+  
+  
+  def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = {
+    val script = dml(ScriptsUtils.getDMLScript(LogisticRegression.scriptPath))
+      .in("$X", " ")
+      .in("$Y", " ")
+      .in("$B", " ")
+      .in("$icpt", toDouble(getIcpt))
+      .in("$reg", toDouble(getRegParam))
+      .in("$tol", toDouble(getTol))
+      .in("$moi", toDouble(getMaxOuterIte))
+      .in("$mii", toDouble(getMaxInnerIter))
+      .out("B_out")
+    (script, "X", "Y_vec")
   }
 }
 object LogisticRegressionModel {
@@ -115,55 +88,22 @@ object LogisticRegressionModel {
  * Logistic Regression Scala API
  */
 
-class LogisticRegressionModel(
-  override val uid: String)(
-    val mloutput: MLOutput) extends Model[LogisticRegressionModel] with HasIcpt
-    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter {
+class LogisticRegressionModel(override val uid: String)(
+    val mloutput: MLResults, val labelMapping: java.util.HashMap[Int, String], val sc: SparkContext) 
+    extends Model[LogisticRegressionModel] with HasIcpt
+    with HasRegParam with HasTol with HasMaxOuterIter with HasMaxInnerIter with BaseSystemMLClassifierModel {
   override def copy(extra: ParamMap): LogisticRegressionModel = {
-    val that = new LogisticRegressionModel(uid)(mloutput)
+    val that = new LogisticRegressionModel(uid)(mloutput, labelMapping, sc)
     copyValues(that, extra)
   }
-  override def transformSchema(schema: StructType): StructType = schema
-  override def transform(df: DataFrame): DataFrame = {
-    val ml = new MLContext(df.rdd.sparkContext)
-
-    val mcXin = new MatrixCharacteristics()
-    val Xin = RDDConverterUtils.vectorDataFrameToBinaryBlock(df.rdd.sparkContext, df, mcXin, false, "features")
-
-    val mlscoreoutput = {
-      val paramsMap: Map[String, String] = Map(
-        "X" -> " ",
-        "B" -> " ")
-      ml.registerInput("X", Xin, mcXin);
-      ml.registerInput("B_full", mloutput.getBinaryBlockedRDD("B_out"), mloutput.getMatrixCharacteristics("B_out"));
-      ml.registerOutput("means");
-      ml.executeScript(ScriptsUtils.getDMLScript(LogisticRegressionModel.scriptPath), paramsMap)
-    }
-
-    val prob = mlscoreoutput.getDF(df.sqlContext, "means", true).withColumnRenamed("C1", "probability")
-
-    val mlNew = new MLContext(df.rdd.sparkContext)
-    mlNew.registerInput("X", Xin, mcXin);
-    mlNew.registerInput("B_full", mloutput.getBinaryBlockedRDD("B_out"), mloutput.getMatrixCharacteristics("B_out"));
-    mlNew.registerInput("Prob", mlscoreoutput.getBinaryBlockedRDD("means"), mlscoreoutput.getMatrixCharacteristics("means"));
-    mlNew.registerOutput("Prediction");
-    mlNew.registerOutput("rawPred");
-
-    val outNew = mlNew.executeScript("Prob = read(\"temp1\"); "
-      + "Prediction = rowIndexMax(Prob); "
-      + "write(Prediction, \"tempOut\", \"csv\")"
-      + "X = read(\"temp2\");"
-      + "B_full = read(\"temp3\");"
-      + "rawPred = 1 / (1 + exp(- X * t(B_full)) );" // Raw prediction logic: 
-      + "write(rawPred, \"tempOut1\", \"csv\")");
-
-    val pred = outNew.getDF(df.sqlContext, "Prediction").withColumnRenamed("C1", "prediction").withColumnRenamed("ID", "ID1")
-    val rawPred = outNew.getDF(df.sqlContext, "rawPred", true).withColumnRenamed("C1", "rawPrediction").withColumnRenamed("ID", "ID2")
-    var predictionsNProb = prob.join(pred, prob.col("ID").equalTo(pred.col("ID1"))).select("ID", "probability", "prediction")
-    predictionsNProb = predictionsNProb.join(rawPred, predictionsNProb.col("ID").equalTo(rawPred.col("ID2"))).select("ID", "probability", "prediction", "rawPrediction")
-    val dataset1 = RDDConverterUtils.addIDToDataFrame(df, df.sqlContext, "ID")
-    dataset1.join(predictionsNProb, dataset1.col("ID").equalTo(predictionsNProb.col("ID")))
-  }
+  var outputRawPredictions = true
+  def setOutputRawPredictions(outRawPred:Boolean): Unit = { outputRawPredictions = outRawPred }
+  
+  def getPredictionScript(mloutput: MLResults, isSingleNode:Boolean): (Script, String) =
+    PredictionUtils.getGLMPredictionScript(mloutput.getBinaryBlockMatrix("B_out"), isSingleNode, 3)
+   
+  def transform(X: MatrixBlock): MatrixBlock = transform(X, mloutput, labelMapping, sc, "means")
+  def transform(df: ScriptsUtils.SparkDataType): DataFrame = transform(df, mloutput, labelMapping, sc, "means")
 }
 
 /**
@@ -190,7 +130,7 @@ object LogisticRegressionExample {
       LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 2.3))))
     val lr = new LogisticRegression("log", sc)
     val lrmodel = lr.fit(training.toDF)
-    lrmodel.mloutput.getDF(sqlContext, "B_out").show()
+    // lrmodel.mloutput.getDF(sqlContext, "B_out").show()
 
     val testing = sc.parallelize(Seq(
       LabeledPoint(1.0, Vectors.dense(1.0, 0.0, 3.0)),