You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by jm...@apache.org on 2010/02/20 06:55:38 UTC

svn commit: r912079 - in /lucene/mahout/trunk/math/src: main/java/org/apache/mahout/math/SequentialAccessSparseVector.java test/java/org/apache/mahout/math/VectorTest.java

Author: jmannix
Date: Sat Feb 20 05:55:37 2010
New Revision: 912079

URL: http://svn.apache.org/viewvc?rev=912079&view=rev
Log:
Fix broken getLengthSquared() after mutation, and fix totally broken iterateAll() mutation.  New unit tests!

Modified:
    lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
    lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java

Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=912079&r1=912078&r2=912079&view=diff
==============================================================================
--- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original)
+++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Sat Feb 20 05:55:37 2010
@@ -107,6 +107,7 @@
   }
 
   public void setQuick(int index, double value) {
+    lengthSquared = -1;
     values.set(index, value);
   }
 
@@ -217,8 +218,8 @@
   private abstract static class AbstractElement implements Element {
     int offset;
     final OrderedIntDoubleMapping mapping;
-    final int[] indices;
-    final double[] values;
+    int[] indices;
+    double[] values;
 
     AbstractElement(int ind, SequentialAccessSparseVector v) {
       offset = ind;
@@ -250,15 +251,19 @@
     }
 
     public void set(double value) {
-      v.lengthSquared = -1;
-      if(value != 0.0) mapping.set(indices[offset], value);
+      v.set(offset, value);
+      // indices and values may have changed, must re-grab them.
+      indices = mapping.getIndices();
+      values = mapping.getValues();
     }
   }
 
   private static final class SparseElement extends AbstractElement {
 
+    SequentialAccessSparseVector v;
     SparseElement(int ind, SequentialAccessSparseVector v) {
       super(ind, v);
+      this.v = v;
     }
 
     public double get() {
@@ -270,6 +275,7 @@
     }
 
     public void set(double value) {
+      v.lengthSquared = -1;
       values[offset] = value;
     }
   }

Modified: lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java?rev=912079&r1=912078&r2=912079&view=diff
==============================================================================
--- lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java (original)
+++ lucene/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java Sat Feb 20 05:55:37 2010
@@ -21,6 +21,7 @@
 import com.google.gson.GsonBuilder;
 import com.google.gson.reflect.TypeToken;
 import junit.framework.TestCase;
+import org.apache.mahout.math.function.Functions;
 
 import static org.apache.mahout.math.function.Functions.*;
 
@@ -191,6 +192,103 @@
         Math.abs(expected - v.getDistanceSquared(w)) < 1e-6);
   }
 
+  public void testGetLengthSquared() throws Exception {
+    Vector v = new DenseVector(5);
+    setUpV(v);
+    doTestGetLengthSquared(v);
+    v = new RandomAccessSparseVector(5);
+    setUpV(v);
+    doTestGetLengthSquared(v);
+    v = new SequentialAccessSparseVector(5);
+    setUpV(v);
+    doTestGetLengthSquared(v);
+  }
+
+  public static double lengthSquaredSlowly(Vector v) {
+    double d = 0;
+    for(int i=0; i<v.size(); i++) {
+      d += (v.get(i) * v.get(i));
+    }
+    return d;
+  }
+
+  public void doTestGetLengthSquared(Vector v) throws Exception {
+    double expected = lengthSquaredSlowly(v);
+    assertTrue("v.getLengthSquared() != sum_of_squared_elements(v)",
+        expected == v.getLengthSquared());
+
+    v.set(v.size()/2, v.get(v.size()/2) + 1.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via set() fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.setQuick(v.size()/5, v.get(v.size()/5) + 1.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via setQuick() fails to change lengthSquared", expected, v.getLengthSquared());
+
+    Iterator<Vector.Element> it = v.iterateAll();
+    while(it.hasNext()) {
+      Vector.Element e = it.next();
+      if(e.index() == v.size() - 2) {
+        e.set(e.get() - 5.0);
+      }
+    }
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via dense iterator.set fails to change lengthSquared", expected, v.getLengthSquared());
+
+    it = v.iterateNonZero();
+    int i=0;
+    while(it.hasNext()) {
+      i++;
+      Vector.Element e = it.next();
+      if(i == v.getNumNondefaultElements() - 1) {
+        e.set(e.get() - 5.0);
+      }
+    }
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via sparse iterator.set fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.assign(3.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via assign(double) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.assign(Functions.square);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via assign(square) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.assign(new double[v.size()]);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via assign(double[]) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.getElement(v.size()/2).set(2.5);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via v.getElement().set() fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.normalize();
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via normalize() fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.set(0, 1.5);
+    v.normalize(1.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via normalize(double) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.times(2.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via times(double) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.times(v);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via times(vector) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.assign(Functions.pow, 3.0);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via assign(pow, 3.0) fails to change lengthSquared", expected, v.getLengthSquared());
+
+    v.assign(v, Functions.plus);
+    expected = lengthSquaredSlowly(v);
+    assertEquals("mutation via assign(v,plus) fails to change lengthSquared", expected, v.getLengthSquared());
+  }
+
   public void testNormalize() throws Exception {
     RandomAccessSparseVector vec1 = new RandomAccessSparseVector(3);