You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/12/26 07:31:22 UTC

[05/28] git commit: Bindings for linear, Lasso, and ridge regression.

Bindings for linear, Lasso, and ridge regression.


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/ded67ee9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/ded67ee9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/ded67ee9

Branch: refs/heads/master
Commit: ded67ee90c2c0b22d67e623156a3f6cce8573abd
Parents: 2a41c9a
Author: Tor Myklebust <tm...@gmail.com>
Authored: Thu Dec 19 22:42:12 2013 -0500
Committer: Tor Myklebust <tm...@gmail.com>
Committed: Thu Dec 19 22:42:12 2013 -0500

----------------------------------------------------------------------
 .../apache/spark/mllib/api/PythonMLLibAPI.scala | 42 +++++++++++++++++---
 1 file changed, 37 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ded67ee9/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index 3daf5dc..c9bd7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -1,5 +1,6 @@
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.regression._
+import org.apache.spark.rdd.RDD
 import java.nio.ByteBuffer
 import java.nio.ByteOrder
 import java.nio.DoubleBuffer
@@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable {
     return bytes
   }
 
-  def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]):
-      java.util.List[java.lang.Object] = {
-    val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x))
-        .map(v => LabeledPoint(v(0), v.slice(1, v.length)))
-    val model = LinearRegressionWithSGD.train(data, 222)
+  def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+      dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
+      java.util.LinkedList[java.lang.Object] = {
+    val data = dataBytesJRDD.rdd.map(xBytes => {
+        val x = deserializeDoubleVector(xBytes)
+        LabeledPoint(x(0), x.slice(1, x.length))
+    })
+    val initialWeights = deserializeDoubleVector(initialWeightsBA)
+    val model = trainFunc(data, initialWeights)
     val ret = new java.util.LinkedList[java.lang.Object]()
     ret.add(serializeDoubleVector(model.weights))
     ret.add(model.intercept: java.lang.Double)
     return ret
   }
+
+  def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
+      numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LinearRegressionWithSGD.train(data, numIterations, stepSize,
+                                      miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA);
+  }
+
+  def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        LassoWithSGD.train(data, numIterations, stepSize, regParam,
+                           miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA);
+  }
+
+  def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+      stepSize: Double, regParam: Double, miniBatchFraction: Double,
+      initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+    return trainRegressionModel((data, initialWeights) =>
+        RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
+                                     miniBatchFraction, initialWeights),
+        dataBytesJRDD, initialWeightsBA);
+  }
 }