You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/03/17 23:45:23 UTC
[4/8] flink git commit: [FLINK-1698] [ml] Adds polynomial base
feature mapper and test cases
[FLINK-1698] [ml] Adds polynomial base feature mapper and test cases
[ml] Adds comments to PolynomialBase
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/effea93d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/effea93d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/effea93d
Branch: refs/heads/master
Commit: effea93d72710dc9fa8184abc2d97ee33794b84f
Parents: 8648543
Author: Till Rohrmann <tr...@apache.org>
Authored: Mon Mar 9 18:10:52 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Mar 17 23:28:34 2015 +0100
----------------------------------------------------------------------
.../org/apache/flink/ml/common/FlinkTools.scala | 30 ++--
.../apache/flink/ml/common/ParameterMap.scala | 7 +-
.../apache/flink/ml/common/Transformer.scala | 2 +-
.../flink/ml/feature/PolynomialBase.scala | 148 +++++++++++++++++++
.../org/apache/flink/ml/math/DenseMatrix.scala | 10 ++
.../org/apache/flink/ml/math/DenseVector.scala | 11 ++
.../org/apache/flink/ml/math/package.scala | 17 ++-
.../flink/ml/feature/PolynomialBaseSuite.scala | 118 +++++++++++++++
.../MultipleLinearRegressionSuite.scala | 53 ++++++-
.../flink/ml/regression/RegressionData.scala | 61 +++++++-
10 files changed, 426 insertions(+), 31 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
index d972960..2b12f30 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
@@ -51,7 +51,7 @@ object FlinkTools {
dataset.output(outputFormat)
env.execute("FlinkTools persist")
- val inputFormat = new TypeSerializerInputFormat[T](dataset.getType.createSerializer())
+ val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
inputFormat.setFilePath(filePath)
env.createInput(inputFormat)
@@ -79,10 +79,10 @@ object FlinkTools {
env.execute("FlinkTools persist")
- val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+ val if1 = new TypeSerializerInputFormat[A](ds1.getType)
if1.setFilePath(f1)
- val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+ val if2 = new TypeSerializerInputFormat[B](ds2.getType)
if2.setFilePath(f2)
(env.createInput(if1), env.createInput(if2))
@@ -119,13 +119,13 @@ object FlinkTools {
env.execute("FlinkTools persist")
- val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+ val if1 = new TypeSerializerInputFormat[A](ds1.getType)
if1.setFilePath(f1)
- val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+ val if2 = new TypeSerializerInputFormat[B](ds2.getType)
if2.setFilePath(f2)
- val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+ val if3 = new TypeSerializerInputFormat[C](ds3.getType)
if3.setFilePath(f3)
(env.createInput(if1), env.createInput(if2), env.createInput(if3))
@@ -173,16 +173,16 @@ object FlinkTools {
env.execute("FlinkTools persist")
- val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+ val if1 = new TypeSerializerInputFormat[A](ds1.getType)
if1.setFilePath(f1)
- val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+ val if2 = new TypeSerializerInputFormat[B](ds2.getType)
if2.setFilePath(f2)
- val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+ val if3 = new TypeSerializerInputFormat[C](ds3.getType)
if3.setFilePath(f3)
- val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+ val if4 = new TypeSerializerInputFormat[D](ds4.getType)
if4.setFilePath(f4)
(env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
@@ -238,19 +238,19 @@ object FlinkTools {
env.execute("FlinkTools persist")
- val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+ val if1 = new TypeSerializerInputFormat[A](ds1.getType)
if1.setFilePath(f1)
- val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+ val if2 = new TypeSerializerInputFormat[B](ds2.getType)
if2.setFilePath(f2)
- val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+ val if3 = new TypeSerializerInputFormat[C](ds3.getType)
if3.setFilePath(f3)
- val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+ val if4 = new TypeSerializerInputFormat[D](ds4.getType)
if4.setFilePath(f4)
- val if5 = new TypeSerializerInputFormat[E](ds5.getType.createSerializer())
+ val if5 = new TypeSerializerInputFormat[E](ds5.getType)
if5.setFilePath(f5)
(env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
index 1d8d4ce..a5efe8a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
@@ -41,8 +41,9 @@ class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable
* @param value Value associated with the given key
* @tparam T Type of value
*/
- def add[T](parameter: Parameter[T], value: T): Unit = {
+ def add[T](parameter: Parameter[T], value: T): ParameterMap = {
map += (parameter -> value)
+ this
}
/**
@@ -100,6 +101,10 @@ class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable
object ParameterMap {
val Empty = new ParameterMap
+
+ def apply(): ParameterMap = {
+ new ParameterMap
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
index 5ba0ea4..76abc62 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
@@ -26,7 +26,7 @@ import org.apache.flink.api.scala.DataSet
* @tparam IN Type of incoming elements
* @tparam OUT Type of outgoing elements
*/
-trait Transformer[IN, OUT] {
+trait Transformer[IN, OUT] extends WithParameters {
def chain[CHAINED](transformer: Transformer[OUT, CHAINED]): ChainedTransformer[IN, OUT, CHAINED] = {
new ChainedTransformer[IN, OUT, CHAINED](this, transformer)
}
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
new file mode 100644
index 0000000..632ded6
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature
+
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.ml.common.{Parameter, ParameterMap, Transformer, LabeledVector}
+import org.apache.flink.ml.feature.PolynomialBase.Degree
+import org.apache.flink.ml.math.{DenseVector, Vector}
+
+import org.apache.flink.api.scala._
+
+/** Maps a vector into the polynomial feature space.
+ *
+ * This transformer takes a a vector of values `(x, y, z, ...)` and maps it into the
+ * polynomial feature space of degree `n`. That is to say, it calculates the following
+ * representation:
+ *
+ * `(x, y, z, x^2, xy, y^2, yz, z^2, x^3, x^2y, x^2z, xyz, ...)^T`
+ *
+ * This transformer can be prepended to all [[Transformer]] and
+ * [[org.apache.flink.ml.commonLearner]] implementations which expect an input of
+ * [[LabeledVector]].
+ *
+ * @example
+ * {{{
+ * val trainingDS: DataSet[LabeledVector] = ...
+ *
+ * val polyBase = PolynomialBase()
+ * .setDegree(3)
+ *
+ * val mlr = MultipleLinearRegression()
+ *
+ * val chained = polyBase.chain(mlr)
+ *
+ * val model = chained.fit(trainingDS)
+ * }}}
+ *
+ * =Parameters=
+ *
+ * - [[PolynomialBase.Degree]]: Maximum polynomial degree
+ */
+class PolynomialBase extends Transformer[LabeledVector, LabeledVector] with Serializable {
+
+ def setDegree(degree: Int): PolynomialBase = {
+ parameters.add(Degree, degree)
+ this
+ }
+
+ override def transform(input: DataSet[LabeledVector], parameters: ParameterMap):
+ DataSet[LabeledVector] = {
+ val resultingParameters = this.parameters ++ parameters
+
+ val degree = resultingParameters(Degree)
+
+ input.map {
+ labeledVector => {
+ val vector = labeledVector.vector
+ val label = labeledVector.label
+
+ val transformedVector = calculatePolynomial(degree, vector)
+
+ LabeledVector(transformedVector, label)
+ }
+ }
+ }
+
+ private def calculatePolynomial(degree: Int, vector: Vector): Vector = {
+ new DenseVector(calculateCombinedCombinations(degree, vector).toArray)
+ }
+
+ /** Calculates for a given vector its representation in the polynomial feature space.
+ *
+ * @param degree Maximum degree of polynomial
+ * @param vector Values of the polynomial variables
+ * @return List of polynomial values
+ */
+ private def calculateCombinedCombinations(degree: Int, vector: Vector): List[Double] = {
+ if(degree == 0) {
+ List()
+ } else {
+ val partialResult = calculateCombinedCombinations(degree - 1, vector)
+
+ val combinations = calculateCombinations(vector.size, degree)
+
+ val result = combinations map {
+ combination =>
+ combination.zipWithIndex.map{
+ case (exp, idx) => math.pow(vector(idx), exp)
+ }.fold(1.0)(_ * _)
+ }
+
+ result ::: partialResult
+ }
+
+ }
+
+ /** Calculates all possible combinations of a polynom of degree `value`, whereas the polynom
+ * can consist of up to `length` factors. The return value is the list of the exponents of the
+ * individual factors
+ *
+ * @param length maximum number of factors
+ * @param value degree of polynomial
+ * @return List of lists which contain the exponents of the individual factors
+ */
+ private def calculateCombinations(length: Int, value: Int): List[List[Int]] = {
+ if(length == 0) {
+ List()
+ } else if (length == 1) {
+ List(List(value))
+ } else {
+ value to 0 by -1 flatMap {
+ v =>
+ calculateCombinations(length - 1, value - v) map {
+ v::_
+ }
+ } toList
+ }
+ }
+}
+
+object PolynomialBase{
+
+ case object Degree extends Parameter[Int] {
+ override val defaultValue: Option[Int] = Some(1)
+ }
+
+ // ========================= Factory methods ======================================
+
+ def apply(): PolynomialBase = {
+ new PolynomialBase()
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index c950dc7..f3bd630 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -53,6 +53,16 @@ case class DenseMatrix(val numRows: Int,
s"DenseMatrix($numRows, $numCols, ${values.mkString(", ")})"
}
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case dense: DenseMatrix =>
+ numRows == dense.numRows && numCols == dense.numCols && values.zip(dense.values).forall {
+ case (a, b) => a == b
+ }
+ case _ => false
+ }
+ }
+
}
object DenseMatrix {
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index 81c1d8d..8e0eed0 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -49,6 +49,17 @@ case class DenseVector(val values: Array[Double]) extends Vector {
override def toString: String = {
s"DenseVector(${values.mkString(", ")})"
}
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case dense: DenseVector =>
+ values.length == dense.values.length && values.zip(dense.values).forall{
+ case (a,b) => a == b
+ }
+
+ case _ => false
+ }
+ }
}
object DenseVector {
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index 4914d24..fce008a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -22,11 +22,20 @@ package org.apache.flink.ml
* Convenience to handle Flink's [[org.apache.flink.ml.math.Matrix]] and [[Vector]] abstraction.
*/
package object math {
- implicit class RichDenseMatrix(matrix: DenseMatrix) extends Iterable[Double] {
- override def iterator: Iterator[Double] = matrix.values.iterator
+ implicit class RichMatrix(matrix: Matrix) extends Iterable[Double] {
+
+ override def iterator: Iterator[Double] = {
+ matrix match {
+ case dense: DenseMatrix => dense.values.iterator
+ }
+ }
}
- implicit class RichDenseVector(vector: DenseVector) extends Iterable[Double] {
- override def iterator: Iterator[Double] = vector.values.iterator
+ implicit class RichVector(vector: Vector) extends Iterable[Double] {
+ override def iterator: Iterator[Double] = {
+ vector match {
+ case dense: DenseVector => dense.values.iterator
+ }
+ }
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
new file mode 100644
index 0000000..8da822f
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature
+
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.math.DenseVector
+import org.scalatest.{ShouldMatchers, FlatSpec}
+
+import org.apache.flink.api.scala._
+
+class PolynomialBaseSuite extends FlatSpec with ShouldMatchers {
+ behavior of "A PolynomialBase"
+
+ it should "map an element into a polynomial vector space" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input = Seq(
+ LabeledVector(DenseVector(1), 1.0),
+ LabeledVector(DenseVector(2), 2.0)
+ )
+
+ val inputDS = env.fromCollection(input)
+
+ val transformer = PolynomialBase()
+ .setDegree(3)
+
+ val transformedDS = transformer.transform(inputDS)
+
+ val expectedMap = List(
+ (1.0 -> DenseVector(1.0, 1.0, 1.0)),
+ (2.0 -> DenseVector(8.0, 4.0, 2.0))
+ ) toMap
+
+ val result = transformedDS.collect
+
+ for(entry <- result) {
+ expectedMap.contains(entry.label) should be(true)
+ entry.vector should equal(expectedMap(entry.label))
+ }
+
+ }
+
+ it should "map a vector into a polynomial vector space" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input = Seq(
+ LabeledVector(DenseVector(2, 3), 1.0),
+ LabeledVector(DenseVector(2, 3, 4), 2.0)
+ )
+
+ val expectedMap = List(
+ (1.0 -> DenseVector(8.0, 12.0, 18.0, 27.0, 4.0, 6.0, 9.0, 2.0, 3.0)),
+ (2.0 -> DenseVector(8.0, 12.0, 16.0, 18.0, 24.0, 32.0, 27.0, 36.0, 48.0, 64.0, 4.0, 6.0, 8.0,
+ 9.0, 12.0, 16.0, 2.0, 3.0, 4.0))
+ ) toMap
+
+ val inputDS = env.fromCollection(input)
+
+ val transformer = PolynomialBase()
+ .setDegree(3)
+
+ val transformedDS = transformer.transform(inputDS)
+
+ val result = transformedDS.collect
+
+ for(entry <- result) {
+ expectedMap.contains(entry.label) should be(true)
+ entry.vector should equal(expectedMap(entry.label))
+ }
+
+ println(result)
+ }
+
+ it should "return an empty vector if the polynomial degree is set to 0" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val input = Seq(
+ LabeledVector(DenseVector(2, 3), 1.0),
+ LabeledVector(DenseVector(2, 3, 4), 2.0)
+ )
+
+ val inputDS = env.fromCollection(input)
+
+ val transformer = PolynomialBase()
+ .setDegree(0)
+
+ val transformedDS = transformer.transform(inputDS)
+
+ val result = transformedDS.collect
+
+ val expectedMap = List(
+ (1.0 -> DenseVector()),
+ (2.0 -> DenseVector())
+ ) toMap
+
+ for(entry <- result) {
+ expectedMap.contains(entry.label) should be(true)
+ entry.vector should equal(expectedMap(entry.label))
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
index 8d59b49..006ff93 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.flink.ml.regression
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.ml.common.ParameterMap
+import org.apache.flink.ml.feature.PolynomialBase
import org.scalatest.{ShouldMatchers, FlatSpec}
import org.apache.flink.api.scala._
@@ -43,21 +44,57 @@ class MultipleLinearRegressionSuite extends FlatSpec with ShouldMatchers {
val inputDS = env.fromCollection(data)
val model = learner.fit(inputDS, parameters)
- val betasList = model.weights.collect
+ val weightList = model.weights.collect
- betasList.size should equal(1)
+ weightList.size should equal(1)
- val (betas, beta0) = betasList(0)
+ val (weights, weight0) = weightList(0)
- expectedBetas.data zip betas.data foreach {
- case (expectedBeta, beta) => {
- beta should be (expectedBeta +- 1)
- }
+ expectedWeights.data zip weights.data foreach {
+ case (expectedWeight, weight) =>
+ weight should be (expectedWeight +- 1)
}
- beta0 should be (expectedBeta0 +- 0.4)
+ weight0 should be (expectedWeight0 +- 0.4)
val srs = model.squaredResidualSum(inputDS).collect(0)
srs should be (expectedSquaredResidualSum +- 2)
}
+
+ it should "calculate the correct polynomial function" in {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ val polynomialBase = PolynomialBase()
+ val learner = MultipleLinearRegression()
+
+ val pipeline = polynomialBase.chain(learner)
+
+ val inputDS = env.fromCollection(RegressionData.polynomialData)
+
+ val parameters = ParameterMap()
+ .add(PolynomialBase.Degree, 3)
+ .add(MultipleLinearRegression.Stepsize, 0.002)
+ .add(MultipleLinearRegression.Iterations, 100)
+
+ val model = pipeline.fit(inputDS, parameters)
+
+ val weightList = model.weights.collect
+
+ weightList.size should equal(1)
+
+ val (weights, weight0) = weightList(0)
+
+ RegressionData.expectedPolynomialWeights.zip(weights.data) foreach {
+ case (expectedWeight, weight) =>
+ weight should be(expectedWeight +- 0.1)
+ }
+
+ weight0 should be(RegressionData.expectedPolynomialWeight0 +- 0.1)
+
+ val transformedInput = polynomialBase.transform(inputDS, parameters)
+
+ val srs = model.squaredResidualSum(transformedInput).collect(0)
+
+ srs should be(RegressionData.expectedPolynomialSquaredResidualSum +- 5)
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
index c4050f3..138b6ec 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
@@ -24,8 +24,8 @@ import org.jblas.DoubleMatrix
object RegressionData {
- val expectedBetas: DoubleMatrix = new DoubleMatrix(1, 1, 3.0094)
- val expectedBeta0: Double = 9.8158
+ val expectedWeights: DoubleMatrix = new DoubleMatrix(1, 1, 3.0094)
+ val expectedWeight0: Double = 9.8158
val expectedSquaredResidualSum: Double = 49.7596
val data: Seq[LabeledVector] = Seq(
@@ -70,4 +70,61 @@ object RegressionData {
LabeledVector(DenseVector(0.4249), 11.9999),
LabeledVector(DenseVector(0.1192), 12.0442)
)
+
+ val expectedPolynomialWeights = Seq(0.2375, -0.3493, -0.1674)
+ val expectedPolynomialWeight0 = 0.0233
+ val expectedPolynomialSquaredResidualSum = 1.5389e+03
+
+ val polynomialData: Seq[LabeledVector] = Seq(
+ LabeledVector(DenseVector(3.6663), 2.1415),
+ LabeledVector(DenseVector(4.0761), 10.9835),
+ LabeledVector(DenseVector(0.5714), 7.2507),
+ LabeledVector(DenseVector(4.1102), 11.9274),
+ LabeledVector(DenseVector(2.8456), -4.2798),
+ LabeledVector(DenseVector(0.4389), 7.1929),
+ LabeledVector(DenseVector(1.2532), 4.5097),
+ LabeledVector(DenseVector(2.4610), -3.6059),
+ LabeledVector(DenseVector(4.3088), 18.1132),
+ LabeledVector(DenseVector(4.3420), 19.2674),
+ LabeledVector(DenseVector(0.7093), 7.0664),
+ LabeledVector(DenseVector(4.3677), 20.1836),
+ LabeledVector(DenseVector(4.3073), 18.0609),
+ LabeledVector(DenseVector(2.1842), -2.2090),
+ LabeledVector(DenseVector(3.6013), 1.1306),
+ LabeledVector(DenseVector(0.6385), 7.1903),
+ LabeledVector(DenseVector(1.8979), -0.2668),
+ LabeledVector(DenseVector(4.1208), 12.2281),
+ LabeledVector(DenseVector(3.5649), 0.6086),
+ LabeledVector(DenseVector(4.3177), 18.4202),
+ LabeledVector(DenseVector(2.9508), -4.1284),
+ LabeledVector(DenseVector(0.1607), 6.1964),
+ LabeledVector(DenseVector(3.8211), 4.9638),
+ LabeledVector(DenseVector(4.2030), 14.6677),
+ LabeledVector(DenseVector(3.0543), -3.8132),
+ LabeledVector(DenseVector(3.4098), -1.2891),
+ LabeledVector(DenseVector(3.3441), -1.9390),
+ LabeledVector(DenseVector(1.7650), 0.7293),
+ LabeledVector(DenseVector(2.9497), -4.1310),
+ LabeledVector(DenseVector(0.7703), 6.9131),
+ LabeledVector(DenseVector(3.1772), -3.2060),
+ LabeledVector(DenseVector(0.1432), 6.0899),
+ LabeledVector(DenseVector(1.2462), 4.5567),
+ LabeledVector(DenseVector(0.2078), 6.4562),
+ LabeledVector(DenseVector(0.4371), 7.1903),
+ LabeledVector(DenseVector(3.7056), 2.8017),
+ LabeledVector(DenseVector(3.1267), -3.4873),
+ LabeledVector(DenseVector(1.4269), 3.2918),
+ LabeledVector(DenseVector(4.2760), 17.0085),
+ LabeledVector(DenseVector(0.1550), 6.1622),
+ LabeledVector(DenseVector(1.9743), -0.8192),
+ LabeledVector(DenseVector(1.7170), 1.0957),
+ LabeledVector(DenseVector(3.4448), -0.9065),
+ LabeledVector(DenseVector(3.5784), 0.7986),
+ LabeledVector(DenseVector(0.8409), 6.6861),
+ LabeledVector(DenseVector(2.2039), -2.3274),
+ LabeledVector(DenseVector(2.0051), -1.0359),
+ LabeledVector(DenseVector(2.9084), -4.2092),
+ LabeledVector(DenseVector(3.1921), -3.1140),
+ LabeledVector(DenseVector(3.3961), -1.4323)
+ )
}