You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2013/12/26 07:31:22 UTC
[05/28] git commit: Bindings for linear, Lasso, and ridge regression.
Bindings for linear, Lasso, and ridge regression.
Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/ded67ee9
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/ded67ee9
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/ded67ee9
Branch: refs/heads/master
Commit: ded67ee90c2c0b22d67e623156a3f6cce8573abd
Parents: 2a41c9a
Author: Tor Myklebust <tm...@gmail.com>
Authored: Thu Dec 19 22:42:12 2013 -0500
Committer: Tor Myklebust <tm...@gmail.com>
Committed: Thu Dec 19 22:42:12 2013 -0500
----------------------------------------------------------------------
.../apache/spark/mllib/api/PythonMLLibAPI.scala | 42 +++++++++++++++++---
1 file changed, 37 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/ded67ee9/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
index 3daf5dc..c9bd7c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala
@@ -1,5 +1,6 @@
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression._
+import org.apache.spark.rdd.RDD
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.DoubleBuffer
@@ -38,14 +39,45 @@ class PythonMLLibAPI extends Serializable {
return bytes
}
- def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]):
- java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x))
- .map(v => LabeledPoint(v(0), v.slice(1, v.length)))
- val model = LinearRegressionWithSGD.train(data, 222)
+ def trainRegressionModel(trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+ dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]):
+ java.util.LinkedList[java.lang.Object] = {
+ val data = dataBytesJRDD.rdd.map(xBytes => {
+ val x = deserializeDoubleVector(xBytes)
+ LabeledPoint(x(0), x.slice(1, x.length))
+ })
+ val initialWeights = deserializeDoubleVector(initialWeightsBA)
+ val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
return ret
}
+
+ def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]],
+ numIterations: Int, stepSize: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LinearRegressionWithSGD.train(data, numIterations, stepSize,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
+
+ def trainLassoModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ LassoWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
+
+ def trainRidgeModel(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int,
+ stepSize: Double, regParam: Double, miniBatchFraction: Double,
+ initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ return trainRegressionModel((data, initialWeights) =>
+ RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam,
+ miniBatchFraction, initialWeights),
+ dataBytesJRDD, initialWeightsBA);
+ }
}