You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/10/11 03:17:11 UTC

svn commit: r1396888 - in /mahout/trunk: core/src/test/java/org/apache/mahout/clustering/meanshift/ math/src/main/java/org/apache/mahout/math/ math/src/test/java/org/apache/mahout/math/

Author: tdunning
Date: Thu Oct 11 01:17:10 2012
New Revision: 1396888

URL: http://svn.apache.org/viewvc?rev=1396888&view=rev
Log:
MAHOUT-1086 - Deal with round-off errors in computing L_2 distances.  Add special case to get higher accuracy when vector difference is small, merge AbstractVectorTest and AbstractTestVector, fix like() bug in Centroid and WeightedVector.

Removed:
    mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
Modified:
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java Thu Oct 11 01:17:10 2012
@@ -492,7 +492,7 @@ public final class TestMeanShift extends
     ToolRunner.run(conf, new MeanShiftCanopyDriver(), args);
     Path outPart = new Path(output, "clusters-3-final/part-r-00000");
     long count = HadoopUtil.countRecords(outPart, conf);
-    assertEquals("count", 4, count);
+    assertEquals("count", 3, count);
     Iterator<?> iterator = new SequenceFileValueIterator<Writable>(outPart,
         true, conf);
     while (iterator.hasNext()) {

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Thu Oct 11 01:17:10 2012
@@ -21,12 +21,13 @@ import org.apache.mahout.common.RandomUt
 import org.apache.mahout.math.function.DoubleDoubleFunction;
 import org.apache.mahout.math.function.DoubleFunction;
 import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.set.OpenIntHashSet;
 
 import java.util.Iterator;
 
 /** Implementations of generic capabilities like sum of elements and dot products */
 public abstract class AbstractVector implements Vector, LengthCachingVector {
-  
+
   private static final double LOG2 = Math.log(2.0);
 
   private int size;
@@ -155,12 +156,12 @@ public abstract class AbstractVector imp
     }
     return result;
   }
