You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2015/04/01 11:21:16 UTC
[1/3] flink git commit: [ml] Adds convenience functions for Breeze
matrix/vector conversion
Repository: flink
Updated Branches:
refs/heads/master c63580244 -> d2e2d79fc
[ml] Adds convenience functions for Breeze matrix/vector conversion
[ml] Adds breeze to flink-dist LICENSE file
[ml] Optimizes sanity checks in vector/matrix accessors
[ml] Fixes scala check style error with missing whitespaces before and after +
[ml] Fixes DenseMatrixTest
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/5ddb2dd9
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/5ddb2dd9
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/5ddb2dd9
Branch: refs/heads/master
Commit: 5ddb2dd9634ab0908c99a08a1d0e10e761444120
Parents: 9219af7
Author: Till Rohrmann <tr...@apache.org>
Authored: Thu Mar 26 17:44:17 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Wed Apr 1 10:56:47 2015 +0200
----------------------------------------------------------------------
flink-dist/src/main/flink-bin/LICENSE | 1 +
.../scala/org/apache/flink/ml/math/Breeze.scala | 92 ++++++++++++++++++++
.../org/apache/flink/ml/math/DenseMatrix.scala | 4 +-
.../org/apache/flink/ml/math/DenseVector.scala | 6 +-
.../org/apache/flink/ml/math/SparseMatrix.scala | 10 +--
.../org/apache/flink/ml/math/SparseVector.scala | 6 +-
.../org/apache/flink/ml/math/package.scala | 82 ++++++++++++++---
.../regression/MultipleLinearRegression.scala | 12 +--
.../apache/flink/ml/math/BreezeMathTest.scala | 69 +++++++++++++++
.../apache/flink/ml/math/DenseVectorTest.scala | 2 +-
.../apache/flink/ml/math/SparseMatrixTest.scala | 42 +++++++--
.../apache/flink/ml/math/SparseVectorTest.scala | 18 +++-
12 files changed, 308 insertions(+), 36 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-dist/src/main/flink-bin/LICENSE
----------------------------------------------------------------------
diff --git a/flink-dist/src/main/flink-bin/LICENSE b/flink-dist/src/main/flink-bin/LICENSE
index d0b7fb4..8c733e4 100644
--- a/flink-dist/src/main/flink-bin/LICENSE
+++ b/flink-dist/src/main/flink-bin/LICENSE
@@ -250,6 +250,7 @@ under the Apache License (v 2.0):
- Twitter Hosebird Client (hbc) (https://github.com/twitter/hbc)
- Jettison (http://jettison.codehaus.org)
- Akka (http://akka.io)
+ - Breeze (https://github.com/scalanlp/breeze)
-----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
new file mode 100644
index 0000000..dffb984
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import breeze.linalg.{ Matrix => BreezeMatrix, DenseMatrix => BreezeDenseMatrix,
+CSCMatrix => BreezeCSCMatrix, DenseVector => BreezeDenseVector, SparseVector => BreezeSparseVector,
+Vector => BreezeVector}
+
+/** This class contains convenience function to wrap a matrix/vector into a breeze matrix/vector
+ * and to unwrap it again.
+ *
+ */
+object Breeze {
+
+ implicit class Matrix2BreezeConverter(matrix: Matrix) {
+ def asBreeze: BreezeMatrix[Double] = {
+ matrix match {
+ case dense: DenseMatrix =>
+ new BreezeDenseMatrix[Double](
+ dense.numRows,
+ dense.numCols,
+ dense.data)
+
+ case sparse: SparseMatrix =>
+ new BreezeCSCMatrix[Double](
+ sparse.data,
+ sparse.numRows,
+ sparse.numCols,
+ sparse.colPtrs,
+ sparse.rowIndices
+ )
+ }
+ }
+ }
+
+ implicit class Breeze2MatrixConverter(matrix: BreezeMatrix[Double]) {
+ def fromBreeze: Matrix = {
+ matrix match {
+ case dense: BreezeDenseMatrix[Double] =>
+ new DenseMatrix(dense.rows, dense.cols, dense.data)
+
+ case sparse: BreezeCSCMatrix[Double] =>
+ new SparseMatrix(sparse.rows, sparse.cols, sparse.rowIndices, sparse.colPtrs, sparse.data)
+ }
+ }
+ }
+
+ implicit class BreezeArrayConverter[T](array: Array[T]) {
+ def asBreeze: BreezeDenseVector[T] = {
+ new BreezeDenseVector[T](array)
+ }
+ }
+
+ implicit class Breeze2VectorConverter(vector: BreezeVector[Double]) {
+ def fromBreeze: Vector = {
+ vector match {
+ case dense: BreezeDenseVector[Double] => new DenseVector(dense.data)
+
+ case sparse: BreezeSparseVector[Double] =>
+ new SparseVector(sparse.length, sparse.index, sparse.data)
+ }
+ }
+ }
+
+ implicit class Vector2BreezeConverter(vector: Vector) {
+ def asBreeze: BreezeVector[Double] = {
+ vector match {
+ case dense: DenseVector =>
+ new BreezeDenseVector[Double](dense.data)
+
+ case sparse: SparseVector =>
+ new BreezeSparseVector[Double](sparse.indices, sparse.data, sparse.size)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index 72eae05..16291b8 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -134,8 +134,8 @@ case class DenseMatrix(val numRows: Int,
* @return
*/
private def locate(row: Int, col: Int): Int = {
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col < numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
row + col * numRows
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index 6d41d47..50992a9 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -41,8 +41,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
* @return element at the given index
*/
override def apply(index: Int): Double = {
- require(0 <= index && index < data.length, s"Index $index is out of bounds " +
- s"[0, ${data.length})")
+ require(0 <= index && index < data.length, index + " not in [0, " + data.length + ")")
data(index)
}
@@ -72,8 +71,7 @@ case class DenseVector(val data: Array[Double]) extends Vector {
* @param value
*/
override def update(index: Int, value: Double): Unit = {
- require(0 <= index && index < data.length, s"Index $index is out of bounds " +
- s"[0, ${data.length})")
+ require(0 <= index && index < data.length, index + " not in [0, " + data.length + ")")
data(index) = value
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
index a46202c..b065630 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
@@ -112,11 +112,11 @@ class SparseMatrix(
}
private def locate(row: Int, col: Int): Int = {
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col < numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
val startIndex = colPtrs(col)
- val endIndex = colPtrs(col+1)
+ val endIndex = colPtrs(col + 1)
java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row)
}
@@ -155,8 +155,8 @@ object SparseMatrix{
val entryArray = entries.toArray
entryArray.foreach{ case (row, col, _) =>
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Columm $col is out of bounds [0, $numCols).")
+ require(0 <= row && row < numRows && 0 <= col && col <= numCols,
+ (row, col) + " not in [0, " + numRows + ") x [0, " + numCols + ")")
}
val COOOrdering = new Ordering[(Int, Int, Double)] {
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
index 93da362..9fa69cb 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
@@ -78,7 +78,7 @@ class SparseVector(
}
private def locate(index: Int): Int = {
- require(0 <= index && index < size, s"Index $index is out of bounds [0, $size).")
+ require(0 <= index && index < size, index + " not in [0, " + size + ")")
java.util.Arrays.binarySearch(indices, 0, indices.length, index)
}
@@ -107,6 +107,10 @@ object SparseVector {
def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = {
val entryArray = entries.toArray
+ entryArray.foreach { case (index, _) =>
+ require(0 <= index && index < size, index + " not in [0, " + size + ")")
+ }
+
val COOOrdering = new Ordering[(Int, Double)] {
override def compare(x: (Int, Double), y: (Int, Double)): Int = {
x._1 - y._1
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index 3ab6143..4c7f254 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -23,26 +23,88 @@ package org.apache.flink.ml
* abstraction.
*/
package object math {
- implicit class RichMatrix(matrix: Matrix) extends Iterable[Double] {
+ implicit class RichMatrix(matrix: Matrix) extends Iterable[(Int, Int, Double)] {
- override def iterator: Iterator[Double] = {
- matrix match {
- case dense: DenseMatrix => dense.data.iterator
+ override def iterator: Iterator[(Int, Int, Double)] = {
+ new Iterator[(Int, Int, Double)] {
+ var index = 0
+
+ override def hasNext: Boolean = {
+ index < matrix.numRows * matrix.numCols
+ }
+
+ override def next(): (Int, Int, Double) = {
+ val row = index % matrix.numRows
+ val column = index / matrix.numRows
+
+ index += 1
+
+ (row, column, matrix(row, column))
+ }
+ }
+ }
+
+ def valueIterator: Iterator[Double] = {
+ val it = iterator
+
+ new Iterator[Double] {
+ override def hasNext: Boolean = it.hasNext
+
+ override def next(): Double = it.next._3
}
}
+
}
- implicit class RichVector(vector: Vector) extends Iterable[Double] {
- override def iterator: Iterator[Double] = {
- vector match {
- case dense: DenseVector => dense.data.iterator
+ implicit class RichVector(vector: Vector) extends Iterable[(Int, Double)] {
+
+ override def iterator: Iterator[(Int, Double)] = {
+ new Iterator[(Int, Double)] {
+ var index = 0
+
+ override def hasNext: Boolean = {
+ index < vector.size
+ }
+
+ override def next(): (Int, Double) = {
+ val resultIndex = index
+
+ index += 1
+
+ (resultIndex, vector(resultIndex))
+ }
+ }
+ }
+
+ def valueIterator: Iterator[Double] = {
+ val it = iterator
+
+ new Iterator[Double] {
+ override def hasNext: Boolean = it.hasNext
+
+ override def next(): Double = it.next._2
}
}
}
- implicit def vector2Array(vector: Vector): Array[Double] = {
+ /** Stores the vector values in a dense array
+ *
+ * @param vector
+ * @return Array containing the vector values
+ */
+ def vector2Array(vector: Vector): Array[Double] = {
vector match {
- case dense: DenseVector => dense.data
+ case dense: DenseVector => dense.data.clone
+
+ case sparse: SparseVector =>
+ val result = new Array[Double](sparse.size)
+
+ for((index, value) <- sparse) {
+ result(index) = value
+ }
+
+ result
+
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 8060d2b..9768cce 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -24,6 +24,8 @@ import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.math.Vector
import org.apache.flink.ml.common._
+import org.apache.flink.ml.math.vector2Array
+
import org.apache.flink.api.scala._
import com.github.fommil.netlib.BLAS.{ getInstance => blas }
@@ -283,14 +285,14 @@ private class SquaredResiduals extends RichMapFunction[LabeledVector, Double] {
}
override def map(value: LabeledVector): Double = {
- val vector = value.vector
+ val array = vector2Array(value.vector)
val label = value.label
- val dotProduct = blas.ddot(weightVector.length, weightVector, 1, vector, 1)
+ val dotProduct = blas.ddot(weightVector.length, weightVector, 1, array, 1)
val residual = dotProduct + weight0 - label
- residual*residual
+ residual * residual
}
}
@@ -322,7 +324,7 @@ RichMapFunction[LabeledVector, (Array[Double], Double, Int)] {
}
override def map(value: LabeledVector): (Array[Double], Double, Int) = {
- val x = value.vector
+ val x = vector2Array(value.vector)
val label = value.label
val dotProduct = blas.ddot(weightVector.length, weightVector, 1, x, 1)
@@ -435,7 +437,7 @@ Transformer[ Vector, LabeledVector ] {
}
override def map(value: Vector): LabeledVector = {
- val dotProduct = blas.ddot(weights.length, weights, 1, value, 1)
+ val dotProduct = blas.ddot(weights.length, weights, 1, vector2Array(value), 1)
val prediction = dotProduct + weight0
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
new file mode 100644
index 0000000..7084f2a
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import Breeze._
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class BreezeMathTest extends ShouldMatchers {
+
+ @Test
+ def testBreezeDenseMatrixWrapping: Unit = {
+ val numRows = 5
+ val numCols = 4
+
+ val data = Array.range(0, numRows * numCols)
+ val expectedData = Array.range(0, numRows * numCols).map(_ * 2)
+
+ val denseMatrix = DenseMatrix(numRows, numCols, data)
+ val expectedMatrix = DenseMatrix(numRows, numCols, expectedData)
+
+ val m = denseMatrix.asBreeze
+
+ val result = (m * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+
+ @Test
+ def testBreezeSparseMatrixWrapping: Unit = {
+ val numRows = 5
+ val numCols = 4
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 1),
+ (4, 3, 13),
+ (3, 2, 45),
+ (4, 0, 12))
+
+ val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 2),
+ (4, 3, 26),
+ (3, 2, 90),
+ (4, 0, 24))
+
+ val sm = sparseMatrix.asBreeze
+
+ val result = (sm * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
index 5da9fe2..66a51fe 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
@@ -32,7 +32,7 @@ class DenseVectorTest extends ShouldMatchers {
assertResult(data.length)(vector.size)
- data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
+ data.zip(vector.map(_._2)).foreach{case (expected, actual) => assertResult(expected)(actual)}
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
index a0e1d27..7fcdf54 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
@@ -25,9 +25,14 @@ class SparseMatrixTest extends ShouldMatchers {
@Test
def testSparseMatrixFromCOO: Unit = {
- val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
(3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
+
val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
(4, 2, 99), (1, 4, 91))
@@ -43,8 +48,22 @@ class SparseMatrixTest extends ShouldMatchers {
sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
+ val dataMap = data.
+ map{ case (row, col, value) => (row, col) -> value }.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ sparseMatrix(row, col) should be(dataMap.getOrElse((row, col), 0))
+ }
+
+ // test access to defined field even though it was set to 0
sparseMatrix(0, 1) = 10
+ // test that a non-defined field is not accessible
intercept[IllegalArgumentException]{
sparseMatrix(1, 1) = 1
}
@@ -52,18 +71,29 @@ class SparseMatrixTest extends ShouldMatchers {
@Test
def testInvalidIndexAccess: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
intercept[IllegalArgumentException] {
- sparseVector(-1)
+ sparseMatrix(-1, 4)
}
intercept[IllegalArgumentException] {
- sparseVector(5)
+ sparseMatrix(numRows, 0)
}
- sparseVector(0) should equal(0)
- sparseVector(3) should equal(3)
+ intercept[IllegalArgumentException] {
+ sparseMatrix(0, numCols)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(3, -1)
+ }
}
@Test
http://git-wip-us.apache.org/repos/asf/flink/blob/5ddb2dd9/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
index 5e514c6..88d4878 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
@@ -25,7 +25,10 @@ class SparseVectorTest extends ShouldMatchers{
@Test
def testDataAfterInitialization: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3))
+ val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3))
+ val size = 5
+ val sparseVector = SparseVector.fromCOO(size, data)
+
val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
val expectedDenseVector = DenseVector.zeros(5)
@@ -38,11 +41,22 @@ class SparseVectorTest extends ShouldMatchers{
val denseVector = sparseVector.toDenseVector
denseVector should equal(expectedDenseVector)
+
+ val dataMap = data.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(index <- 0 until size) {
+ sparseVector(index) should be(dataMap.getOrElse(index, 0))
+ }
}
@Test
def testInvalidIndexAccess: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5))
+ val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
intercept[IllegalArgumentException] {
sparseVector(-1)
[3/3] flink git commit: [FLINK-1718] [ml] Unifies existing test cases
Posted by tr...@apache.org.
[FLINK-1718] [ml] Unifies existing test cases
This closes #539.
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d2e2d79f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d2e2d79f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d2e2d79f
Branch: refs/heads/master
Commit: d2e2d79fc0052c064188940520c93bbd0c1b1d4b
Parents: 5ddb2dd
Author: Till Rohrmann <tr...@apache.org>
Authored: Tue Mar 31 17:01:12 2015 +0200
Committer: Till Rohrmann <tr...@apache.org>
Committed: Wed Apr 1 10:56:57 2015 +0200
----------------------------------------------------------------------
.../apache/flink/ml/math/BreezeMathSuite.scala | 68 ++++++++++
.../apache/flink/ml/math/BreezeMathTest.scala | 69 -----------
.../apache/flink/ml/math/DenseMatrixSuite.scala | 86 +++++++++++++
.../apache/flink/ml/math/DenseMatrixTest.scala | 89 -------------
.../apache/flink/ml/math/DenseVectorSuite.scala | 50 ++++++++
.../apache/flink/ml/math/DenseVectorTest.scala | 52 --------
.../flink/ml/math/SparseMatrixSuite.scala | 121 ++++++++++++++++++
.../apache/flink/ml/math/SparseMatrixTest.scala | 124 -------------------
.../flink/ml/math/SparseVectorSuite.scala | 90 ++++++++++++++
.../apache/flink/ml/math/SparseVectorTest.scala | 93 --------------
10 files changed, 415 insertions(+), 427 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala
new file mode 100644
index 0000000..b03f08f
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import Breeze._
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class BreezeMathSuite extends FlatSpec with Matchers {
+
+ behavior of "Breeze vector conversion"
+
+ it should "convert a DenseMatrix into breeze.linalg.DenseMatrix and vice versa" in {
+ val numRows = 5
+ val numCols = 4
+
+ val data = Array.range(0, numRows * numCols)
+ val expectedData = Array.range(0, numRows * numCols).map(_ * 2)
+
+ val denseMatrix = DenseMatrix(numRows, numCols, data)
+ val expectedMatrix = DenseMatrix(numRows, numCols, expectedData)
+
+ val m = denseMatrix.asBreeze
+
+ val result = (m * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+
+ it should "convert a SparseMatrix into breeze.linalg.CSCMatrix" in {
+ val numRows = 5
+ val numCols = 4
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 1),
+ (4, 3, 13),
+ (3, 2, 45),
+ (4, 0, 12))
+
+ val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols,
+ (0, 1, 2),
+ (4, 3, 26),
+ (3, 2, 90),
+ (4, 0, 24))
+
+ val sm = sparseMatrix.asBreeze
+
+ val result = (sm * 2.0).fromBreeze
+
+ result should equal(expectedMatrix)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
deleted file mode 100644
index 7084f2a..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/BreezeMathTest.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import Breeze._
-
-import org.junit.Test
-import org.scalatest.ShouldMatchers
-
-class BreezeMathTest extends ShouldMatchers {
-
- @Test
- def testBreezeDenseMatrixWrapping: Unit = {
- val numRows = 5
- val numCols = 4
-
- val data = Array.range(0, numRows * numCols)
- val expectedData = Array.range(0, numRows * numCols).map(_ * 2)
-
- val denseMatrix = DenseMatrix(numRows, numCols, data)
- val expectedMatrix = DenseMatrix(numRows, numCols, expectedData)
-
- val m = denseMatrix.asBreeze
-
- val result = (m * 2.0).fromBreeze
-
- result should equal(expectedMatrix)
- }
-
- @Test
- def testBreezeSparseMatrixWrapping: Unit = {
- val numRows = 5
- val numCols = 4
-
- val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols,
- (0, 1, 1),
- (4, 3, 13),
- (3, 2, 45),
- (4, 0, 12))
-
- val expectedMatrix = SparseMatrix.fromCOO(numRows, numCols,
- (0, 1, 2),
- (4, 3, 26),
- (3, 2, 90),
- (4, 0, 24))
-
- val sm = sparseMatrix.asBreeze
-
- val result = (sm * 2.0).fromBreeze
-
- result should equal(expectedMatrix)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
new file mode 100644
index 0000000..ca3d601
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class DenseMatrixSuite extends FlatSpec with Matchers {
+
+ behavior of "Flink's DenseMatrix"
+
+ it should "contain the initialization data" in {
+ val numRows = 10
+ val numCols = 13
+
+ val data = Array.range(0, numRows*numCols)
+
+ val matrix = DenseMatrix(numRows, numCols, data)
+
+ assertResult(numRows)(matrix.numRows)
+ assertResult(numCols)(matrix.numCols)
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ assertResult(data(col*numRows + row))(matrix(row, col))
+ }
+ }
+
+ it should "fail in case of invalid element access" in {
+ val numRows = 10
+ val numCols = 13
+
+ val matrix = DenseMatrix.zeros(numRows, numCols)
+
+ intercept[IllegalArgumentException] {
+ matrix(-1, 2)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(0, -1)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(numRows, 0)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(0, numCols)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(numRows, numCols)
+ }
+ }
+
+ it should "be copyable" in {
+ val numRows = 4
+ val numCols = 5
+
+ val data = Array.range(0, numRows*numCols)
+
+ val denseMatrix = DenseMatrix.apply(numRows, numCols, data)
+
+ val copy = denseMatrix.copy
+
+ denseMatrix should equal(copy)
+
+ copy(0, 0) = 1
+
+ denseMatrix should not equal(copy)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
deleted file mode 100644
index 12001fc..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.junit.Test
-import org.scalatest.ShouldMatchers
-
-class DenseMatrixTest extends ShouldMatchers {
-
- @Test
- def testDataAfterInitialization: Unit = {
- val numRows = 10
- val numCols = 13
-
- val data = Array.range(0, numRows*numCols)
-
- val matrix = DenseMatrix(numRows, numCols, data)
-
- assertResult(numRows)(matrix.numRows)
- assertResult(numCols)(matrix.numCols)
-
- for(row <- 0 until numRows; col <- 0 until numCols) {
- assertResult(data(col*numRows + row))(matrix(row, col))
- }
- }
-
- @Test
- def testIllegalArgumentExceptionInCaseOfInvalidIndexAccess: Unit = {
- val numRows = 10
- val numCols = 13
-
- val matrix = DenseMatrix.zeros(numRows, numCols)
-
- intercept[IllegalArgumentException] {
- matrix(-1, 2)
- }
-
- intercept[IllegalArgumentException] {
- matrix(0, -1)
- }
-
- intercept[IllegalArgumentException] {
- matrix(numRows, 0)
- }
-
- intercept[IllegalArgumentException] {
- matrix(0, numCols)
- }
-
- intercept[IllegalArgumentException] {
- matrix(numRows, numCols)
- }
- }
-
- @Test
- def testCopy: Unit = {
- val numRows = 4
- val numCols = 5
-
- val data = Array.range(0, numRows*numCols)
-
- val denseMatrix = DenseMatrix.apply(numRows, numCols, data)
-
- val copy = denseMatrix.copy
-
-
- denseMatrix should equal(copy)
-
- copy(0, 0) = 1
-
- denseMatrix should not equal(copy)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
new file mode 100644
index 0000000..553f672
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class DenseVectorSuite extends FlatSpec with Matchers {
+
+ behavior of "Flink's DenseVector"
+
+ it should "contain the initialization data" in {
+ val data = Array.range(1,10)
+
+ val vector = DenseVector(data)
+
+ assertResult(data.length)(vector.size)
+
+ data.zip(vector.map(_._2)).foreach{case (expected, actual) => assertResult(expected)(actual)}
+ }
+
+ it should "fail in case of an illegal element access" in {
+ val size = 10
+
+ val vector = DenseVector.zeros(size)
+
+ intercept[IllegalArgumentException] {
+ vector(-1)
+ }
+
+ intercept[IllegalArgumentException] {
+ vector(size)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
deleted file mode 100644
index 66a51fe..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.junit.Test
-import org.scalatest.ShouldMatchers
-
-
-class DenseVectorTest extends ShouldMatchers {
-
- @Test
- def testDataAfterInitialization {
- val data = Array.range(1,10)
-
- val vector = DenseVector(data)
-
- assertResult(data.length)(vector.size)
-
- data.zip(vector.map(_._2)).foreach{case (expected, actual) => assertResult(expected)(actual)}
- }
-
- @Test
- def testIllegalArgumentExceptionInCaseOfIllegalIndexAccess {
- val size = 10
-
- val vector = DenseVector.zeros(size)
-
- intercept[IllegalArgumentException] {
- vector(-1)
- }
-
- intercept[IllegalArgumentException] {
- vector(size)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
new file mode 100644
index 0000000..5710931
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixSuite.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class SparseMatrixSuite extends FlatSpec with Matchers {
+
+ behavior of "Flink's SparseMatrix"
+
+ it should "be initialized from a coordinate list representation (COO)" in {
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
+
+ val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
+ (4, 2, 99), (1, 4, 91))
+
+ val expectedDenseMatrix = DenseMatrix.zeros(5, 5)
+ expectedDenseMatrix(3, 4) = 42
+ expectedDenseMatrix(2, 1) = 17
+ expectedDenseMatrix(3, 3) = 88
+ expectedDenseMatrix(4, 2) = 99
+ expectedDenseMatrix(1, 4) = 91
+
+ sparseMatrix should equal(expectedSparseMatrix)
+ sparseMatrix should equal(expectedDenseMatrix)
+
+ sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
+
+ val dataMap = data.
+ map{ case (row, col, value) => (row, col) -> value }.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ sparseMatrix(row, col) should be(dataMap.getOrElse((row, col), 0))
+ }
+
+ // test access to defined field even though it was set to 0
+ sparseMatrix(0, 1) = 10
+
+ // test that a non-defined field is not accessible
+ intercept[IllegalArgumentException]{
+ sparseMatrix(1, 1) = 1
+ }
+ }
+
+ it should "fail when accessing zero elements or using invalid indices" in {
+ val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+ val numRows = 5
+ val numCols = 5
+
+ val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(-1, 4)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(numRows, 0)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(0, numCols)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseMatrix(3, -1)
+ }
+ }
+
+ it should "fail when elements of the COO list have invalid indices" in {
+ intercept[IllegalArgumentException]{
+ val sparseMatrix = SparseMatrix.fromCOO(5 ,5, (5, 0, 10), (0, 0, 0), (0, 1, 0), (3, 4, 43),
+ (2, 1, 17))
+ }
+
+ intercept[IllegalArgumentException]{
+ val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (-1, 4, 20))
+ }
+ }
+
+ it should "be copyable" in {
+ val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 1, 2), (2, 3, 1), (2, 0, 42), (1, 3, 3))
+
+ val copy = sparseMatrix.copy
+
+ sparseMatrix should equal(copy)
+
+ copy(2, 3) = 2
+
+ sparseMatrix should not equal(copy)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
deleted file mode 100644
index 7fcdf54..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.junit.Test
-import org.scalatest.ShouldMatchers
-
-class SparseMatrixTest extends ShouldMatchers {
-
- @Test
- def testSparseMatrixFromCOO: Unit = {
- val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
- (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
-
- val numRows = 5
- val numCols = 5
-
- val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
-
- val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
- (4, 2, 99), (1, 4, 91))
-
- val expectedDenseMatrix = DenseMatrix.zeros(5, 5)
- expectedDenseMatrix(3, 4) = 42
- expectedDenseMatrix(2, 1) = 17
- expectedDenseMatrix(3, 3) = 88
- expectedDenseMatrix(4, 2) = 99
- expectedDenseMatrix(1, 4) = 91
-
- sparseMatrix should equal(expectedSparseMatrix)
- sparseMatrix should equal(expectedDenseMatrix)
-
- sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
-
- val dataMap = data.
- map{ case (row, col, value) => (row, col) -> value }.
- groupBy{_._1}.
- mapValues{
- entries =>
- entries.map(_._2).reduce(_ + _)
- }
-
- for(row <- 0 until numRows; col <- 0 until numCols) {
- sparseMatrix(row, col) should be(dataMap.getOrElse((row, col), 0))
- }
-
- // test access to defined field even though it was set to 0
- sparseMatrix(0, 1) = 10
-
- // test that a non-defined field is not accessible
- intercept[IllegalArgumentException]{
- sparseMatrix(1, 1) = 1
- }
- }
-
- @Test
- def testInvalidIndexAccess: Unit = {
- val data = List[(Int, Int, Double)]((0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
- (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
-
- val numRows = 5
- val numCols = 5
-
- val sparseMatrix = SparseMatrix.fromCOO(numRows, numCols, data)
-
- intercept[IllegalArgumentException] {
- sparseMatrix(-1, 4)
- }
-
- intercept[IllegalArgumentException] {
- sparseMatrix(numRows, 0)
- }
-
- intercept[IllegalArgumentException] {
- sparseMatrix(0, numCols)
- }
-
- intercept[IllegalArgumentException] {
- sparseMatrix(3, -1)
- }
- }
-
- @Test
- def testSparseMatrixFromCOOWithInvalidIndices: Unit = {
- intercept[IllegalArgumentException]{
- val sparseMatrix = SparseMatrix.fromCOO(5 ,5, (5, 0, 10), (0, 0, 0), (0, 1, 0), (3, 4, 43),
- (2, 1, 17))
- }
-
- intercept[IllegalArgumentException]{
- val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
- (-1, 4, 20))
- }
- }
-
- @Test
- def testSparseMatrixCopy: Unit = {
- val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 1, 2), (2, 3, 1), (2, 0, 42), (1, 3, 3))
-
- val copy = sparseMatrix.copy
-
- sparseMatrix should equal(copy)
-
- copy(2, 3) = 2
-
- sparseMatrix should not equal(copy)
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
new file mode 100644
index 0000000..28415e8
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.scalatest.{Matchers, FlatSpec}
+
+class SparseVectorSuite extends FlatSpec with Matchers {
+
+ behavior of "Flink's SparseVector"
+
+ it should "contain the initialization data provided as coordinate list (COO)" in {
+ val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3))
+ val size = 5
+ val sparseVector = SparseVector.fromCOO(size, data)
+
+ val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
+ val expectedDenseVector = DenseVector.zeros(5)
+
+ expectedDenseVector(0) = 4
+ expectedDenseVector(4) = 42
+
+ sparseVector should equal(expectedSparseVector)
+ sparseVector should equal(expectedDenseVector)
+
+ val denseVector = sparseVector.toDenseVector
+
+ denseVector should equal(expectedDenseVector)
+
+ val dataMap = data.
+ groupBy{_._1}.
+ mapValues{
+ entries =>
+ entries.map(_._2).reduce(_ + _)
+ }
+
+ for(index <- 0 until size) {
+ sparseVector(index) should be(dataMap.getOrElse(index, 0))
+ }
+ }
+
+ it should "fail when accessing elements using an invalid index" in {
+ val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
+
+ intercept[IllegalArgumentException] {
+ sparseVector(-1)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseVector(5)
+ }
+ }
+
+ it should "fail when the COO list contains elements with invalid indices" in {
+ intercept[IllegalArgumentException] {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (-1, 34), (3, 2))
+ }
+
+ intercept[IllegalArgumentException] {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (4,3), (5, 1))
+ }
+ }
+
+ it should "be copyable" in {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 3), (3, 2))
+
+ val copy = sparseVector.copy
+
+ sparseVector should equal(copy)
+
+ copy(3) = 3
+
+ sparseVector should not equal(copy)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/d2e2d79f/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
deleted file mode 100644
index 88d4878..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.junit.Test
-import org.scalatest.ShouldMatchers
-
-class SparseVectorTest extends ShouldMatchers{
-
- @Test
- def testDataAfterInitialization: Unit = {
- val data = List[(Int, Double)]((0, 1), (2, 0), (4, 42), (0, 3))
- val size = 5
- val sparseVector = SparseVector.fromCOO(size, data)
-
- val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
- val expectedDenseVector = DenseVector.zeros(5)
-
- expectedDenseVector(0) = 4
- expectedDenseVector(4) = 42
-
- sparseVector should equal(expectedSparseVector)
- sparseVector should equal(expectedDenseVector)
-
- val denseVector = sparseVector.toDenseVector
-
- denseVector should equal(expectedDenseVector)
-
- val dataMap = data.
- groupBy{_._1}.
- mapValues{
- entries =>
- entries.map(_._2).reduce(_ + _)
- }
-
- for(index <- 0 until size) {
- sparseVector(index) should be(dataMap.getOrElse(index, 0))
- }
- }
-
- @Test
- def testInvalidIndexAccess: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
-
- intercept[IllegalArgumentException] {
- sparseVector(-1)
- }
-
- intercept[IllegalArgumentException] {
- sparseVector(5)
- }
- }
-
- @Test
- def testSparseVectorFromCOOWithInvalidIndices: Unit = {
- intercept[IllegalArgumentException] {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (-1, 34), (3, 2))
- }
-
- intercept[IllegalArgumentException] {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (4,3), (5, 1))
- }
- }
-
- @Test
- def testSparseVectorCopy: Unit = {
- val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 3), (3, 2))
-
- val copy = sparseVector.copy
-
- sparseVector should equal(copy)
-
- copy(3) = 3
-
- sparseVector should not equal(copy)
- }
-}
[2/3] flink git commit: [FLINK-1718] [ml] Adds sparse matrix and
sparse vector types
Posted by tr...@apache.org.
[FLINK-1718] [ml] Adds sparse matrix and sparse vector types
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/9219af7b
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/9219af7b
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/9219af7b
Branch: refs/heads/master
Commit: 9219af7b63321ea78af67579bdc68eecd895acaa
Parents: c635802
Author: Till Rohrmann <tr...@apache.org>
Authored: Wed Mar 25 15:27:58 2015 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Wed Apr 1 10:56:47 2015 +0200
----------------------------------------------------------------------
flink-staging/flink-ml/pom.xml | 6 +-
.../org/apache/flink/ml/math/DenseMatrix.scala | 125 +++++++++-
.../org/apache/flink/ml/math/DenseVector.scala | 42 ++--
.../scala/org/apache/flink/ml/math/Matrix.scala | 58 +++--
.../org/apache/flink/ml/math/SparseMatrix.scala | 235 +++++++++++++++++++
.../org/apache/flink/ml/math/SparseVector.scala | 156 ++++++++++++
.../scala/org/apache/flink/ml/math/Vector.scala | 52 ++--
.../org/apache/flink/ml/math/package.scala | 6 +-
.../apache/flink/ml/math/DenseMatrixSuite.scala | 69 ------
.../apache/flink/ml/math/DenseMatrixTest.scala | 89 +++++++
.../apache/flink/ml/math/DenseVectorSuite.scala | 50 ----
.../apache/flink/ml/math/DenseVectorTest.scala | 52 ++++
.../apache/flink/ml/math/SparseMatrixTest.scala | 94 ++++++++
.../apache/flink/ml/math/SparseVectorTest.scala | 79 +++++++
14 files changed, 926 insertions(+), 187 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/pom.xml
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/pom.xml b/flink-staging/flink-ml/pom.xml
index 4f251e5..899d266 100644
--- a/flink-staging/flink-ml/pom.xml
+++ b/flink-staging/flink-ml/pom.xml
@@ -41,9 +41,9 @@
</dependency>
<dependency>
- <groupId>com.github.fommil.netlib</groupId>
- <artifactId>core</artifactId>
- <version>1.1.2</version>
+ <groupId>org.scalanlp</groupId>
+ <artifactId>breeze_2.10</artifactId>
+ <version>0.11.1</version>
</dependency>
<dependency>
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
index f3bd630..72eae05 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseMatrix.scala
@@ -24,13 +24,15 @@ package org.apache.flink.ml.math
*
* @param numRows Number of rows
* @param numCols Number of columns
- * @param values Array of matrix elements in column major order
+ * @param data Array of matrix elements in column major order
*/
case class DenseMatrix(val numRows: Int,
val numCols: Int,
- val values: Array[Double]) extends Matrix {
+ val data: Array[Double]) extends Matrix {
- require(numRows * numCols == values.length, s"The number of values ${values.length} does " +
+ import DenseMatrix._
+
+ require(numRows * numCols == data.length, s"The number of values ${data.length} does " +
s"not correspond to its dimensions ($numRows, $numCols).")
/**
@@ -41,32 +43,129 @@ case class DenseMatrix(val numRows: Int,
* @return matrix entry at (row, col)
*/
override def apply(row: Int, col: Int): Double = {
- require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
- require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+ val index = locate(row, col)
- val index = col * numRows + row
-
- values(index)
+ data(index)
}
override def toString: String = {
- s"DenseMatrix($numRows, $numCols, ${values.mkString(", ")})"
+ val result = StringBuilder.newBuilder
+ result.append(s"DenseMatrix($numRows, $numCols)\n")
+
+ val linewidth = LINE_WIDTH
+
+ val columnsFieldWidths = for(row <- 0 until math.min(numRows, MAX_ROWS)) yield {
+ var column = 0
+ var maxFieldWidth = 0
+
+ while(column * maxFieldWidth < linewidth && column < numCols) {
+ val fieldWidth = printEntry(row, column).length + 2
+
+ if(fieldWidth > maxFieldWidth) {
+ maxFieldWidth = fieldWidth
+ }
+
+ if(column * maxFieldWidth < linewidth) {
+ column += 1
+ }
+ }
+
+ (column, maxFieldWidth)
+ }
+
+ val (columns, fieldWidths) = columnsFieldWidths.unzip
+
+ val maxColumns = columns.min
+ val fieldWidth = fieldWidths.max
+
+ for(row <- 0 until math.min(numRows, MAX_ROWS)) {
+ for(col <- 0 until maxColumns) {
+ val str = printEntry(row, col)
+
+ result.append(" " * (fieldWidth - str.length) + str)
+ }
+
+ if(maxColumns < numCols) {
+ result.append("...")
+ }
+
+ result.append("\n")
+ }
+
+ if(numRows > MAX_ROWS) {
+ result.append("...\n")
+ }
+
+ result.toString()
}
override def equals(obj: Any): Boolean = {
obj match {
case dense: DenseMatrix =>
- numRows == dense.numRows && numCols == dense.numCols && values.zip(dense.values).forall {
- case (a, b) => a == b
- }
- case _ => false
+ numRows == dense.numRows && numCols == dense.numCols && data.sameElements(dense.data)
+ case _ => super.equals(obj)
+ }
+ }
+
+ /** Element wise update function
+ *
+ * @param row row index
+ * @param col column index
+ * @param value value to set at (row, col)
+ */
+ override def update(row: Int, col: Int, value: Double): Unit = {
+ val index = locate(row, col)
+
+ data(index) = value
+ }
+
+ def toSparseMatrix: SparseMatrix = {
+ val entries = for(row <- 0 until numRows; col <- 0 until numCols) yield {
+ (row, col, apply(row, col))
}
+
+ SparseMatrix.fromCOO(numRows, numCols, entries.filter(_._3 != 0))
}
+ /** Calculates the linear index of the respective matrix entry
+ *
+ * @param row
+ * @param col
+ * @return
+ */
+ private def locate(row: Int, col: Int): Int = {
+ require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
+ require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+
+ row + col * numRows
+ }
+
+ /** Converts the entry at (row, col) to string
+ *
+ * @param row
+ * @param col
+ * @return
+ */
+ private def printEntry(row: Int, col: Int): String = {
+ val index = locate(row, col)
+
+ data(index).toString
+ }
+
+ /** Copies the matrix instance
+ *
+ * @return Copy of itself
+ */
+ override def copy: DenseMatrix = {
+ new DenseMatrix(numRows, numCols, data.clone)
+ }
}
object DenseMatrix {
+ val LINE_WIDTH = 100
+ val MAX_ROWS = 50
+
def apply(numRows: Int, numCols: Int, values: Array[Int]): DenseMatrix = {
new DenseMatrix(numRows, numCols, values.map(_.toDouble))
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
index d407a70..6d41d47 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/DenseVector.scala
@@ -22,16 +22,16 @@ package org.apache.flink.ml.math
* Dense vector implementation of [[Vector]]. The data is represented in a continuous array of
* doubles.
*
- * @param values Array of doubles to store the vector elements
+ * @param data Array of doubles to store the vector elements
*/
-case class DenseVector(val values: Array[Double]) extends Vector {
+case class DenseVector(val data: Array[Double]) extends Vector {
/**
* Number of elements in a vector
* @return
*/
override def size: Int = {
- values.length
+ data.length
}
/**
@@ -41,23 +41,19 @@ case class DenseVector(val values: Array[Double]) extends Vector {
* @return element at the given index
*/
override def apply(index: Int): Double = {
- require(0 <= index && index < values.length, s"Index $index is out of bounds " +
- s"[0, ${values.length})")
- values(index)
+ require(0 <= index && index < data.length, s"Index $index is out of bounds " +
+ s"[0, ${data.length})")
+ data(index)
}
override def toString: String = {
- s"DenseVector(${values.mkString(", ")})"
+ s"DenseVector(${data.mkString(", ")})"
}
override def equals(obj: Any): Boolean = {
obj match {
- case dense: DenseVector =>
- values.length == dense.values.length && values.zip(dense.values).forall{
- case (a,b) => a == b
- }
-
- case _ => false
+ case dense: DenseVector => data.length == dense.data.length && data.sameElements(dense.data)
+ case _ => super.equals(obj)
}
}
@@ -67,7 +63,25 @@ case class DenseVector(val values: Array[Double]) extends Vector {
* @return Copy of the vector instance
*/
override def copy: Vector = {
- DenseVector(values.clone())
+ DenseVector(data.clone())
+ }
+
+ /** Updates the element at the given index with the provided value
+ *
+ * @param index
+ * @param value
+ */
+ override def update(index: Int, value: Double): Unit = {
+ require(0 <= index && index < data.length, s"Index $index is out of bounds " +
+ s"[0, ${data.length})")
+
+ data(index) = value
+ }
+
+ def toSparseVector: SparseVector = {
+ val nonZero = (0 until size).zip(data).filter(_._2 != 0)
+
+ SparseVector.fromCOO(size, nonZero)
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
index 62ea85a..11b4e55 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Matrix.scala
@@ -18,28 +18,52 @@
package org.apache.flink.ml.math
-/**
- * Base trait for a matrix representation
- */
+/** Base trait for a matrix representation
+ *
+ */
trait Matrix {
- /**
- * Number of rows
- * @return
- */
+ /** Number of rows
+ *
+ * @return
+ */
def numRows: Int
- /**
- * Number of columns
- * @return
- */
+ /** Number of columns
+ *
+ * @return
+ */
def numCols: Int
- /**
- * Element wise access function
- * @param row row index
- * @param col column index
- * @return matrix entry at (row, col)
- */
+ /** Element wise access function
+ *
+ * @param row row index
+ * @param col column index
+ * @return matrix entry at (row, col)
+ */
def apply(row: Int, col: Int): Double
+
+ /** Element wise update function
+ *
+ * @param row row index
+ * @param col column index
+ * @param value value to set at (row, col)
+ */
+ def update(row: Int, col: Int, value: Double): Unit
+
+ /** Copies the matrix instance
+ *
+ * @return Copy of itself
+ */
+ def copy: Matrix
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case matrix: Matrix if numRows == matrix.numRows && numCols == matrix.numCols =>
+ val coordinates = for(row <- 0 until numRows; col <- 0 until numCols) yield (row, col)
+ coordinates forall { case(row, col) => this.apply(row, col) == matrix(row, col)}
+ case _ => false
+ }
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
new file mode 100644
index 0000000..a46202c
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseMatrix.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import scala.util.Sorting
+
+/** Sparse matrix using the compressed sparse column (CSC) representation.
+ *
+ * More details concerning the compressed sparse column (CSC) representation can be found
+ * [http://en.wikipedia.org/wiki/Sparse_matrix#Compressed_sparse_column_.28CSC_or_CCS.29].
+ *
+ * @param numRows Number of rows
+ * @param numCols Number of columns
+ * @param rowIndices Array containing the row indices of non-zero entries
+ * @param colPtrs Array containing the starting offsets in data for each column
+ * @param data Array containing the non-zero entries in column-major order
+ */
+class SparseMatrix(
+ val numRows: Int,
+ val numCols: Int,
+ val rowIndices: Array[Int],
+ val colPtrs: Array[Int],
+ val data: Array[Double])
+ extends Matrix {
+
+ /** Element wise access function
+ *
+ * @param row row index
+ * @param col column index
+ * @return matrix entry at (row, col)
+ */
+ override def apply(row: Int, col: Int): Double = {
+
+ val index = locate(row, col)
+
+ if(index < 0){
+ 0
+ } else {
+ data(index)
+ }
+ }
+
+ def toDenseMatrix: DenseMatrix = {
+ val result = DenseMatrix.zeros(numRows, numCols)
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ result(row, col) = apply(row, col)
+ }
+
+ result
+ }
+
+ /** Element wise update function
+ *
+ * @param row row index
+ * @param col column index
+ * @param value value to set at (row, col)
+ */
+ override def update(row: Int, col: Int, value: Double): Unit = {
+ val index = locate(row, col)
+
+ if(index < 0) {
+ throw new IllegalArgumentException("Cannot update zero value of sparse matrix at index " +
+ s"($row, $col)")
+ } else {
+ data(index) = value
+ }
+ }
+
+ override def toString: String = {
+ val result = StringBuilder.newBuilder
+
+ result.append(s"SparseMatrix($numRows, $numCols)\n")
+
+ var columnIndex = 0
+
+ val fieldWidth = math.max(numRows, numCols).toString.length
+ val valueFieldWidth = data.map(_.toString.length).max + 2
+
+ for(index <- 0 until colPtrs.last) {
+ while(colPtrs(columnIndex + 1) <= index){
+ columnIndex += 1
+ }
+
+ val rowStr = rowIndices(index).toString
+ val columnStr = columnIndex.toString
+ val valueStr = data(index).toString
+
+ result.append("(" + " " * (fieldWidth - rowStr.length) + rowStr + "," +
+ " " * (fieldWidth - columnStr.length) + columnStr + ")")
+ result.append(" " * (valueFieldWidth - valueStr.length) + valueStr)
+ result.append("\n")
+ }
+
+ result.toString
+ }
+
+ private def locate(row: Int, col: Int): Int = {
+ require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
+ require(0 <= col && col < numCols, s"Col $col is out of bounds [0, $numCols).")
+
+ val startIndex = colPtrs(col)
+ val endIndex = colPtrs(col+1)
+
+ java.util.Arrays.binarySearch(rowIndices, startIndex, endIndex, row)
+ }
+
+ /** Copies the matrix instance
+ *
+ * @return Copy of itself
+ */
+ override def copy: SparseMatrix = {
+ new SparseMatrix(numRows, numCols, rowIndices.clone, colPtrs.clone(), data.clone)
+ }
+}
+
+object SparseMatrix{
+
+ /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry
+ * is stored as a tuple of (rowIndex, columnIndex, value).
+ * @param numRows
+ * @param numCols
+ * @param entries
+ * @return
+ */
+ def fromCOO(numRows: Int, numCols: Int, entries: (Int, Int, Double)*): SparseMatrix = {
+ fromCOO(numRows, numCols, entries)
+ }
+
+ /** Constructs a sparse matrix from a coordinate list (COO) representation where each entry
+ * is stored as a tuple of (rowIndex, columnIndex, value).
+ *
+ * @param numRows
+ * @param numCols
+ * @param entries
+ * @return
+ */
+ def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = {
+ val entryArray = entries.toArray
+
+ entryArray.foreach{ case (row, col, _) =>
+ require(0 <= row && row < numRows, s"Row $row is out of bounds [0, $numRows).")
+ require(0 <= col && col < numCols, s"Columm $col is out of bounds [0, $numCols).")
+ }
+
+ val COOOrdering = new Ordering[(Int, Int, Double)] {
+ override def compare(x: (Int, Int, Double), y: (Int, Int, Double)): Int = {
+ if(x._2 < y._2) {
+ -1
+ } else if(x._2 > y._2) {
+ 1
+ } else {
+ x._1 - y._1
+ }
+ }
+ }
+
+ Sorting.quickSort(entryArray)(COOOrdering)
+
+ val nnz = entryArray.length
+
+ val data = new Array[Double](nnz)
+ val rowIndices = new Array[Int](nnz)
+ val colPtrs = new Array[Int](numCols + 1)
+
+ var (lastRow, lastCol, lastValue) = entryArray(0)
+
+ rowIndices(0) = lastRow
+ data(0) = lastValue
+
+ var i = 1
+ var lastDataIndex = 0
+
+ while(i < nnz) {
+ val (curRow, curCol, curValue) = entryArray(i)
+
+ if(lastRow == curRow && lastCol == curCol) {
+ // add values with identical coordinates
+ data(lastDataIndex) += curValue
+ } else {
+ lastDataIndex += 1
+ data(lastDataIndex) = curValue
+ rowIndices(lastDataIndex) = curRow
+ lastRow = curRow
+ }
+
+ while(lastCol < curCol) {
+ lastCol += 1
+ colPtrs(lastCol) = lastDataIndex
+ }
+
+ i += 1
+ }
+
+ lastDataIndex += 1
+ while(lastCol < numCols) {
+ colPtrs(lastCol + 1) = lastDataIndex
+ lastCol += 1
+ }
+
+ val prunedRowIndices = if(lastDataIndex < nnz) {
+ val prunedArray = new Array[Int](lastDataIndex)
+ rowIndices.copyToArray(prunedArray)
+ prunedArray
+ } else {
+ rowIndices
+ }
+
+ val prunedData = if(lastDataIndex < nnz) {
+ val prunedArray = new Array[Double](lastDataIndex)
+ data.copyToArray(prunedArray)
+ prunedArray
+ } else {
+ data
+ }
+
+ new SparseMatrix(numRows, numCols, prunedRowIndices, colPtrs, prunedData)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
new file mode 100644
index 0000000..93da362
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/SparseVector.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import scala.util.Sorting
+
+/** Sparse vector implementation storing the data in two arrays. One index contains the sorted
+ * indices of the non-zero vector entries and the other the corresponding vector entries
+ */
+class SparseVector(
+ val size: Int,
+ val indices: Array[Int],
+ val data: Array[Double])
+ extends Vector {
+ /** Updates the element at the given index with the provided value
+ *
+ * @param index
+ * @param value
+ */
+ override def update(index: Int, value: Double): Unit = {
+ val resolvedIndex = locate(index)
+
+ if (resolvedIndex < 0) {
+ throw new IllegalArgumentException("Cannot update zero value of sparse vector at index " +
+ index)
+ } else {
+ data(resolvedIndex) = value
+ }
+ }
+
+ /** Copies the vector instance
+ *
+ * @return Copy of the vector instance
+ */
+ override def copy: Vector = {
+ new SparseVector(size, indices.clone, data.clone)
+ }
+
+ /** Element wise access function
+ *
+ * * @param index index of the accessed element
+ * @return element with index
+ */
+ override def apply(index: Int): Double = {
+ val resolvedIndex = locate(index)
+
+ if(resolvedIndex < 0) {
+ 0
+ } else {
+ data(resolvedIndex)
+ }
+ }
+
+ def toDenseVector: DenseVector = {
+ val denseVector = DenseVector.zeros(size)
+
+ for(index <- 0 until size) {
+ denseVector(index) = this(index)
+ }
+
+ denseVector
+ }
+
+ private def locate(index: Int): Int = {
+ require(0 <= index && index < size, s"Index $index is out of bounds [0, $size).")
+
+ java.util.Arrays.binarySearch(indices, 0, indices.length, index)
+ }
+}
+
+object SparseVector {
+
+ /** Constructs a sparse vector from a coordinate list (COO) representation where each entry
+ * is stored as a tuple of (index, value).
+ *
+ * @param size
+ * @param entries
+ * @return
+ */
+ def fromCOO(size: Int, entries: (Int, Double)*): SparseVector = {
+ fromCOO(size, entries)
+ }
+
+ /** Constructs a sparse vector from a coordinate list (COO) representation where each entry
+ * is stored as a tuple of (index, value).
+ *
+ * @param size
+ * @param entries
+ * @return
+ */
+ def fromCOO(size: Int, entries: Iterable[(Int, Double)]): SparseVector = {
+ val entryArray = entries.toArray
+
+ val COOOrdering = new Ordering[(Int, Double)] {
+ override def compare(x: (Int, Double), y: (Int, Double)): Int = {
+ x._1 - y._1
+ }
+ }
+
+ Sorting.quickSort(entryArray)(COOOrdering)
+
+ // calculate size of the array
+ val arraySize = entryArray.foldLeft((-1, 0)){ case ((lastIndex, numRows), (index, _)) =>
+ if(lastIndex == index) {
+ (lastIndex, numRows)
+ } else {
+ (index, numRows + 1)
+ }
+ }._2
+
+ val indices = new Array[Int](arraySize)
+ val data = new Array[Double](arraySize)
+
+ val (index, value) = entryArray(0)
+
+ indices(0) = index
+ data(0) = value
+
+ var i = 1
+ var lastIndex = indices(0)
+ var lastDataIndex = 0
+
+ while(i < entryArray.length) {
+ val (curIndex, curValue) = entryArray(i)
+
+ if(curIndex == lastIndex) {
+ data(lastDataIndex) += curValue
+ } else {
+ lastDataIndex += 1
+ data(lastDataIndex) = curValue
+ indices(lastDataIndex) = curIndex
+ lastIndex = curIndex
+ }
+
+ i += 1
+ }
+
+ new SparseVector(size, indices, data)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
index 20d820c..7e7c32c 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Vector.scala
@@ -18,29 +18,45 @@
package org.apache.flink.ml.math
-/**
- * Base trait for Vectors
- */
+/** Base trait for Vectors
+ *
+ */
trait Vector {
- /**
- * Number of elements in a vector
- * @return
- */
+ /** Number of elements in a vector
+ *
+ * @return
+ */
def size: Int
- /**
- * Element wise access function
- *
- * @param index index of the accessed element
- * @return element with index
- */
+ /** Element wise access function
+ *
+ * * @param index index of the accessed element
+ * @return element with index
+ */
def apply(index: Int): Double
- /**
- * Copies the vector instance
- *
- * @return Copy of the vector instance
- */
+ /** Updates the element at the given index with the provided value
+ *
+ * @param index
+ * @param value
+ */
+ def update(index: Int, value: Double): Unit
+
+ /** Copies the vector instance
+ *
+ * @return Copy of the vector instance
+ */
def copy: Vector
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case vector: Vector if size == vector.size =>
+ 0 until size forall { idx =>
+ this(idx) == vector(idx)
+ }
+
+ case _ => false
+ }
+ }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
index e82e38f..3ab6143 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/package.scala
@@ -27,7 +27,7 @@ package object math {
override def iterator: Iterator[Double] = {
matrix match {
- case dense: DenseMatrix => dense.values.iterator
+ case dense: DenseMatrix => dense.data.iterator
}
}
}
@@ -35,14 +35,14 @@ package object math {
implicit class RichVector(vector: Vector) extends Iterable[Double] {
override def iterator: Iterator[Double] = {
vector match {
- case dense: DenseVector => dense.values.iterator
+ case dense: DenseVector => dense.data.iterator
}
}
}
implicit def vector2Array(vector: Vector): Array[Double] = {
vector match {
- case dense: DenseVector => dense.values
+ case dense: DenseVector => dense.data
}
}
}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
deleted file mode 100644
index be5db08..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixSuite.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.scalatest.FlatSpec
-
-class DenseMatrixSuite extends FlatSpec {
-
- behavior of "A DenseMatrix"
-
- it should "contain the initialization data after intialization" in {
- val numRows = 10
- val numCols = 13
-
- val data = Array.range(0, numRows*numCols)
-
- val matrix = DenseMatrix(numRows, numCols, data)
-
- assertResult(numRows)(matrix.numRows)
- assertResult(numCols)(matrix.numCols)
-
- for(row <- 0 until numRows; col <- 0 until numCols) {
- assertResult(data(col*numRows + row))(matrix(row, col))
- }
- }
-
- it should "throw an IllegalArgumentException in case of an invalid index access" in {
- val numRows = 10
- val numCols = 13
-
- val matrix = DenseMatrix.zeros(numRows, numCols)
-
- intercept[IllegalArgumentException] {
- matrix(-1, 2)
- }
-
- intercept[IllegalArgumentException] {
- matrix(0, -1)
- }
-
- intercept[IllegalArgumentException] {
- matrix(numRows, 0)
- }
-
- intercept[IllegalArgumentException] {
- matrix(0, numCols)
- }
-
- intercept[IllegalArgumentException] {
- matrix(numRows, numCols)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
new file mode 100644
index 0000000..12001fc
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseMatrixTest.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class DenseMatrixTest extends ShouldMatchers {
+
+ @Test
+ def testDataAfterInitialization: Unit = {
+ val numRows = 10
+ val numCols = 13
+
+ val data = Array.range(0, numRows*numCols)
+
+ val matrix = DenseMatrix(numRows, numCols, data)
+
+ assertResult(numRows)(matrix.numRows)
+ assertResult(numCols)(matrix.numCols)
+
+ for(row <- 0 until numRows; col <- 0 until numCols) {
+ assertResult(data(col*numRows + row))(matrix(row, col))
+ }
+ }
+
+ @Test
+ def testIllegalArgumentExceptionInCaseOfInvalidIndexAccess: Unit = {
+ val numRows = 10
+ val numCols = 13
+
+ val matrix = DenseMatrix.zeros(numRows, numCols)
+
+ intercept[IllegalArgumentException] {
+ matrix(-1, 2)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(0, -1)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(numRows, 0)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(0, numCols)
+ }
+
+ intercept[IllegalArgumentException] {
+ matrix(numRows, numCols)
+ }
+ }
+
+ @Test
+ def testCopy: Unit = {
+ val numRows = 4
+ val numCols = 5
+
+ val data = Array.range(0, numRows*numCols)
+
+ val denseMatrix = DenseMatrix.apply(numRows, numCols, data)
+
+ val copy = denseMatrix.copy
+
+
+ denseMatrix should equal(copy)
+
+ copy(0, 0) = 1
+
+ denseMatrix should not equal(copy)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
deleted file mode 100644
index ae1e012..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorSuite.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.math
-
-import org.scalatest.FlatSpec
-
-class DenseVectorSuite extends FlatSpec {
-
- behavior of "A DenseVector"
-
- it should "contain the initialization data after initialization" in {
- val data = Array.range(1,10)
-
- val vector = DenseVector(data)
-
- assertResult(data.length)(vector.size)
-
- data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
- }
-
- it should "throw an IllegalArgumentException in case of an illegal index access" in {
- val size = 10
-
- val vector = DenseVector.zeros(size)
-
- intercept[IllegalArgumentException] {
- vector(-1)
- }
-
- intercept[IllegalArgumentException] {
- vector(size)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
new file mode 100644
index 0000000..5da9fe2
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/DenseVectorTest.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+
+class DenseVectorTest extends ShouldMatchers {
+
+ @Test
+ def testDataAfterInitialization {
+ val data = Array.range(1,10)
+
+ val vector = DenseVector(data)
+
+ assertResult(data.length)(vector.size)
+
+ data.zip(vector).foreach{case (expected, actual) => assertResult(expected)(actual)}
+ }
+
+ @Test
+ def testIllegalArgumentExceptionInCaseOfIllegalIndexAccess {
+ val size = 10
+
+ val vector = DenseVector.zeros(size)
+
+ intercept[IllegalArgumentException] {
+ vector(-1)
+ }
+
+ intercept[IllegalArgumentException] {
+ vector(size)
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
new file mode 100644
index 0000000..a0e1d27
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseMatrixTest.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class SparseMatrixTest extends ShouldMatchers {
+
+ @Test
+ def testSparseMatrixFromCOO: Unit = {
+ val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (3, 3, 88), (4 , 2, 99), (1, 4, 91), (3, 4, -1))
+
+ val expectedSparseMatrix = SparseMatrix.fromCOO(5, 5, (3, 4, 42), (2, 1, 17), (3, 3, 88),
+ (4, 2, 99), (1, 4, 91))
+
+ val expectedDenseMatrix = DenseMatrix.zeros(5, 5)
+ expectedDenseMatrix(3, 4) = 42
+ expectedDenseMatrix(2, 1) = 17
+ expectedDenseMatrix(3, 3) = 88
+ expectedDenseMatrix(4, 2) = 99
+ expectedDenseMatrix(1, 4) = 91
+
+ sparseMatrix should equal(expectedSparseMatrix)
+ sparseMatrix should equal(expectedDenseMatrix)
+
+ sparseMatrix.toDenseMatrix.data.sameElements(expectedDenseMatrix.data) should be(true)
+
+ sparseMatrix(0, 1) = 10
+
+ intercept[IllegalArgumentException]{
+ sparseMatrix(1, 1) = 1
+ }
+ }
+
+ @Test
+ def testInvalidIndexAccess: Unit = {
+ val sparseVector = SparseVector.fromCOO(5, (1, 1), (3, 3), (4, 4))
+
+ intercept[IllegalArgumentException] {
+ sparseVector(-1)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseVector(5)
+ }
+
+ sparseVector(0) should equal(0)
+ sparseVector(3) should equal(3)
+ }
+
+ @Test
+ def testSparseMatrixFromCOOWithInvalidIndices: Unit = {
+ intercept[IllegalArgumentException]{
+ val sparseMatrix = SparseMatrix.fromCOO(5 ,5, (5, 0, 10), (0, 0, 0), (0, 1, 0), (3, 4, 43),
+ (2, 1, 17))
+ }
+
+ intercept[IllegalArgumentException]{
+ val sparseMatrix = SparseMatrix.fromCOO(5, 5, (0, 0, 0), (0, 1, 0), (3, 4, 43), (2, 1, 17),
+ (-1, 4, 20))
+ }
+ }
+
+ @Test
+ def testSparseMatrixCopy: Unit = {
+ val sparseMatrix = SparseMatrix.fromCOO(4, 4, (0, 1, 2), (2, 3, 1), (2, 0, 42), (1, 3, 3))
+
+ val copy = sparseMatrix.copy
+
+ sparseMatrix should equal(copy)
+
+ copy(2, 3) = 2
+
+ sparseMatrix should not equal(copy)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/9219af7b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
new file mode 100644
index 0000000..5e514c6
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/math/SparseVectorTest.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.math
+
+import org.junit.Test
+import org.scalatest.ShouldMatchers
+
+class SparseVectorTest extends ShouldMatchers{
+
+ @Test
+ def testDataAfterInitialization: Unit = {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (2, 0), (4, 42), (0, 3))
+ val expectedSparseVector = SparseVector.fromCOO(5, (0, 4), (4, 42))
+ val expectedDenseVector = DenseVector.zeros(5)
+
+ expectedDenseVector(0) = 4
+ expectedDenseVector(4) = 42
+
+ sparseVector should equal(expectedSparseVector)
+ sparseVector should equal(expectedDenseVector)
+
+ val denseVector = sparseVector.toDenseVector
+
+ denseVector should equal(expectedDenseVector)
+ }
+
+ @Test
+ def testInvalidIndexAccess: Unit = {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 10), (3, 5))
+
+ intercept[IllegalArgumentException] {
+ sparseVector(-1)
+ }
+
+ intercept[IllegalArgumentException] {
+ sparseVector(5)
+ }
+ }
+
+ @Test
+ def testSparseVectorFromCOOWithInvalidIndices: Unit = {
+ intercept[IllegalArgumentException] {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (-1, 34), (3, 2))
+ }
+
+ intercept[IllegalArgumentException] {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (4,3), (5, 1))
+ }
+ }
+
+ @Test
+ def testSparseVectorCopy: Unit = {
+ val sparseVector = SparseVector.fromCOO(5, (0, 1), (4, 3), (3, 2))
+
+ val copy = sparseVector.copy
+
+ sparseVector should equal(copy)
+
+ copy(3) = 3
+
+ sparseVector should not equal(copy)
+ }
+}