You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/03/17 23:45:23 UTC

[4/8] flink git commit: [FLINK-1698] [ml] Adds polynomial base feature mapper and test cases

[FLINK-1698] [ml] Adds polynomial base feature mapper and test cases

[ml] Adds comments to PolynomialBase


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/effea93d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/effea93d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/effea93d

Branch: refs/heads/master
Commit: effea93d72710dc9fa8184abc2d97ee33794b84f
Parents: 8648543
Author: Till Rohrmann <tr...@apache.org>
Authored: Mon Mar 9 18:10:52 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Tue Mar 17 23:28:34 2015 +0100

----------------------------------------------------------------------
 .../org/apache/flink/ml/common/FlinkTools.scala |  30 ++--
 .../apache/flink/ml/common/ParameterMap.scala   |   7 +-
 .../apache/flink/ml/common/Transformer.scala    |   2 +-
 .../flink/ml/feature/PolynomialBase.scala       | 148 +++++++++++++++++++
 .../org/apache/flink/ml/math/DenseMatrix.scala  |  10 ++
 .../org/apache/flink/ml/math/DenseVector.scala  |  11 ++
 .../org/apache/flink/ml/math/package.scala      |  17 ++-
 .../flink/ml/feature/PolynomialBaseSuite.scala  | 118 +++++++++++++++
 .../MultipleLinearRegressionSuite.scala         |  53 ++++++-
 .../flink/ml/regression/RegressionData.scala    |  61 +++++++-
 10 files changed, 426 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
index d972960..2b12f30 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
@@ -51,7 +51,7 @@ object FlinkTools {
     dataset.output(outputFormat)
     env.execute("FlinkTools persist")
 
-    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType.createSerializer())
+    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
     inputFormat.setFilePath(filePath)
 
     env.createInput(inputFormat)
