You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2014/06/07 04:32:09 UTC
[1/3] git commit: MAHOUT-1574 - Add sparse handling to rows and
columns of DiagonalMatrix
Repository: mahout
Updated Branches:
refs/heads/master 9cf90546d -> 5083f5835
MAHOUT-1574 - Add sparse handling to rows and columns of DiagonalMatrix
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/dd78ed94
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/dd78ed94
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/dd78ed94
Branch: refs/heads/master
Commit: dd78ed9479559cd222f24fa0be57655cf2e3075b
Parents: 9cf9054
Author: Ted Dunning <td...@apache.org>
Authored: Fri Jun 6 19:19:03 2014 -0700
Committer: Ted Dunning <td...@apache.org>
Committed: Fri Jun 6 19:19:03 2014 -0700
----------------------------------------------------------------------
.../org/apache/mahout/math/DiagonalMatrix.java | 206 ++++++++++++++++++-
.../apache/mahout/math/DiagonalMatrixTest.java | 43 ++++
2 files changed, 244 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
index 2a027f7..3e20a4a 100644
--- a/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/DiagonalMatrix.java
@@ -17,6 +17,9 @@
package org.apache.mahout.math;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
private final Vector diagonal;
@@ -60,6 +63,195 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
throw new UnsupportedOperationException("Can't assign a row to a diagonal matrix");
}
+ @Override
+ public Vector viewRow(int row) {
+ return new SingleElementVector(row);
+ }
+
+ @Override
+ public Vector viewColumn(int row) {
+ return new SingleElementVector(row);
+ }
+
+ /**
+ * Special class to implement views of rows and columns of a diagonal matrix.
+ */
+ public class SingleElementVector extends AbstractVector {
+ private int index;
+
+ public SingleElementVector(int index) {
+ super(diagonal.size());
+ this.index = index;
+ }
+
+ @Override
+ public double getQuick(int index) {
+ if (index == this.index) {
+ return diagonal.get(index);
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public void set(int index, double value) {
+ if (index == this.index) {
+ diagonal.set(index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set off-diagonal element of diagonal matrix");
+ }
+ }
+
+ @Override
+ protected Iterator<Element> iterateNonZero() {
+ return new Iterator<Element>() {
+ boolean more = true;
+
+ @Override
+ public boolean hasNext() {
+ return more;
+ }
+
+ @Override
+ public Element next() {
+ if (more) {
+ more = false;
+ return new Element() {
+ @Override
+ public double get() {
+ return diagonal.get(index);
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ diagonal.set(index, value);
+ }
+ };
+ } else {
+ throw new NoSuchElementException("Only one non-zero element in a row or column of a diagonal matrix");
+ }
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Can't remove from vector view");
+ }
+ };
+ }
+
+ @Override
+ protected Iterator<Element> iterator() {
+ return new Iterator<Element>() {
+ int i = 0;
+
+ Element r = new Element() {
+ @Override
+ public double get() {
+ if (i == index) {
+ return diagonal.get(index);
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public int index() {
+ return i;
+ }
+
+ @Override
+ public void set(double value) {
+ if (i == index) {
+ diagonal.set(index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set any element but diagonal");
+ }
+ }
+ };
+
+ @Override
+ public boolean hasNext() {
+ return i < diagonal.size() - 1;
+ }
+
+ @Override
+ public Element next() {
+ if (i < SingleElementVector.this.size() - 1) {
+ i++;
+ return r;
+ } else {
+ throw new NoSuchElementException("Attempted to access passed last element of vector");
+ }
+ }
+
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException("Default operation");
+ }
+ };
+ }
+
+ @Override
+ protected Matrix matrixLike(int rows, int columns) {
+ return new DiagonalMatrix(rows, columns);
+ }
+
+ @Override
+ public boolean isDense() {
+ return false;
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return true;
+ }
+
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ throw new UnsupportedOperationException("Default operation");
+ }
+
+ @Override
+ public Vector like() {
+ return new DenseVector(size());
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ if (index == this.index) {
+ diagonal.set(this.index, value);
+ } else {
+ throw new IllegalArgumentException("Can't set off-diagonal element of DiagonalMatrix");
+ }
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ return 1;
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 0;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return false;
+ }
+ }
+
/**
* Provides a view of the diagonal of a matrix.
*/
@@ -147,22 +339,26 @@ public class DiagonalMatrix extends AbstractMatrix implements MatrixTimesOps {
@Override
public Matrix timesRight(Matrix that) {
- if (that.numRows() != diagonal.size())
+ if (that.numRows() != diagonal.size()) {
throw new IllegalArgumentException("Incompatible number of rows in the right operand of matrix multiplication.");
+ }
Matrix m = that.like();
- for (int row = 0; row < diagonal.size(); row++)
+ for (int row = 0; row < diagonal.size(); row++) {
m.assignRow(row, that.viewRow(row).times(diagonal.getQuick(row)));
+ }
return m;
}
@Override
public Matrix timesLeft(Matrix that) {
- if (that.numCols() != diagonal.size())
+ if (that.numCols() != diagonal.size()) {
throw new IllegalArgumentException(
- "Incompatible number of rows in the left operand of matrix-matrix multiplication.");
+ "Incompatible number of rows in the left operand of matrix-matrix multiplication.");
+ }
Matrix m = that.like();
- for (int col = 0; col < diagonal.size(); col++)
+ for (int col = 0; col < diagonal.size(); col++) {
m.assignColumn(col, that.viewColumn(col).times(diagonal.getQuick(col)));
+ }
return m;
}
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/dd78ed94/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
index 5b3a278..2ca7be0 100644
--- a/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
+++ b/math/src/test/java/org/apache/mahout/math/DiagonalMatrixTest.java
@@ -18,8 +18,11 @@
package org.apache.mahout.math;
import org.apache.mahout.math.function.Functions;
+import org.junit.Assert;
import org.junit.Test;
+import java.util.Iterator;
+
public class DiagonalMatrixTest extends MahoutTestCase {
@Test
public void testBasics() {
@@ -46,4 +49,44 @@ public class DiagonalMatrixTest extends MahoutTestCase {
assertEquals(100, a.times(m.transpose()).aggregate(Functions.PLUS, Functions.ABS), 1.0e-10);
}
+ @Test
+ public void testSparsity() {
+ Vector d = new DenseVector(10);
+ for (int i = 0; i < 10; i++) {
+ d.set(i, i * i);
+ }
+ DiagonalMatrix m = new DiagonalMatrix(d);
+
+ Assert.assertFalse(m.viewRow(0).isDense());
+ Assert.assertFalse(m.viewColumn(0).isDense());
+
+ for (int i = 0; i < 10; i++) {
+ assertEquals(i * i, m.viewRow(i).zSum(), 0);
+ assertEquals(i * i, m.viewRow(i).get(i), 0);
+
+ assertEquals(i * i, m.viewColumn(i).zSum(), 0);
+ assertEquals(i * i, m.viewColumn(i).get(i), 0);
+ }
+
+ Iterator<Vector.Element> ix = m.viewRow(7).nonZeroes().iterator();
+ assertTrue(ix.hasNext());
+ Vector.Element r = ix.next();
+ assertEquals(7, r.index());
+ assertEquals(49, r.get(), 0);
+ assertFalse(ix.hasNext());
+
+ assertEquals(0, m.viewRow(5).get(3), 0);
+ assertEquals(0, m.viewColumn(8).get(3), 0);
+
+ m.viewRow(3).set(3, 1);
+ assertEquals(1, m.get(3, 3), 0);
+
+ for (Vector.Element element : m.viewRow(6).all()) {
+ if (element.index() == 6) {
+ assertEquals(36, element.get(), 0);
+ } else {
+ assertEquals(0, element.get(), 0);
+ }
+ }
+ }
}
[3/3] git commit: MAHOUT-1574 - Add tests to demonstrate correctness
and efficiency
Posted by td...@apache.org.
MAHOUT-1574 - Add tests to demonstrate correctness and efficiency
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/5083f583
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/5083f583
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/5083f583
Branch: refs/heads/master
Commit: 5083f58359dd103ac6d8e72a50c53bb20a5df14e
Parents: 0f4f5de
Author: Ted Dunning <td...@apache.org>
Authored: Fri Jun 6 19:21:47 2014 -0700
Committer: Ted Dunning <td...@apache.org>
Committed: Fri Jun 6 19:21:47 2014 -0700
----------------------------------------------------------------------
.../apache/mahout/math/TestSparseRowMatrix.java | 127 +++++++++++++++++++
1 file changed, 127 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/5083f583/math/src/test/java/org/apache/mahout/math/TestSparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/test/java/org/apache/mahout/math/TestSparseRowMatrix.java b/math/src/test/java/org/apache/mahout/math/TestSparseRowMatrix.java
index 0f71506..08174be 100644
--- a/math/src/test/java/org/apache/mahout/math/TestSparseRowMatrix.java
+++ b/math/src/test/java/org/apache/mahout/math/TestSparseRowMatrix.java
@@ -17,6 +17,14 @@
package org.apache.mahout.math;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.jet.random.Gamma;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+
public final class TestSparseRowMatrix extends MatrixTest {
@Override
@@ -31,4 +39,123 @@ public final class TestSparseRowMatrix extends MatrixTest {
}
+ @Test(timeout=5000)
+ public void testTimesSparseEfficiency() {
+ Random raw = RandomUtils.getRandom();
+ Gamma gen = new Gamma(0.1, 0.1, raw);
+
+ // build two large sequential sparse matrices and multiply them
+ Matrix x = new SparseRowMatrix(1000, 2000, false);
+ for (int i = 0; i < 1000; i++) {
+ int[] values = new int[1000];
+ for (int k = 0; k < 1000; k++) {
+ int j = (int) Math.min(1000, gen.nextDouble());
+ values[j]++;
+ }
+ for (int j = 0; j < 1000; j++) {
+ if (values[j] > 0) {
+ x.set(i, j, values[j]);
+ }
+ }
+ }
+
+ Matrix y = new SparseRowMatrix(2000, 1000, false);
+ for (int i = 0; i < 2000; i++) {
+ int[] values = new int[1000];
+ for (int k = 0; k < 1000; k++) {
+ int j = (int) Math.min(1000, gen.nextDouble());
+ values[j]++;
+ }
+ for (int j = 0; j < 1000; j++) {
+ if (values[j] > 0) {
+ y.set(i, j, values[j]);
+ }
+ }
+ }
+
+ long t0 = System.nanoTime();
+ Matrix z = x.times(y);
+ double elapsedTime = (System.nanoTime() - t0) * 1e-6;
+ System.out.printf("done in %.1f ms\n", elapsedTime);
+
+ for (int k = 0; k < 1000; k++) {
+ int i = (int) (-10 * Math.log(raw.nextDouble()));
+ int j = (int) (-10 * Math.log(raw.nextDouble()));
+ Assert.assertEquals(x.viewRow(i).dot(y.viewColumn(j)), z.get(i, j), 1e-12);
+ }
+ }
+
+ @Test(timeout=5000)
+ public void testTimesDenseEfficiency() {
+ Random raw = RandomUtils.getRandom();
+ Gamma gen = new Gamma(0.1, 0.1, raw);
+
+ // build a sequential sparse matrix and a dense matrix and multiply them
+ Matrix x = new SparseRowMatrix(1000, 2000, false);
+ for (int i = 0; i < 1000; i++) {
+ int[] values = new int[1000];
+ for (int k = 0; k < 1000; k++) {
+ int j = (int) Math.min(1000, gen.nextDouble());
+ values[j]++;
+ }
+ for (int j = 0; j < 1000; j++) {
+ if (values[j] > 0) {
+ x.set(i, j, values[j]);
+ }
+ }
+ }
+
+ Matrix y = new DenseMatrix(2000, 20);
+ for (int i = 0; i < 2000; i++) {
+ for (int j = 0; j < 20; j++) {
+ y.set(i, j, raw.nextDouble());
+ }
+ }
+
+ long t0 = System.nanoTime();
+ Matrix z = x.times(y);
+ double elapsedTime = (System.nanoTime() - t0) * 1e-6;
+ System.out.printf("done in %.1f ms\n", elapsedTime);
+
+ for (int i = 0; i < 1000; i++) {
+ for (int j = 0; j < 20; j++) {
+ Assert.assertEquals(x.viewRow(i).dot(y.viewColumn(j)), z.get(i, j), 1e-12);
+ }
+ }
+ }
+
+ @Test(timeout=5000)
+ public void testTimesOtherSparseEfficiency() {
+ Random raw = RandomUtils.getRandom();
+ Gamma gen = new Gamma(0.1, 0.1, raw);
+
+ // build a sequential sparse matrix and a diagonal matrix and multiply them
+ Matrix x = new SparseRowMatrix(1000, 2000, false);
+ for (int i = 0; i < 1000; i++) {
+ int[] values = new int[1000];
+ for (int k = 0; k < 1000; k++) {
+ int j = (int) Math.min(1000, gen.nextDouble());
+ values[j]++;
+ }
+ for (int j = 0; j < 1000; j++) {
+ if (values[j] > 0) {
+ x.set(i, j, values[j]);
+ }
+ }
+ }
+
+ Vector d = new DenseVector(2000).assign(Functions.random());
+ Matrix y = new DiagonalMatrix(d);
+
+ long t0 = System.nanoTime();
+ Matrix z = x.times(y);
+ double elapsedTime = (System.nanoTime() - t0) * 1e-6;
+ System.out.printf("done in %.1f ms\n", elapsedTime);
+
+ for (MatrixSlice row : z) {
+ for (Vector.Element element : row.nonZeroes()) {
+ assertEquals(x.get(row.index(), element.index()) * d.get(element.index()), element.get(), 1e-12);
+ }
+ }
+ }
}
[2/3] git commit: MAHOUT-1574 - Make SparseRowMatrix handle
multiplication efficiently for special cases of SRM,
DenseMatrix and other kinds of sparse matrices. Speedup on the given test is
at least several thousand x
Posted by td...@apache.org.
MAHOUT-1574 - Make SparseRowMatrix handle multiplication efficiently for special cases of SRM, DenseMatrix and other kinds of sparse matrices. Speedup on the given test is at least several thousand x
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/0f4f5dec
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/0f4f5dec
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/0f4f5dec
Branch: refs/heads/master
Commit: 0f4f5dec7f3a7de7161d1fabf3e02418099f4446
Parents: dd78ed9
Author: Ted Dunning <td...@apache.org>
Authored: Fri Jun 6 19:20:35 2014 -0700
Committer: Ted Dunning <td...@apache.org>
Committed: Fri Jun 6 19:20:35 2014 -0700
----------------------------------------------------------------------
.../org/apache/mahout/math/SparseRowMatrix.java | 43 ++++++++++++++++++++
1 file changed, 43 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/0f4f5dec/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
----------------------------------------------------------------------
diff --git a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
index 5829289..294a69d 100644
--- a/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
+++ b/math/src/main/java/org/apache/mahout/math/SparseRowMatrix.java
@@ -17,6 +17,8 @@
package org.apache.mahout.math;
+import org.apache.mahout.math.function.Functions;
+
/**
* sparse matrix with general element values whose rows are accessible quickly. Implemented as a row array of
* either SequentialAccessSparseVectors or RandomAccessSparseVectors.
@@ -177,4 +179,45 @@ public class SparseRowMatrix extends AbstractMatrix {
return scm;
}
+ @Override
+ public Matrix times(Matrix other) {
+ if (other instanceof SparseRowMatrix) {
+ SparseRowMatrix y = (SparseRowMatrix) other;
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(y.rowVectors[element.index()], Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ } else {
+ if (other.viewRow(0).isDense()) {
+ // result is dense, but can be computed relatively cheaply
+ Matrix result = other.like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ Vector r = new DenseVector(other.columnSize());
+ for (Vector.Element element : row.nonZeroes()) {
+ r.assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ result.viewRow(i).assign(r);
+ }
+ return result;
+ } else {
+ // other is sparse, but not something we understand intimately
+ SparseRowMatrix result = (SparseRowMatrix) like(rowSize(), other.columnSize());
+
+ for (int i = 0; i < rows; i++) {
+ Vector row = rowVectors[i];
+ for (Vector.Element element : row.nonZeroes()) {
+ result.rowVectors[i].assign(other.viewRow(element.index()), Functions.plusMult(element.get()));
+ }
+ }
+ return result;
+ }
+ }
+ }
}