You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/10/11 08:23:45 UTC

svn commit: r1396920 - in /mahout/trunk/math/src: main/java/org/apache/mahout/math/SequentialAccessSparseVector.java test/java/org/apache/mahout/math/AbstractVectorTest.java

Author: tdunning
Date: Thu Oct 11 06:23:44 2012
New Revision: 1396920

URL: http://svn.apache.org/viewvc?rev=1396920&view=rev
Log:
MAHOUT-1091 - Add test to demonstrate broken iterator in SequentialAccessSparseVector (and add fix)

Modified:
    mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=1396920&r1=1396919&r2=1396920&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Thu Oct 11 06:23:44 2012
@@ -253,10 +253,16 @@ public class SequentialAccessSparseVecto
     protected Element computeNext() {
       int numMappings = values.getNumMappings();
       if (numMappings <= 0 || element.getNextIndex() > values.getIndices()[numMappings - 1]) {
-        return endOfData();
+        if (element.index() >= SequentialAccessSparseVector.this.size() - 1) {
+          return endOfData();
+        } else {
+          element.advanceIndex();
+          return element;
+        }
+      } else {
+        element.advanceIndex();
+        return element;
       }
-      element.advanceIndex();
-      return element;
     }
 
   }
@@ -297,7 +303,7 @@ public class SequentialAccessSparseVecto
 
     void advanceIndex() {
       index++;
-      if (index > values.getIndices()[nextOffset]) {
+      if (nextOffset < values.getNumMappings() && index > values.getIndices()[nextOffset]) {
         nextOffset++;
       }
     }
@@ -308,10 +314,11 @@ public class SequentialAccessSparseVecto
 
     @Override
     public double get() {
-      if (index == values.getIndices()[nextOffset]) {
+      if (nextOffset < values.getNumMappings() && index == values.getIndices()[nextOffset]) {
         return values.getValues()[nextOffset];
+      } else {
+        return OrderedIntDoubleMapping.DEFAULT_VALUE;
       }
-      return OrderedIntDoubleMapping.DEFAULT_VALUE;
     }
 
     @Override
@@ -322,7 +329,7 @@ public class SequentialAccessSparseVecto
     @Override
     public void set(double value) {
       invalidateCachedLength();
-      if (index == values.getIndices()[nextOffset]) {
+      if (nextOffset < values.getNumMappings() && index == values.getIndices()[nextOffset]) {
         values.getValues()[nextOffset] = value;
       } else {
         // Yes, this works; the offset into indices of the new value's index will still be nextOffset

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1396920&r1=1396919&r2=1396920&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java Thu Oct 11 06:23:44 2012
@@ -3,6 +3,7 @@ package org.apache.mahout.math;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.jet.random.Normal;
+import org.apache.mahout.math.random.MultiNormal;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -574,6 +575,54 @@ public abstract class AbstractVectorTest
       for (int col = 0; col < result.columnSize(); col++) {
         assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
             * test.getQuick(col), result.getQuick(row, col), EPSILON);
+
+      }
+    }
+  }
+
+  @Test
+  public void testIterators() {
+    final T v0 = vectorToTest(20);
+
+    double sum = 0;
+    int elements = 0;
+    int nonZero = 0;
+    for (Vector.Element element : v0) {
+      elements++;
+      sum += element.get();
+      if (element.get() != 0) {
+        nonZero++;
+      }
+    }
+
+    int nonZeroIterated = 0;
+    final Iterator<Vector.Element> i = v0.iterateNonZero();
+    while (i.hasNext()) {
+      i.next();
+      nonZeroIterated++;
+    }
+    assertEquals(20, elements);
+    assertEquals(v0.size(), elements);
+    assertEquals(nonZeroIterated, nonZero);
+    assertEquals(v0.zSum(), sum, 0);
+  }
+
+  @Test
+  public void testSmallDistances() {
+    for (double fuzz : new double[]{1e-5, 1e-6, 1e-7, 1e-8, 1e-9, 1e-10}) {
+      MultiNormal x = new MultiNormal(fuzz, new ConstantVector(0, 20));
+      for (int i = 0; i < 10000; i++) {
+        final T v1 = vectorToTest(20);
+        Vector v2 = v1.plus(x.sample());
+        if (1 + fuzz * fuzz > 1) {
+          String msg = String.format("fuzz = %.1g, >", fuzz);
+          assertTrue(msg, v1.getDistanceSquared(v2) > 0);
+          assertTrue(msg, v2.getDistanceSquared(v1) > 0);
+        } else {
+          String msg = String.format("fuzz = %.1g, >=", fuzz);
+          assertTrue(msg, v1.getDistanceSquared(v2) >= 0);
+          assertTrue(msg, v2.getDistanceSquared(v1) >= 0);
+        }
       }
     }
   }