You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/06/02 13:25:03 UTC
flink git commit: [FLINK-2102] [ml] Add predict function for labeled
data for SVM and MLR.
Repository: flink
Updated Branches:
refs/heads/master 7571959a1 -> d163a817f
[FLINK-2102] [ml] Add predict function for labeled data for SVM and MLR.
These functions return for each example in the input DataSet[LabeledVector] a pair (truth, prediction)
Added documentation for new predict functions
This closes #744.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d163a817
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d163a817
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d163a817
Branch: refs/heads/master
Commit: d163a817fa2e330e86384d0bbcd104f051a6fb48
Parents: 7571959
Author: Theodore Vasiloudis <tv...@sics.se>
Authored: Thu May 28 18:51:17 2015 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Jun 2 13:24:05 2015 +0200
----------------------------------------------------------------------
docs/libs/ml/multiple_linear_regression.md | 8 +++
docs/libs/ml/svm.md | 8 +++
.../apache/flink/ml/classification/SVM.scala | 53 +++++++++++++++++-
.../regression/MultipleLinearRegression.scala | 58 +++++++++++++++++++-
.../flink/ml/classification/SVMITSuite.scala | 31 +++++++++++
.../MultipleLinearRegressionITSuite.scala | 24 ++++++++
6 files changed, 178 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/multiple_linear_regression.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/multiple_linear_regression.md b/docs/libs/ml/multiple_linear_regression.md
index d9bc951..aaf1fbf 100644
--- a/docs/libs/ml/multiple_linear_regression.md
+++ b/docs/libs/ml/multiple_linear_regression.md
@@ -77,6 +77,14 @@ MultipleLinearRegression predicts for all subtypes of `Vector` the corresponding
* `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]`
+If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the regression value
+for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element
+is the true value, as was provided from the input `DataSet[LabeledVector]` and the second element
+is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate
+the algorithm's performance.
+
+* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]`
+
## Parameters
The multiple linear regression implementation can be controlled by the following parameters:
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/svm.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/svm.md b/docs/libs/ml/svm.md
index a9c94ec..e649949 100644
--- a/docs/libs/ml/svm.md
+++ b/docs/libs/ml/svm.md
@@ -74,6 +74,14 @@ SVM predicts for all subtypes of `Vector` the corresponding class label:
* `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]`
+If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the class label
+for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element
+is the true value, as was provided from the input `DataSet[LabeledVector]` and the second element
+is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate
+the algorithm's performance.
+
+* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]`
+
## Parameters
The SVM implementation can be controlled by the following parameters:
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
index a186c5d..95f2b23 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
@@ -33,7 +33,7 @@ import org.apache.flink.ml.math.Breeze._
import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector}
-/** Implements a soft-maring SVM using the communication-efficient distributed dual coordinate
+/** Implements a soft-margin SVM using the communication-efficient distributed dual coordinate
* ascent algorithm (CoCoA) with hinge-loss function.
*
* The algorithm solves the following minimization problem:
@@ -276,6 +276,57 @@ object SVM{
}
}
+ /** [[org.apache.flink.ml.pipeline.PredictOperation]] for [[LabeledVector ]]types. The result type
+ * is a [[(Double, Double)]] tuple, corresponding to (truth, prediction)
+ *
+ * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair.
+ */
+ implicit def predictLabeledValues = {
+ new PredictOperation[SVM, LabeledVector, (Double, Double)]{
+ override def predict(
+ instance: SVM,
+ predictParameters: ParameterMap,
+ input: DataSet[LabeledVector])
+ : DataSet[(Double, Double)] = {
+
+ instance.weightsOption match {
+ case Some(weights) => {
+ input.map(new LabeledPredictionMapper).withBroadcastSet(weights, WEIGHT_VECTOR)
+ }
+
+ case None => {
+ throw new RuntimeException("The SVM model has not been trained. Call first fit" +
+ "before calling the predict operation.")
+ }
+ }
+ }
+ }
+ }
+
+ /** Mapper to calculate the value of the prediction function. This is a RichMapFunction, because
+ * we broadcast the weight vector to all mappers.
+ */
+ class LabeledPredictionMapper extends RichMapFunction[LabeledVector, (Double, Double)] {
+
+ var weights: BreezeDenseVector[Double] = _
+
+ @throws(classOf[Exception])
+ override def open(configuration: Configuration): Unit = {
+ // get current weights
+ weights = getRuntimeContext.
+ getBroadcastVariable[BreezeDenseVector[Double]](WEIGHT_VECTOR).get(0)
+ }
+
+ override def map(labeledVector: LabeledVector): (Double, Double) = {
+ // calculate the prediction value (scaled distance from the separating hyperplane)
+ val prediction = weights dot labeledVector.vector.asBreeze
+ val truth = labeledVector.label
+
+ (truth, prediction)
+ }
+ }
+
+
/** [[FitOperation]] which trains a SVM with soft-margin based on the given training data set.
*
*/
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 64b24dc..32746a1 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -21,11 +21,9 @@ package org.apache.flink.ml.regression
import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.scala.DataSet
import org.apache.flink.configuration.Configuration
-import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.math.{DenseVector, BLAS, Vector, vector2Array}
import org.apache.flink.ml.common._
-import org.apache.flink.ml.math.vector2Array
-
import org.apache.flink.api.scala._
import com.github.fommil.netlib.BLAS.{ getInstance => blas }
@@ -348,6 +346,60 @@ object MultipleLinearRegression {
LabeledVector(prediction, value)
}
}
+
+ /** Calculates the predictions for labeled data with respect to the learned linear model.
+ *
+ * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair.
+ */
+ implicit def predictLabeledVectors = {
+ new PredictOperation[MultipleLinearRegression, LabeledVector, (Double, Double)] {
+ override def predict(
+ instance: MultipleLinearRegression,
+ predictParameters: ParameterMap,
+ input: DataSet[LabeledVector])
+ : DataSet[(Double, Double)] = {
+ instance.weightsOption match {
+ case Some(weights) => {
+ input.map(new LinearRegressionLabeledPrediction)
+ .withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST)
+ }
+
+ case None => {
+ throw new RuntimeException("The MultipleLinearRegression has not been fitted to the " +
+ "data. This is necessary to learn the weight vector of the linear function.")
+ }
+ }
+ }
+ }
+ }
+
+ private class LinearRegressionLabeledPrediction
+ extends RichMapFunction[LabeledVector, (Double, Double)] {
+ private var weights: Array[Double] = null
+ private var weight0: Double = 0
+
+
+ @throws(classOf[Exception])
+ override def open(configuration: Configuration): Unit = {
+ val t = getRuntimeContext
+ .getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
+
+ val weightsPair = t.get(0)
+
+ weights = weightsPair._1
+ weight0 = weightsPair._2
+ }
+
+ override def map(labeledVector: LabeledVector ): (Double, Double) = {
+
+ val truth = labeledVector.label
+ val dotProduct = BLAS.dot(DenseVector(weights), labeledVector.vector)
+
+ val prediction = dotProduct + weight0
+
+ (truth, prediction)
+ }
+ }
}
//--------------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
index 55ef056..25c2afb 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
@@ -49,4 +49,35 @@ class SVMITSuite extends FlatSpec with Matchers with FlinkTestBase {
weight should be(expectedWeight +- 0.1)
}
}
+
+ it should "make (mostly) correct predictions" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val svm = SVM().
+ setBlocks(env.getParallelism).
+ setIterations(100).
+ setLocalIterations(100).
+ setRegularization(0.002).
+ setStepsize(0.1).
+ setSeed(0)
+
+ val trainingDS = env.fromCollection(Classification.trainingData)
+
+ svm.fit(trainingDS)
+
+ val threshold = 0.0
+
+ val predictionPairs = svm.predict(trainingDS).map {
+ truthPrediction =>
+ val truth = truthPrediction._1
+ val prediction = truthPrediction._2
+ val thresholdedPrediction = if (prediction > threshold) 1.0 else -1.0
+ (truth, thresholdedPrediction)
+ }
+
+ val absoluteErrorSum = predictionPairs.collect().map{
+ case (truth, prediction) => Math.abs(truth - prediction)}.sum
+
+ absoluteErrorSum should be < 15.0
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
index 8be239a..30338e5 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
@@ -106,4 +106,28 @@ class MultipleLinearRegressionITSuite
srs should be(RegressionData.expectedPolynomialSquaredResidualSum +- 5)
}
+
+ it should "make (mostly) correct predictions" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val mlr = MultipleLinearRegression()
+
+ import RegressionData._
+
+ val parameters = ParameterMap()
+
+ parameters.add(MultipleLinearRegression.Stepsize, 1.0)
+ parameters.add(MultipleLinearRegression.Iterations, 10)
+ parameters.add(MultipleLinearRegression.ConvergenceThreshold, 0.001)
+
+ val inputDS = env.fromCollection(data)
+ mlr.fit(inputDS, parameters)
+
+ val predictionPairs = mlr.predict(inputDS)
+
+ val absoluteErrorSum = predictionPairs.collect().map{
+ case (truth, prediction) => Math.abs(truth - prediction)}.sum
+
+ absoluteErrorSum should be < 50.0
+ }
}