@@ -79,10 +79,10 @@ object FlinkTools {
 
     env.execute("FlinkTools persist")
 
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
     if1.setFilePath(f1)
 
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
     if2.setFilePath(f2)
 
     (env.createInput(if1), env.createInput(if2))
@@ -119,13 +119,13 @@ object FlinkTools {
 
     env.execute("FlinkTools persist")
 
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
     if1.setFilePath(f1)
 
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
     if2.setFilePath(f2)
 
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
     if3.setFilePath(f3)
 
     (env.createInput(if1), env.createInput(if2), env.createInput(if3))
@@ -173,16 +173,16 @@ object FlinkTools {
 
     env.execute("FlinkTools persist")
 
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
     if1.setFilePath(f1)
 
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
     if2.setFilePath(f2)
 
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
     if3.setFilePath(f3)
 
-    val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
     if4.setFilePath(f4)
 
     (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
@@ -238,19 +238,19 @@ object FlinkTools {
 
     env.execute("FlinkTools persist")
 
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType.createSerializer())
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
     if1.setFilePath(f1)
 
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType.createSerializer())
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
     if2.setFilePath(f2)
 
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType.createSerializer())
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
     if3.setFilePath(f3)
 
-    val if4 = new TypeSerializerInputFormat[D](ds4.getType.createSerializer())
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
     if4.setFilePath(f4)
 
-    val if5 = new TypeSerializerInputFormat[E](ds5.getType.createSerializer())
+    val if5 = new TypeSerializerInputFormat[E](ds5.getType)
     if5.setFilePath(f5)
 
     (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
index 1d8d4ce..a5efe8a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala
@@ -41,8 +41,9 @@ class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable
    * @param value Value associated with the given key
    * @tparam T Type of value
    */
-  def add[T](parameter: Parameter[T], value: T): Unit = {
+  def add[T](parameter: Parameter[T], value: T): ParameterMap = {
     map += (parameter -> value)
+    this
   }
 
   /**
@@ -100,6 +101,10 @@ class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable
 
 object ParameterMap {
   val Empty = new ParameterMap
+
+  def apply(): ParameterMap = {
+    new ParameterMap
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
index 5ba0ea4..76abc62 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Transformer.scala
@@ -26,7 +26,7 @@ import org.apache.flink.api.scala.DataSet
  * @tparam IN Type of incoming elements
  * @tparam OUT Type of outgoing elements
  */
-trait Transformer[IN, OUT] {
+trait Transformer[IN, OUT] extends WithParameters {
   def chain[CHAINED](transformer: Transformer[OUT, CHAINED]): ChainedTransformer[IN, OUT, CHAINED] = {
     new ChainedTransformer[IN, OUT, CHAINED](this, transformer)
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
new file mode 100644
index 0000000..632ded6
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/feature/PolynomialBase.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature
+
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.ml.common.{Parameter, ParameterMap, Transformer, LabeledVector}
+import org.apache.flink.ml.feature.PolynomialBase.Degree
+import org.apache.flink.ml.math.{DenseVector, Vector}
+
+import org.apache.flink.api.scala._
+
+/** Maps a vector into the polynomial feature space.
+  *
+  * This transformer takes a a vector of values `(x, y, z, ...)` and maps it into the
+  * polynomial feature space of degree `n`. That is to say, it calculates the following
+  * representation:
+  *
+  * `(x, y, z, x^2, xy, y^2, yz, z^2, x^3, x^2y, x^2z, xyz, ...)^T`
+  *
+  * This transformer can be prepended to all [[Transformer]] and
+  * [[org.apache.flink.ml.commonLearner]] implementations which expect an input of
+  * [[LabeledVector]].
+  *
+  * @example
+  *          {{{
+  *             val trainingDS: DataSet[LabeledVector] = ...
+  *
+  *             val polyBase = PolynomialBase()
+  *               .setDegree(3)
+  *
+  *             val mlr = MultipleLinearRegression()
+  *
+  *             val chained = polyBase.chain(mlr)
+  *
+  *             val model = chained.fit(trainingDS)
+  *          }}}
+  *
+  * =Parameters=
+  *
+  *  - [[PolynomialBase.Degree]]: Maximum polynomial degree
+  */
+class PolynomialBase extends Transformer[LabeledVector, LabeledVector] with Serializable {
+
+  def setDegree(degree: Int): PolynomialBase = {
+    parameters.add(Degree, degree)
+    this
+  }
+
+  override def transform(input: DataSet[LabeledVector], parameters: ParameterMap):
+  DataSet[LabeledVector] = {
+    val resultingParameters = this.parameters ++ parameters
+
+    val degree = resultingParameters(Degree)
+
+    input.map {
+      labeledVector => {
+        val vector = labeledVector.vector
+        val label = labeledVector.label
+
+        val transformedVector = calculatePolynomial(degree, vector)
+
+        LabeledVector(transformedVector, label)
+      }
+    }
+  }
+
+  private def calculatePolynomial(degree: Int, vector: Vector): Vector = {
+    new DenseVector(calculateCombinedCombinations(degree, vector).toArray)
+  }
+
+  /** Calculates for a given vector its representation in the polynomial feature space.
+    *
+    * @param degree Maximum degree of polynomial
+    * @param vector Values of the polynomial variables
+    * @return List of polynomial values
+    */
+  private def calculateCombinedCombinations(degree: Int, vector: Vector): List[Double] = {
+    if(degree == 0) {
+      List()
+    } else {
+      val partialResult = calculateCombinedCombinations(degree - 1, vector)
+
+      val combinations = calculateCombinations(vector.size, degree)
+
+      val result = combinations map {
+        combination =>
+          combination.zipWithIndex.map{
+            case (exp, idx) => math.pow(vector(idx), exp)
+          }.fold(1.0)(_ * _)
+      }
+
+      result ::: partialResult
+    }
+
+  }
+
+  /** Calculates all possible combinations of a polynom of degree `value`, whereas the polynom
+    * can consist of up to `length` factors. The return value is the list of the exponents of the
+    * individual factors
+    *
+    * @param length maximum number of factors
+    * @param value degree of polynomial
+    * @return List of lists which contain the exponents of the individual factors
+    */
+  private def calculateCombinations(length: Int, value: Int): List[List[Int]] = {
+    if(length == 0) {
+      List()
+    } else if (length == 1) {
+      List(List(value))
+    } else {
+      value to 0 by -1 flatMap {
+        v =>
+          calculateCombinations(length - 1, value - v) map {
+            v::_
+          }
+      } toList
+    }
+  }
+}
+
+object PolynomialBase{
+
+  case object Degree extends Parameter[Int] {
+    override val defaultValue: Option[Int] = Some(1)
+  }
+
+  // ========================= Factory methods ======================================
+
+  def apply(): PolynomialBase = {
+    new PolynomialBase()
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index c950dc7..f3bd630 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -53,6 +53,16 @@ case class DenseMatrix(val numRows: Int,
     s"DenseMatrix($numRows, $numCols, ${values.mkString(", ")})"
   }
 
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case dense: DenseMatrix =>
+        numRows == dense.numRows && numCols == dense.numCols && values.zip(dense.values).forall {
+          case (a, b) => a == b
+        }
+      case _ => false
+    }
+  }
+
 }
 
 object DenseMatrix {

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index 81c1d8d..8e0eed0 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -49,6 +49,17 @@ case class DenseVector(val values: Array[Double]) extends Vector {
   override def toString: String = {
     s"DenseVector(${values.mkString(", ")})"
   }
+
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case dense: DenseVector =>
+        values.length == dense.values.length && values.zip(dense.values).forall{
+          case (a,b) => a == b
+        }
+
+      case _ => false
+    }
+  }
 }
 
 object DenseVector {

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index 4914d24..fce008a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -22,11 +22,20 @@ package org.apache.flink.ml
  * Convenience to handle Flink's [[org.apache.flink.ml.math.Matrix]] and [[Vector]] abstraction.
  */
 package object math {
-  implicit class RichDenseMatrix(matrix: DenseMatrix) extends Iterable[Double] {
-    override def iterator: Iterator[Double] = matrix.values.iterator
+  implicit class RichMatrix(matrix: Matrix) extends Iterable[Double] {
+
+    override def iterator: Iterator[Double] = {
+      matrix match {
+        case dense: DenseMatrix => dense.values.iterator
+      }
+    }
   }
 
-  implicit class RichDenseVector(vector: DenseVector) extends Iterable[Double] {
-    override def iterator: Iterator[Double] = vector.values.iterator
+  implicit class RichVector(vector: Vector) extends Iterable[Double] {
+    override def iterator: Iterator[Double] = {
+      vector match {
+        case dense: DenseVector => dense.values.iterator
+      }
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
new file mode 100644
index 0000000..8da822f
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseSuite.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature
+
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.math.DenseVector
+import org.scalatest.{ShouldMatchers, FlatSpec}
+
+import org.apache.flink.api.scala._
+
+class PolynomialBaseSuite extends FlatSpec with ShouldMatchers {
+  behavior of "A PolynomialBase"
+
+  it should "map an element into a polynomial vector space" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val input = Seq(
+      LabeledVector(DenseVector(1), 1.0),
+      LabeledVector(DenseVector(2), 2.0)
+    )
+
+    val inputDS = env.fromCollection(input)
+
+    val transformer = PolynomialBase()
+    .setDegree(3)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val expectedMap = List(
+      (1.0 -> DenseVector(1.0, 1.0, 1.0)),
+      (2.0 -> DenseVector(8.0, 4.0, 2.0))
+    ) toMap
+
+    val result = transformedDS.collect
+
+    for(entry <- result) {
+      expectedMap.contains(entry.label) should be(true)
+      entry.vector should equal(expectedMap(entry.label))
+    }
+
+  }
+
+  it should "map a vector into a polynomial vector space" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val input = Seq(
+      LabeledVector(DenseVector(2, 3), 1.0),
+      LabeledVector(DenseVector(2, 3, 4), 2.0)
+    )
+
+    val expectedMap = List(
+      (1.0 -> DenseVector(8.0, 12.0, 18.0, 27.0, 4.0, 6.0, 9.0, 2.0, 3.0)),
+      (2.0 -> DenseVector(8.0, 12.0, 16.0, 18.0, 24.0, 32.0, 27.0, 36.0, 48.0, 64.0, 4.0, 6.0, 8.0,
+      9.0, 12.0, 16.0, 2.0, 3.0, 4.0))
+    ) toMap
+
+    val inputDS = env.fromCollection(input)
+
+    val transformer = PolynomialBase()
+    .setDegree(3)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val result = transformedDS.collect
+
+    for(entry <- result) {
+      expectedMap.contains(entry.label) should be(true)
+      entry.vector should equal(expectedMap(entry.label))
+    }
+
+    println(result)
+  }
+
+  it should "return an empty vector if the polynomial degree is set to 0" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val input = Seq(
+      LabeledVector(DenseVector(2, 3), 1.0),
+      LabeledVector(DenseVector(2, 3, 4), 2.0)
+    )
+
+    val inputDS = env.fromCollection(input)
+
+    val transformer = PolynomialBase()
+    .setDegree(0)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val result = transformedDS.collect
+
+    val expectedMap = List(
+      (1.0 -> DenseVector()),
+      (2.0 -> DenseVector())
+    ) toMap
+
+    for(entry <- result) {
+      expectedMap.contains(entry.label) should be(true)
+      entry.vector should equal(expectedMap(entry.label))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
index 8d59b49..006ff93 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.flink.ml.regression
 
 import org.apache.flink.api.scala.ExecutionEnvironment
 import org.apache.flink.ml.common.ParameterMap
+import org.apache.flink.ml.feature.PolynomialBase
 import org.scalatest.{ShouldMatchers, FlatSpec}
 
 import org.apache.flink.api.scala._
@@ -43,21 +44,57 @@ class MultipleLinearRegressionSuite extends FlatSpec with ShouldMatchers {
     val inputDS = env.fromCollection(data)
     val model = learner.fit(inputDS, parameters)
 
-    val betasList = model.weights.collect
+    val weightList = model.weights.collect
 
-    betasList.size should equal(1)
+    weightList.size should equal(1)
 
-    val (betas, beta0) = betasList(0)
+    val (weights, weight0) = weightList(0)
 
-    expectedBetas.data zip betas.data foreach {
-      case (expectedBeta, beta) => {
-        beta should be (expectedBeta +- 1)
-      }
+    expectedWeights.data zip weights.data foreach {
+      case (expectedWeight, weight) =>
+        weight should be (expectedWeight +- 1)
     }
-    beta0 should be (expectedBeta0 +- 0.4)
+    weight0 should be (expectedWeight0 +- 0.4)
 
     val srs = model.squaredResidualSum(inputDS).collect(0)
 
     srs should be (expectedSquaredResidualSum +- 2)
   }
+
+  it should "calculate the correct polynomial function" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val polynomialBase = PolynomialBase()
+    val learner = MultipleLinearRegression()
+
+    val pipeline = polynomialBase.chain(learner)
+
+    val inputDS = env.fromCollection(RegressionData.polynomialData)
+
+    val parameters = ParameterMap()
+    .add(PolynomialBase.Degree, 3)
+    .add(MultipleLinearRegression.Stepsize, 0.002)
+    .add(MultipleLinearRegression.Iterations, 100)
+
+    val model = pipeline.fit(inputDS, parameters)
+
+    val weightList = model.weights.collect
+
+    weightList.size should equal(1)
+
+    val (weights, weight0) = weightList(0)
+
+    RegressionData.expectedPolynomialWeights.zip(weights.data) foreach {
+      case (expectedWeight, weight) =>
+        weight should be(expectedWeight +- 0.1)
+    }
+
+    weight0 should be(RegressionData.expectedPolynomialWeight0 +- 0.1)
+
+    val transformedInput = polynomialBase.transform(inputDS, parameters)
+
+    val srs = model.squaredResidualSum(transformedInput).collect(0)
+
+    srs should be(RegressionData.expectedPolynomialSquaredResidualSum +- 5)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/effea93d/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
index c4050f3..138b6ec 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/RegressionData.scala
@@ -24,8 +24,8 @@ import org.jblas.DoubleMatrix
 
 object RegressionData {
 
-  val expectedBetas: DoubleMatrix = new DoubleMatrix(1, 1, 3.0094)
-  val expectedBeta0: Double = 9.8158
+  val expectedWeights: DoubleMatrix = new DoubleMatrix(1, 1, 3.0094)
+  val expectedWeight0: Double = 9.8158
   val expectedSquaredResidualSum: Double = 49.7596
 
   val data: Seq[LabeledVector] = Seq(
@@ -70,4 +70,61 @@ object RegressionData {
     LabeledVector(DenseVector(0.4249), 11.9999),
     LabeledVector(DenseVector(0.1192), 12.0442)
   )
+
+  val expectedPolynomialWeights = Seq(0.2375, -0.3493, -0.1674)
+  val expectedPolynomialWeight0 = 0.0233
+  val expectedPolynomialSquaredResidualSum = 1.5389e+03
+
+  val polynomialData: Seq[LabeledVector] = Seq(
+    LabeledVector(DenseVector(3.6663), 2.1415),
+    LabeledVector(DenseVector(4.0761), 10.9835),
+    LabeledVector(DenseVector(0.5714), 7.2507),
+    LabeledVector(DenseVector(4.1102), 11.9274),
+    LabeledVector(DenseVector(2.8456), -4.2798),
+    LabeledVector(DenseVector(0.4389), 7.1929),
+    LabeledVector(DenseVector(1.2532), 4.5097),
+    LabeledVector(DenseVector(2.4610), -3.6059),
+    LabeledVector(DenseVector(4.3088), 18.1132),
+    LabeledVector(DenseVector(4.3420), 19.2674),
+    LabeledVector(DenseVector(0.7093), 7.0664),
+    LabeledVector(DenseVector(4.3677), 20.1836),
+    LabeledVector(DenseVector(4.3073), 18.0609),
+    LabeledVector(DenseVector(2.1842), -2.2090),
+    LabeledVector(DenseVector(3.6013), 1.1306),
+    LabeledVector(DenseVector(0.6385), 7.1903),
+    LabeledVector(DenseVector(1.8979), -0.2668),
+    LabeledVector(DenseVector(4.1208), 12.2281),
+    LabeledVector(DenseVector(3.5649), 0.6086),
+    LabeledVector(DenseVector(4.3177), 18.4202),
+    LabeledVector(DenseVector(2.9508), -4.1284),
+    LabeledVector(DenseVector(0.1607), 6.1964),
+    LabeledVector(DenseVector(3.8211), 4.9638),
+    LabeledVector(DenseVector(4.2030), 14.6677),
+    LabeledVector(DenseVector(3.0543), -3.8132),
+    LabeledVector(DenseVector(3.4098), -1.2891),
+    LabeledVector(DenseVector(3.3441), -1.9390),
+    LabeledVector(DenseVector(1.7650), 0.7293),
+    LabeledVector(DenseVector(2.9497), -4.1310),
+    LabeledVector(DenseVector(0.7703), 6.9131),
+    LabeledVector(DenseVector(3.1772), -3.2060),
+    LabeledVector(DenseVector(0.1432), 6.0899),
+    LabeledVector(DenseVector(1.2462), 4.5567),
+    LabeledVector(DenseVector(0.2078), 6.4562),
+    LabeledVector(DenseVector(0.4371), 7.1903),
+    LabeledVector(DenseVector(3.7056), 2.8017),
+    LabeledVector(DenseVector(3.1267), -3.4873),
+    LabeledVector(DenseVector(1.4269), 3.2918),
+    LabeledVector(DenseVector(4.2760), 17.0085),
+    LabeledVector(DenseVector(0.1550), 6.1622),
+    LabeledVector(DenseVector(1.9743), -0.8192),
+    LabeledVector(DenseVector(1.7170), 1.0957),
+    LabeledVector(DenseVector(3.4448), -0.9065),
+    LabeledVector(DenseVector(3.5784), 0.7986),
+    LabeledVector(DenseVector(0.8409), 6.6861),
+    LabeledVector(DenseVector(2.2039), -2.3274),
+    LabeledVector(DenseVector(2.0051), -1.0359),
+    LabeledVector(DenseVector(2.9084), -4.2092),
+    LabeledVector(DenseVector(3.1921), -3.1140),
+    LabeledVector(DenseVector(3.3961), -1.4323)
+  )
 }