You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/10/11 08:23:45 UTC
svn commit: r1396920 - in /mahout/trunk/math/src:
main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
test/java/org/apache/mahout/math/AbstractVectorTest.java
Author: tdunning
Date: Thu Oct 11 06:23:44 2012
New Revision: 1396920
URL: http://svn.apache.org/viewvc?rev=1396920&view=rev
Log:
MAHOUT-1091 - Add test to demonstrate broken iterator in SequentialAccessSparseVector (and add fix)
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=1396920&r1=1396919&r2=1396920&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Thu Oct 11 06:23:44 2012
@@ -253,10 +253,16 @@ public class SequentialAccessSparseVecto
protected Element computeNext() {
int numMappings = values.getNumMappings();
if (numMappings <= 0 || element.getNextIndex() > values.getIndices()[numMappings - 1]) {
- return endOfData();
+ if (element.index() >= SequentialAccessSparseVector.this.size() - 1) {
+ return endOfData();
+ } else {
+ element.advanceIndex();
+ return element;
+ }
+ } else {
+ element.advanceIndex();
+ return element;
}
- element.advanceIndex();
- return element;
}
}
@@ -297,7 +303,7 @@ public class SequentialAccessSparseVecto
void advanceIndex() {
index++;
- if (index > values.getIndices()[nextOffset]) {
+ if (nextOffset < values.getNumMappings() && index > values.getIndices()[nextOffset]) {
nextOffset++;
}
}
@@ -308,10 +314,11 @@ public class SequentialAccessSparseVecto
@Override
public double get() {
- if (index == values.getIndices()[nextOffset]) {
+ if (nextOffset < values.getNumMappings() && index == values.getIndices()[nextOffset]) {
return values.getValues()[nextOffset];
+ } else {
+ return OrderedIntDoubleMapping.DEFAULT_VALUE;
}
- return OrderedIntDoubleMapping.DEFAULT_VALUE;
}
@Override
@@ -322,7 +329,7 @@ public class SequentialAccessSparseVecto
@Override
public void set(double value) {
invalidateCachedLength();
- if (index == values.getIndices()[nextOffset]) {
+ if (nextOffset < values.getNumMappings() && index == values.getIndices()[nextOffset]) {
values.getValues()[nextOffset] = value;
} else {
// Yes, this works; the offset into indices of the new value's index will still be nextOffset
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1396920&r1=1396919&r2=1396920&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java Thu Oct 11 06:23:44 2012
@@ -3,6 +3,7 @@ package org.apache.mahout.math;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.jet.random.Normal;
+import org.apache.mahout.math.random.MultiNormal;
import org.junit.Before;
import org.junit.Test;
@@ -574,6 +575,54 @@ public abstract class AbstractVectorTest
for (int col = 0; col < result.columnSize(); col++) {
assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
* test.getQuick(col), result.getQuick(row, col), EPSILON);
+
+ }
+ }
+ }
+
+ @Test
+ public void testIterators() {
+ final T v0 = vectorToTest(20);
+
+ double sum = 0;
+ int elements = 0;
+ int nonZero = 0;
+ for (Vector.Element element : v0) {
+ elements++;
+ sum += element.get();
+ if (element.get() != 0) {
+ nonZero++;
+ }
+ }
+
+ int nonZeroIterated = 0;
+ final Iterator<Vector.Element> i = v0.iterateNonZero();
+ while (i.hasNext()) {
+ i.next();
+ nonZeroIterated++;
+ }
+ assertEquals(20, elements);
+ assertEquals(v0.size(), elements);
+ assertEquals(nonZeroIterated, nonZero);
+ assertEquals(v0.zSum(), sum, 0);
+ }
+
+ @Test
+ public void testSmallDistances() {
+ for (double fuzz : new double[]{1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10}) {
+ MultiNormal x = new MultiNormal(fuzz, new ConstantVector(0, 20));
+ for (int i = 0; i < 10000; i++) {
+ final T v1 = vectorToTest(20);
+ Vector v2 = v1.plus(x.sample());
+ if (1 + fuzz * fuzz > 1) {
+ String msg = String.format("fuzz = %.1g, >", fuzz);
+ assertTrue(msg, v1.getDistanceSquared(v2) > 0);
+ assertTrue(msg, v2.getDistanceSquared(v1) > 0);
+ } else {
+ String msg = String.format("fuzz = %.1g, >=", fuzz);
+ assertTrue(msg, v1.getDistanceSquared(v2) >= 0);
+ assertTrue(msg, v2.getDistanceSquared(v1) >= 0);
+ }
}
}
}