-  
-  public double dotSelf() {
+
+  protected double dotSelf() {
     double result = 0.0;
-    Iterator<Element> iter = iterateNonZero();
-    while (iter.hasNext()) {
-      double value = iter.next().get();
+    Iterator<Element> i = iterateNonZero();
+    while (i.hasNext()) {
+      double value = i.next().get();
       result += value * value;
     }
     return result;
@@ -216,18 +217,18 @@ public abstract class AbstractVector imp
   public Vector normalize(double power) {
     return divide(norm(power));
   }
-  
+
   @Override
   public Vector logNormalize() {
     return logNormalize(2.0, Math.sqrt(dotSelf()));
   }
-  
+
   @Override
   public Vector logNormalize(double power) {
     return logNormalize(power, norm(power));
   }
-  
-  public Vector logNormalize(double power, double normLength) {   
+
+  public Vector logNormalize(double power, double normLength) {
     // we can special case certain powers
     if (Double.isInfinite(power) || power <= 1.0) {
       throw new IllegalArgumentException("Power must be > 1 and < infinity");
@@ -293,8 +294,8 @@ public abstract class AbstractVector imp
   }
 
   @Override
-  public void setLengthSquared(double d2) {
-    lengthSquared = d2;
+  public void invalidateCachedLength() {
+    lengthSquared = -1;
   }
 
   @Override
@@ -303,34 +304,189 @@ public abstract class AbstractVector imp
       throw new CardinalityException(size, v.size());
     }
     // if this and v has a cached lengthSquared, dot product is quickest way to compute this.
-    if (lengthSquared >= 0 && v instanceof LengthCachingVector && v.getLengthSquared() >= 0) {
-      return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
+    double d1;
+    double d2;
+    double dot;
+    if (lengthSquared >= 0) {
+      // our length squared is cached.  use it
+      // the max is (slight) antidote to round-off errors
+      d1 = lengthSquared;
+      d2 = v.getLengthSquared();
+      dot = this.dot(v);
+    } else {
+      // our length is not cached... compute it and the dot product in one pass for speed
+      d1 = 0;
+      d2 = v.getLengthSquared();
+      dot = 0;
+      final Iterator<Element> i = iterateNonZero();
+      while (i.hasNext()) {
+        Element e = i.next();
+        double value = e.get();
+        d1 += value * value;
+        dot += value * v.getQuick(e.index());
+      }
+      lengthSquared = d1;
+      // again, round-off errors may be present
     }
-    Vector sparseAccessed;
-    Vector randomlyAccessed;
-    if (lengthSquared >= 0.0) {
-      randomlyAccessed = this;
-      sparseAccessed = v;
-    } else { // TODO: could be further optimized, figure out which one is smaller, etc
-      randomlyAccessed = v;
-      sparseAccessed = this;
+
+    double r = d1 + d2 - 2 * dot;
+    if (r > 1e-3 * (d1 + d2)) {
+      return Math.max(0, r);
+    } else {
+      if (this.isSequentialAccess()) {
+        if (v.isSequentialAccess()) {
+          return mergeDiff(this, v);
+        } else {
+          return randomScanDiff(this, v);
+        }
+      } else {
+        return randomScanDiff(v, this);
+      }
     }
+  }
 
-    Iterator<Element> it = sparseAccessed.iterateNonZero();
-    double d = randomlyAccessed.getLengthSquared();
-    double d2 = 0;
-    double dot = 0;
-    while (it.hasNext()) {
-      Element e = it.next();
-      double value = e.get();
-      d2 += value * value;
-      dot += value * randomlyAccessed.getQuick(e.index());
+  /**
+   * Computes the squared difference of two vectors where iterateNonZero
+   * is efficient for each vector, but where the order of iteration is not
+   * known.  This forces us to access most elements of v2 via get(), which
+   * would be very inefficient for some kinds of vectors.
+   *
+   * Note that this static method is exposed at a package level for testing purposes only.
+   * @param v1  The vector that we access only via iterateNonZero
+   * @param v2  The vector that we access via iterateNonZero and via Element.get()
+   * @return The squared difference between v1 and v2.
+   */
+   static double randomScanDiff(Vector v1, Vector v2) {
+    // keeps a list of elements we visited by iterating over v1.  This should be
+    // almost all of the elements of v2 because we only call this method if the
+    // difference is small.
+    OpenIntHashSet visited = new OpenIntHashSet();
+
+    double r = 0;
+
+    // walk through non-zeros of v1
+    Iterator<Element> i = v1.iterateNonZero();
+    while (i.hasNext()) {
+      Element e1 = i.next();
+      visited.add(e1.index());
+      double x = e1.get() - v2.get(e1.index());
+      r += x * x;
+    }
+
+    // now walk through neglected elements of v2
+    i = v2.iterateNonZero();
+    while (i.hasNext()) {
+      Element e2 = i.next();
+      if (!visited.contains(e2.index())) {
+        // if not visited already then v1's value here would be zero.
+        double x = e2.get();
+        r += x * x;
+      }
     }
-    if (sparseAccessed instanceof LengthCachingVector) {
-      ((LengthCachingVector) sparseAccessed).setLengthSquared(d2);
+
+    return r;
+  }
+
+  /**
+   * Computes the squared difference of two vectors where iterateNonZero returns
+   * elements in index order for both vectors.  This allows a merge to be used to
+   * compute the difference.  A merge allows a single sequential pass over each
+   * vector and should be faster than any alternative.
+   *
+   * Note that this static method is exposed at a package level for testing purposes only.
+   * @param v1  The first vector.
+   * @param v2  The second vector.
+   * @return The squared difference between the two vectors.
+   */
+  static double mergeDiff(Vector v1, Vector v2) {
+    Iterator<Element> i1 = v1.iterateNonZero();
+    Iterator<Element> i2 = v2.iterateNonZero();
+
+    // v1 is empty?
+    if (!i1.hasNext()) {
+      return v2.getLengthSquared();
+    }
+
+    // v2 is empty?
+    if (!i2.hasNext()) {
+      return v1.getLengthSquared();
+    }
+
+    Element e1 = i1.next();
+    Element e2 = i2.next();
+
+    double r = 0;
+    while (e1 != null && e2 != null) {
+      // eat elements of v1 that precede all in v2
+      while (e1 != null && e1.index() < e2.index()) {
+        double x = e1.get();
+        r += x * x;
+
+        if (i1.hasNext()) {
+          e1 = i1.next();
+        } else {
+          e1 = null;
+        }
+      }
+
+      // at this point we have three possibilities, e1 == null or e1 matches e2 or
+      // e2 precedes e1.  Here we handle the e2 < e1 case
+      while (e2 != null && (e1 == null || e2.index() < e1.index())) {
+        double x = e2.get();
+        r += x * x;
+
+        if (i2.hasNext()) {
+          e2 = i2.next();
+        } else {
+          e2 = null;
+        }
+      }
+
+      // and now we handle the e1 == e2 case.  For convenience, we
+      // grab as many of these as possible.  Given that we are called here
+      // only when v1 and v2 are nearly equal, this loop should dominate
+      while (e1 != null && e2 != null && e1.index() == e2.index()) {
+        double x = e1.get() - e2.get();
+        r += x * x;
+
+        if (i1.hasNext()) {
+          e1 = i1.next();
+        } else {
+          e1 = null;
+        }
+
+        if (i2.hasNext()) {
+          e2 = i2.next();
+        } else {
+          e2 = null;
+        }
+      }
+    }
+
+    // one of i1 or i2 is exhausted here, but the other may not be
+    while (e1 != null ) {
+      double x = e1.get();
+      r += x * x;
+
+      if (i1.hasNext()) {
+        e1 = i1.next();
+      } else {
+        e1 = null;
+      }
     }
-    //assert d > -1.0e-9; // round-off errors should never be too far off!
-    return Math.abs(d + d2 - 2 * dot);
+
+    while (e2 != null) {
+      double x = e2.get();
+      r += x * x;
+
+      if (i2.hasNext()) {
+        e2 = i2.next();
+      } else {
+        e2 = null;
+      }
+    }
+    // both v1 and v2 have been completely processed
+    return r;
   }
 
   @Override
@@ -348,7 +504,7 @@ public abstract class AbstractVector imp
     }
     return result;
   }
-  
+
   @Override
   public int maxValueIndex() {
     int result = -1;
@@ -473,7 +629,7 @@ public abstract class AbstractVector imp
     if (x == 1.0) {
       return result;
     }
-    
+
     Iterator<Element> iter = result.iterateNonZero();
     while (iter.hasNext()) {
       Element element = iter.next();
@@ -600,7 +756,7 @@ public abstract class AbstractVector imp
 
   @Override
   public final int size() {
-    return size;  
+    return size;
   }
 
   @Override

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java Thu Oct 11 01:17:10 2012
@@ -25,8 +25,8 @@ import org.apache.mahout.math.function.D
  */
 public class Centroid extends WeightedVector {
     public Centroid(WeightedVector original) {
-        super(original.size(), original.getWeight(), original.getIndex());
-        delegate = original.like();
+        super(original.getWeight(), original.getIndex());
+        delegate = original.getVector().like();
         delegate.assign(original);
     }
 
@@ -67,7 +67,12 @@ public class Centroid extends WeightedVe
         setWeight(totalWeight);
     }
 
-    /**
+  @Override
+  public Vector like() {
+    return new Centroid(getIndex(), getVector().like(), getWeight());
+  }
+
+  /**
      * Gets the index of this centroid.  Use getIndex instead to maintain standard names.
      */
     @Deprecated

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java Thu Oct 11 01:17:10 2012
@@ -32,14 +32,13 @@ import java.util.Iterator;
 public class DelegatingVector implements Vector, LengthCachingVector {
   protected Vector delegate;
 
-  public DelegatingVector(int size) {
-    delegate = new DenseVector(size);
-  }
-
   public DelegatingVector(Vector v) {
     delegate = v;
   }
 
+  protected DelegatingVector() {
+  }
+
   public Vector getVector() {
     return delegate;
   }
@@ -126,14 +125,10 @@ public class DelegatingVector implements
     return delegate.getLengthSquared();
   }
 
-  // not normally called because the delegate vector is who would need this and
-  // they will call their own version of this method.  In fact, if the delegate is
-  // also a delegating vector the same logic will apply recursively down to the first
-  // non-delegating vector.  This makes this very hard to test except in trivial ways.
   @Override
-  public void setLengthSquared(double d2) {
+  public void invalidateCachedLength() {
     if (delegate instanceof LengthCachingVector) {
-      ((LengthCachingVector) delegate).setLengthSquared(d2);
+      ((LengthCachingVector) delegate).invalidateCachedLength();
     }
   }
 
@@ -275,7 +270,7 @@ public class DelegatingVector implements
 
   @Override
   public Vector like() {
-    return delegate.like();
+    return new DelegatingVector(delegate.like());
   }
 
   @Override

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java Thu Oct 11 01:17:10 2012
@@ -56,7 +56,7 @@ public class DenseVector extends Abstrac
 
   /**
    * Copy-constructor (for use in turning a sparse vector into a dense one, for example)
-   * @param vector
+   * @param vector The vector to copy
    */
   public DenseVector(Vector vector) {
     super(vector.size());
@@ -95,7 +95,7 @@ public class DenseVector extends Abstrac
   }
 
   @Override
-  public double dotSelf() {
+  protected double dotSelf() {
     double result = 0.0;
     int max = size();
     for (int i = 0; i < max; i++) {
@@ -117,13 +117,13 @@ public class DenseVector extends Abstrac
 
   @Override
   public void setQuick(int index, double value) {
-    lengthSquared = -1.0;
+    invalidateCachedLength();
     values[index] = value;
   }
   
   @Override
   public Vector assign(double value) {
-    this.lengthSquared = -1;
+    invalidateCachedLength();
     Arrays.fill(values, value);
     return this;
   }
@@ -145,7 +145,7 @@ public class DenseVector extends Abstrac
         values[i] = function.apply(values[i], other.getQuick(i));
       }
     }
-    lengthSquared = -1;
+    invalidateCachedLength();
     return this;
   }
 
@@ -197,21 +197,6 @@ public class DenseVector extends Abstrac
     return super.equals(o);
   }
 
-  @Override
-  public double getLengthSquared() {
-    if (lengthSquared >= 0.0) {
-      return lengthSquared;
-    }
-
-    double result = 0.0;
-    for (double value : values) {
-      result += value * value;
-
-    }
-    lengthSquared = result;
-    return result;
-  }
-
   public void addAll(Vector v) {
     if (size() != v.size()) {
       throw new CardinalityException(size(), v.size());
@@ -281,7 +266,7 @@ public class DenseVector extends Abstrac
 
     @Override
     public void set(double value) {
-      lengthSquared = -1;
+      invalidateCachedLength();
       values[index] = value;
     }
   }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java Thu Oct 11 01:17:10 2012
@@ -21,12 +21,15 @@ package org.apache.mahout.math;
  * Marker interface for vectors that may cache their squared length.
  */
 interface LengthCachingVector {
+  /**
+   * Gets the currently cached squared length or if there is none, recalculates
+   * the value and returns that.
+   * @return The sum of the squares of all elements in the vector.
+   */
   double getLengthSquared();
 
   /**
-   * This is a very dangerous method to call.  Passing in a wrong value can
-   * completely screw up distance computations and normalization.
-   * @param d2  The new value for the squared length cache.
+   * Invalidates the length cache.  This should be called by all mutators of the vector.
    */
-  void setLengthSquared(double d2);
+  void invalidateCachedLength();
 }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -132,7 +132,7 @@ public class RandomAccessSparseVector ex
 
   @Override
   public void setQuick(int index, double value) {
-    lengthSquared = -1.0;
+    invalidateCachedLength();
     if (value == 0.0) {
       values.removeKey(index);
     } else {
@@ -225,7 +225,7 @@ public class RandomAccessSparseVector ex
 
     @Override
     public void set(double value) {
-      lengthSquared = -1;
+      invalidateCachedLength();
       if (value == 0.0) {
         values.removeKey(index);
       } else {

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -186,7 +186,7 @@ public class SequentialAccessSparseVecto
 
   @Override
   public void setQuick(int index, double value) {
-    lengthSquared = -1;
+    invalidateCachedLength();
     values.set(index, value);
   }
 
@@ -285,7 +285,7 @@ public class SequentialAccessSparseVecto
 
     @Override
     public void set(double value) {
-      lengthSquared = -1;
+      invalidateCachedLength();
       values.getValues()[offset] = value;
     }
   }
@@ -321,7 +321,7 @@ public class SequentialAccessSparseVecto
 
     @Override
     public void set(double value) {
-      lengthSquared = -1;
+      invalidateCachedLength();
       if (index == values.getIndices()[nextOffset]) {
         values.getValues()[nextOffset] = value;
       } else {

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java Thu Oct 11 01:17:10 2012
@@ -21,79 +21,84 @@ package org.apache.mahout.math;
  * Decorates a vector with a floating point weight and an index.
  */
 public class WeightedVector extends DelegatingVector implements Comparable<WeightedVector> {
-    private static final int INVALID_INDEX = -1;
-    private double weight;
-    private int index;
-
-    protected WeightedVector(int size, double weight, int index) {
-        super(size);
-        this.weight = weight;
-        this.index = index;
-    }
-
-    public WeightedVector(Vector v, double weight, int index) {
-        super(v);
-        this.weight = weight;
-        this.index = index;
-    }
-
-    public WeightedVector(Vector v, Vector projection, int index) {
-        super(v);
-        this.index = index;
-        this.weight = v.dot(projection);
-    }
-
-    public static WeightedVector project(Vector v, Vector projection) {
-        return project(v, projection, INVALID_INDEX);
-    }
-
-    public static WeightedVector project(Vector v, Vector projection, int index) {
-        return new WeightedVector(v, projection, index);
-    }
-
-    public double getWeight() {
-        return weight;
-    }
-
-
-    @Override
-    public int compareTo(WeightedVector other) {
-        if (this == other) {
-            return 0;
-        }
-        int r = Double.compare(weight, other.getWeight());
-        if (r == 0 || Math.abs(weight - other.getWeight()) < 1.0e-8) {
-            double diff = this.minus(other).norm(1);
-            if (diff < 1.0e-12) {
-                return 0;
-            } else {
-                for (Vector.Element element : this) {
-                    r = Double.compare(element.get(), other.get(element.index()));
-                    if (r != 0) {
-                        return r;
-                    }
-                }
-                return 0;
-            }
-        } else {
+  private static final int INVALID_INDEX = -1;
+  private double weight;
+  private int index;
+
+  protected WeightedVector(double weight, int index) {
+    super();
+    this.weight = weight;
+    this.index = index;
+  }
+
+  public WeightedVector(Vector v, double weight, int index) {
+    super(v);
+    this.weight = weight;
+    this.index = index;
+  }
+
+  public WeightedVector(Vector v, Vector projection, int index) {
+    super(v);
+    this.index = index;
+    this.weight = v.dot(projection);
+  }
+
+  public static WeightedVector project(Vector v, Vector projection) {
+    return project(v, projection, INVALID_INDEX);
+  }
+
+  public static WeightedVector project(Vector v, Vector projection, int index) {
+    return new WeightedVector(v, projection, index);
+  }
+
+  public double getWeight() {
+    return weight;
+  }
+
+
+  @Override
+  public int compareTo(WeightedVector other) {
+    if (this == other) {
+      return 0;
+    }
+    int r = Double.compare(weight, other.getWeight());
+    if (r == 0 || Math.abs(weight - other.getWeight()) < 1e-8) {
+      double diff = this.minus(other).norm(1);
+      if (diff < 1e-12) {
+        return 0;
+      } else {
+        for (Vector.Element element : this) {
+          r = Double.compare(element.get(), other.get(element.index()));
+          if (r != 0) {
             return r;
+          }
         }
-    }
-
-    public int getIndex() {
-        return index;
-    }
-
-    public void setWeight(double newWeight) {
-        this.weight = newWeight;
-    }
-
-    public void setIndex(int index) {
-        this.index = index;
-    }
-
-    @Override
-    public String toString() {
-        return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
-    }
+        return 0;
+      }
+    } else {
+      return r;
+    }
+  }
+
+  public int getIndex() {
+    return index;
+  }
+
+  public void setWeight(double newWeight) {
+    this.weight = newWeight;
+  }
+
+  public void setIndex(int index) {
+    this.index = index;
+  }
+
+  @Override
+  public Vector like() {
+    return new WeightedVector(getVector().like(), weight, index);
+  }
+
+  @Override
+  public String toString() {
+    return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
+  }
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java Thu Oct 11 01:17:10 2012
@@ -3,8 +3,10 @@ package org.apache.mahout.math;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.jet.random.Normal;
+import org.junit.Before;
 import org.junit.Test;
 
+import java.util.Iterator;
 import java.util.Random;
 
 /**
@@ -18,6 +20,17 @@ import java.util.Random;
 public abstract class AbstractVectorTest<T extends Vector> extends MahoutTestCase {
 
   private static final double FUZZ = 1.0e-13;
+  private static final double[] values = {1.1, 2.2, 3.3};
+  private static final double[] gold = {0.0, 1.1, 0.0, 2.2, 0.0, 3.3, 0.0};
+  private Vector test;
+
+  private static void checkIterator(Iterator<Vector.Element> nzIter, double[] values) {
+    while (nzIter.hasNext()) {
+      Vector.Element elt = nzIter.next();
+      assertEquals(elt.index() + " Value: " + values[elt.index()]
+          + " does not equal: " + elt.get(), values[elt.index()], elt.get(), 0.0);
+    }
+  }
 
   public abstract T vectorToTest(int size);
 
@@ -27,13 +40,14 @@ public abstract class AbstractVectorTest
     T v0 = vectorToTest(20);
     Random gen = RandomUtils.getRandom();
     Vector v1 = v0.assign(new Normal(0, 1, gen));
-    Vector v2 = vectorToTest(20).assign(new Normal(0, 1, gen));
 
+    // verify that v0 and v1 share and are identical
     assertEquals(v0.get(12), v1.get(12), 0);
     v0.set(12, gen.nextDouble());
     assertEquals(v0.get(12), v1.get(12), 0);
     assertSame(v0, v1);
 
+    Vector v2 = vectorToTest(20).assign(new Normal(0, 1, gen));
     Vector dv1 = new DenseVector(v1);
     Vector dv2 = new DenseVector(v2);
     Vector sv1 = new RandomAccessSparseVector(v1);
@@ -147,4 +161,420 @@ public abstract class AbstractVectorTest
 
 
   }
+
+  abstract Vector generateTestVector(int cardinality);
+
+  Vector getTestVector() {
+    return test;
+  }
+
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    test = generateTestVector(2 * values.length + 1);
+    for (int i = 0; i < values.length; i++) {
+      test.set(2*i + 1, values[i]);
+    }
+  }
+
+  @Test
+  public void testCardinality() {
+    assertEquals("size", 7, test.size());
+  }
+
+  @Test
+  public void testIterator() {
+    Iterator<Vector.Element> iterator = test.iterateNonZero();
+    checkIterator(iterator, gold);
+
+    iterator = test.iterator();
+    checkIterator(iterator, gold);
+
+    double[] doubles = {0.0, 5.0, 0, 3.0};
+    RandomAccessSparseVector zeros = new RandomAccessSparseVector(doubles.length);
+    for (int i = 0; i < doubles.length; i++) {
+      zeros.setQuick(i, doubles[i]);
+    }
+    iterator = zeros.iterateNonZero();
+    checkIterator(iterator, doubles);
+    iterator = zeros.iterator();
+    checkIterator(iterator, doubles);
+
+    doubles = new double[]{0.0, 0.0, 0, 0.0};
+    zeros = new RandomAccessSparseVector(doubles.length);
+    for (int i = 0; i < doubles.length; i++) {
+      zeros.setQuick(i, doubles[i]);
+    }
+    iterator = zeros.iterateNonZero();
+    checkIterator(iterator, doubles);
+    iterator = zeros.iterator();
+    checkIterator(iterator, doubles);
+
+  }
+
+  @Test
+  public void testIteratorSet() {
+    Vector clone = test.clone();
+    Iterator<Vector.Element> it = clone.iterateNonZero();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      e.set(e.get() * 2.0);
+    }
+    it = clone.iterateNonZero();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
+    }
+    clone = test.clone();
+    it = clone.iterator();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      e.set(e.get() * 2.0);
+    }
+    it = clone.iterator();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
+    }
+  }
+
+  @Test
+  public void testCopy() {
+    Vector copy = test.clone();
+    for (int i = 0; i < test.size(); i++) {
+      assertEquals("copy [" + i + ']', test.get(i), copy.get(i), EPSILON);
+    }
+  }
+
+  @Test
+  public void testGet() {
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2], test.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test(expected = IndexException.class)
+  public void testGetOver() {
+    test.get(test.size());
+  }
+
+  @Test(expected = IndexException.class)
+  public void testGetUnder() {
+    test.get(-1);
+  }
+
+  @Test
+  public void testSet() {
+    test.set(3, 4.5);
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+      } else if (i == 3) {
+        assertEquals("set [" + i + ']', 4.5, test.get(i), EPSILON);
+      } else {
+        assertEquals("set [" + i + ']', values[i/2], test.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testSize() {
+    assertEquals("size", 3, test.getNumNondefaultElements());
+  }
+
+  @Test
+  public void testViewPart() {
+    Vector part = test.viewPart(1, 2);
+    assertEquals("part size", 2, part.getNumNondefaultElements());
+    for (int i = 0; i < part.size(); i++) {
+      assertEquals("part[" + i + ']', test.get(i+1), part.get(i), EPSILON);
+    }
+  }
+
+  @Test(expected = IndexException.class)
+  public void testViewPartUnder() {
+    test.viewPart(-1, values.length);
+  }
+
+  @Test(expected = IndexException.class)
+  public void testViewPartOver() {
+    test.viewPart(2, 7);
+  }
+
+  @Test(expected = IndexException.class)
+  public void testViewPartCardinality() {
+    test.viewPart(1, 8);
+  }
+
+  @Test
+  public void testSparseDoubleVectorInt() {
+    Vector val = new RandomAccessSparseVector(4);
+    assertEquals("size", 4, val.size());
+    for (int i = 0; i < 4; i++) {
+      assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+    }
+  }
+
+  @Test
+  public void testDot() {
+    double res = test.dot(test);
+    double expected = 3.3 * 3.3 + 2.2 * 2.2 + 1.1 * 1.1;
+    assertEquals("dot", expected, res, EPSILON);
+  }
+
+  @Test
+  public void testDot2() {
+    Vector test2 = test.clone();
+    test2.set(1, 0.0);
+    test2.set(3, 0.0);
+    assertEquals(3.3 * 3.3, test2.dot(test), EPSILON);
+  }
+
+  @Test(expected = CardinalityException.class)
+  public void testDotCardinality() {
+    test.dot(new DenseVector(test.size() + 1));
+  }
+
+  @Test
+  public void testNormalize() {
+    Vector val = test.normalize();
+    double mag = Math.sqrt(1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3);
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("dot", values[i/2] / mag, val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testMinus() {
+    Vector val = test.minus(test);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+    }
+
+    val = test.minus(test).minus(test);
+    assertEquals("cardinality", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      assertEquals("get [" + i + ']', 0.0, val.get(i) + test.get(i), EPSILON);
+    }
+
+    Vector val1 = test.plus(1);
+    val = val1.minus(test);
+    for (int i = 0; i < test.size(); i++) {
+      assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
+    }
+
+    val1 = test.plus(-1);
+    val = val1.minus(test);
+    for (int i = 0; i < test.size(); i++) {
+      assertEquals("get [" + i + ']', -1.0, val.get(i), EPSILON);
+    }
+  }
+
+  @Test
+  public void testPlusDouble() {
+    Vector val = test.plus(1);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2] + 1.0, val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testPlusVector() {
+    Vector val = test.plus(test);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2] * 2.0, val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test(expected = CardinalityException.class)
+  public void testPlusVectorCardinality() {
+    test.plus(new DenseVector(test.size() + 1));
+  }
+
+  @Test
+  public void testTimesDouble() {
+    Vector val = test.times(3);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2] * 3.0, val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testDivideDouble() {
+    Vector val = test.divide(3);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2] / 3.0, val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testTimesVector() {
+    Vector val = test.times(test);
+    assertEquals("size", test.size(), val.size());
+    for (int i = 0; i < test.size(); i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+      } else {
+        assertEquals("get [" + i + ']', values[i/2] * values[i/2], val.get(i), EPSILON);
+      }
+    }
+  }
+
+  @Test(expected = CardinalityException.class)
+  public void testTimesVectorCardinality() {
+    test.times(new DenseVector(test.size() + 1));
+  }
+
+  @Test
+  public void testZSum() {
+    double expected = 0;
+    for (double value : values) {
+      expected += value;
+    }
+    assertEquals("wrong zSum", expected, test.zSum(), EPSILON);
+  }
+
+  @Test
+  public void testGetDistanceSquared() {
+    Vector other = new RandomAccessSparseVector(test.size());
+    other.set(1, -2);
+    other.set(2, -5);
+    other.set(3, -9);
+    other.set(4, 1);
+    double expected = test.minus(other).getLengthSquared();
+    assertTrue("a.getDistanceSquared(b) != a.minus(b).getLengthSquared",
+               Math.abs(expected - test.getDistanceSquared(other)) < 10.0E-7);
+  }
+
+  @Test
+  public void testAssignDouble() {
+    test.assign(0);
+    for (int i = 0; i < values.length; i++) {
+      assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+    }
+  }
+
+  @Test
+  public void testAssignDoubleArray() {
+    double[] array = new double[test.size()];
+    test.assign(array);
+    for (int i = 0; i < values.length; i++) {
+      assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+    }
+  }
+
+  @Test(expected = CardinalityException.class)
+  public void testAssignDoubleArrayCardinality() {
+    double[] array = new double[test.size() + 1];
+    test.assign(array);
+  }
+
+  @Test
+  public void testAssignVector() {
+    Vector other = new DenseVector(test.size());
+    test.assign(other);
+    for (int i = 0; i < values.length; i++) {
+      assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+    }
+  }
+
+  @Test(expected = CardinalityException.class)
+  public void testAssignVectorCardinality() {
+    Vector other = new DenseVector(test.size() - 1);
+    test.assign(other);
+  }
+
+  @Test
+  public void testAssignUnaryFunction() {
+    test.assign(Functions.NEGATE);
+    for (int i = 1; i < values.length; i += 2) {
+      assertEquals("value[" + i + ']', -values[i], test.getQuick(i+2), EPSILON);
+    }
+  }
+
+  @Test
+  public void testAssignBinaryFunction() {
+    test.assign(test, Functions.PLUS);
+    for (int i = 0; i < values.length; i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+      } else {
+        assertEquals("value[" + i + ']', 2 * values[i - 1], test.getQuick(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testAssignBinaryFunction2() {
+    test.assign(Functions.plus(4));
+    for (int i = 0; i < values.length; i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 4.0, test.get(i), EPSILON);
+      } else {
+        assertEquals("value[" + i + ']', values[i - 1] + 4, test.getQuick(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testAssignBinaryFunction3() {
+    test.assign(Functions.mult(4));
+    for (int i = 0; i < values.length; i++) {
+      if (i % 2 == 0) {
+        assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+      } else {
+        assertEquals("value[" + i + ']', values[i - 1] * 4, test.getQuick(i), EPSILON);
+      }
+    }
+  }
+
+  @Test
+  public void testLike() {
+    Vector other = test.like();
+    assertTrue("not like", test.getClass().isAssignableFrom(other.getClass()));
+    assertEquals("size", test.size(), other.size());
+  }
+
+  @Test
+  public void testCrossProduct() {
+    Matrix result = test.cross(test);
+    assertEquals("row size", test.size(), result.rowSize());
+    assertEquals("col size", test.size(), result.columnSize());
+    for (int row = 0; row < result.rowSize(); row++) {
+      for (int col = 0; col < result.columnSize(); col++) {
+        assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
+            * test.getQuick(col), result.getQuick(row, col), EPSILON);
+      }
+    }
+  }
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java Thu Oct 11 01:17:10 2012
@@ -59,4 +59,15 @@ public class CentroidTest extends Abstra
   public Vector vectorToTest(int size) {
     return new Centroid(new WeightedVector(new DenseVector(size), 3.15, 51));
   }
+
+  @Override
+  public void testSize() {
+    assertEquals("size", 7, getTestVector().getNumNondefaultElements());
+  }
+
+  @Override
+  Vector generateTestVector(int cardinality) {
+    return new Centroid(new WeightedVector(new DenseVector(cardinality), 3.14, 53));
+  }
+
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java Thu Oct 11 01:17:10 2012
@@ -17,7 +17,9 @@
 
 package org.apache.mahout.math;
 
-public final class TestDenseVector extends AbstractTestVector {
+import org.apache.mahout.math.function.Functions;
+
+public final class TestDenseVector extends AbstractVectorTest<DenseVector> {
 
   @Override
   Vector generateTestVector(int cardinality) {
@@ -29,4 +31,10 @@ public final class TestDenseVector exten
     assertEquals("size", 7, getTestVector().getNumNondefaultElements());
   }
 
+  @Override
+  public DenseVector vectorToTest(int size) {
+    DenseVector r = new DenseVector(size);
+    r.assign(Functions.random());
+    return r;
+  }
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -17,11 +17,26 @@
 
 package org.apache.mahout.math;
 
-public final class TestRandomAccessSparseVector extends AbstractTestVector {
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+public final class TestRandomAccessSparseVector extends AbstractVectorTest<RandomAccessSparseVector> {
 
   @Override
   Vector generateTestVector(int cardinality) {
     return new RandomAccessSparseVector(cardinality);
   }
 
+
+  @Override
+  public RandomAccessSparseVector vectorToTest(int size) {
+    RandomAccessSparseVector r = new RandomAccessSparseVector(size);
+    Random gen = RandomUtils.getRandom();
+    for (int i = 0; i < 3; i++) {
+      r.set(gen.nextInt(r.size()), gen.nextGaussian());
+    }
+    return r;
+  }
+
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -17,18 +17,20 @@
 
 package org.apache.mahout.math;
 
+import org.apache.mahout.common.RandomUtils;
 import org.junit.Test;
 
-public final class TestSequentialAccessSparseVector extends AbstractTestVector {
+import java.util.Random;
+
+public final class TestSequentialAccessSparseVector extends AbstractVectorTest<SequentialAccessSparseVector> {
 
   @Override
   Vector generateTestVector(int cardinality) {
     return new SequentialAccessSparseVector(cardinality);
   }
 
-  @Override
   @Test
-  public void testDot2() {
+  public void testDotSuperBig() {
     Vector w = new SequentialAccessSparseVector(Integer.MAX_VALUE, 12);
     w.set(1, 0.4);
     w.set(2, 0.4);
@@ -37,6 +39,17 @@ public final class TestSequentialAccessS
     Vector v = new SequentialAccessSparseVector(Integer.MAX_VALUE, 12);
     v.set(3, 1);
 
-    assertEquals("dot2", -0.666666667, v.dot(w), EPSILON);
+    assertEquals("super-big", -0.666666667, v.dot(w), EPSILON);
+  }
+
+
+  @Override
+  public SequentialAccessSparseVector vectorToTest(int size) {
+    SequentialAccessSparseVector r = new SequentialAccessSparseVector(size);
+    Random gen = RandomUtils.getRandom();
+    for (int i = 0; i < 3; i++) {
+      r.set(gen.nextInt(r.size()), gen.nextGaussian());
+    }
+    return r;
   }
 }
\ No newline at end of file

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java Thu Oct 11 01:17:10 2012
@@ -886,4 +886,37 @@ public final class VectorTest extends Ma
     assertFalse(left.hashCode() == right.hashCode());
   }
 
+  @Test
+  public void testMergeDiff() {
+    Vector left = new SequentialAccessSparseVector(20);
+    Vector right = new SequentialAccessSparseVector(20);
+
+    assertEquals(0, AbstractVector.mergeDiff(left, right), 0);
+
+    left.set(5, 1.5);
+    assertEquals(1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+    right.set(4, 3.1);
+    assertEquals(3.1 * 3.1 + 1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+    left.set(3, 1.2);
+    assertEquals(1.2 * 1.2 + 3.1 * 3.1 + 1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+    left.set(6, 2);
+    right.set(6, 2);
+    right.set(8, 2);
+    assertEquals(1.2 * 1.2 + 3.1 * 3.1 + 1.5 * 1.5 + 2 * 2, AbstractVector.mergeDiff(left, right), 0);
+  }
+
+  @Test
+  public void testRandomScanDiff() {
+    Vector left = new SequentialAccessSparseVector(20);
+    Vector right = new SequentialAccessSparseVector(20);
+    left.set(4, 1.1);
+    left.set(6, 2.1);
+    right.set(7, 3.1);
+    right.set(4, 1.2);
+
+    assertEquals(.1 * .1 + 2.1 * 2.1 + 3.1 * 3.1, AbstractVector.randomScanDiff(left, right), 0);
+  }
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java Thu Oct 11 01:17:10 2012
@@ -74,4 +74,14 @@ public class WeightedVectorTest extends 
     WeightedVector v5 = WeightedVector.project(q.viewColumn(0), qx);
     assertEquals(Math.sqrt(0.5), v5.getWeight(), 1.0e-13);
   }
+
+  @Override
+  public void testSize() {
+    assertEquals("size", 7, getTestVector().getNumNondefaultElements());
+  }
+
+  @Override
+  Vector generateTestVector(int cardinality) {
+    return new WeightedVector(new DenseVector(cardinality), 3.14, 53);
+  }
 }