You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/10/11 03:17:11 UTC
svn commit: r1396888 - in /mahout/trunk:
core/src/test/java/org/apache/mahout/clustering/meanshift/
math/src/main/java/org/apache/mahout/math/
math/src/test/java/org/apache/mahout/math/
Author: tdunning
Date: Thu Oct 11 01:17:10 2012
New Revision: 1396888
URL: http://svn.apache.org/viewvc?rev=1396888&view=rev
Log:
MAHOUT-1086 - Deal with round-off errors in computing L_2 distances. Add special case to get higher accuracy when vector difference is small, merge AbstractVectorTest and AbstractTestVector, fix like() bug in Centroid and WeightedVector.
Removed:
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractTestVector.java
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java Thu Oct 11 01:17:10 2012
@@ -492,7 +492,7 @@ public final class TestMeanShift extends
ToolRunner.run(conf, new MeanShiftCanopyDriver(), args);
Path outPart = new Path(output, "clusters-3-final/part-r-00000");
long count = HadoopUtil.countRecords(outPart, conf);
- assertEquals("count", 4, count);
+ assertEquals("count", 3, count);
Iterator<?> iterator = new SequenceFileValueIterator<Writable>(outPart,
true, conf);
while (iterator.hasNext()) {
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/AbstractVector.java Thu Oct 11 01:17:10 2012
@@ -21,12 +21,13 @@ import org.apache.mahout.common.RandomUt
import org.apache.mahout.math.function.DoubleDoubleFunction;
import org.apache.mahout.math.function.DoubleFunction;
import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.set.OpenIntHashSet;
import java.util.Iterator;
/** Implementations of generic capabilities like sum of elements and dot products */
public abstract class AbstractVector implements Vector, LengthCachingVector {
-
+
private static final double LOG2 = Math.log(2.0);
private int size;
@@ -155,12 +156,12 @@ public abstract class AbstractVector imp
}
return result;
}
-
- public double dotSelf() {
+
+ protected double dotSelf() {
double result = 0.0;
- Iterator<Element> iter = iterateNonZero();
- while (iter.hasNext()) {
- double value = iter.next().get();
+ Iterator<Element> i = iterateNonZero();
+ while (i.hasNext()) {
+ double value = i.next().get();
result += value * value;
}
return result;
@@ -216,18 +217,18 @@ public abstract class AbstractVector imp
public Vector normalize(double power) {
return divide(norm(power));
}
-
+
@Override
public Vector logNormalize() {
return logNormalize(2.0, Math.sqrt(dotSelf()));
}
-
+
@Override
public Vector logNormalize(double power) {
return logNormalize(power, norm(power));
}
-
- public Vector logNormalize(double power, double normLength) {
+
+ public Vector logNormalize(double power, double normLength) {
// we can special case certain powers
if (Double.isInfinite(power) || power <= 1.0) {
throw new IllegalArgumentException("Power must be > 1 and < infinity");
@@ -293,8 +294,8 @@ public abstract class AbstractVector imp
}
@Override
- public void setLengthSquared(double d2) {
- lengthSquared = d2;
+ public void invalidateCachedLength() {
+ lengthSquared = -1;
}
@Override
@@ -303,34 +304,189 @@ public abstract class AbstractVector imp
throw new CardinalityException(size, v.size());
}
// if this and v has a cached lengthSquared, dot product is quickest way to compute this.
- if (lengthSquared >= 0 && v instanceof LengthCachingVector && v.getLengthSquared() >= 0) {
- return lengthSquared + v.getLengthSquared() - 2 * this.dot(v);
+ double d1;
+ double d2;
+ double dot;
+ if (lengthSquared >= 0) {
+ // our length squared is cached. use it
+ // the max is (slight) antidote to round-off errors
+ d1 = lengthSquared;
+ d2 = v.getLengthSquared();
+ dot = this.dot(v);
+ } else {
+ // our length is not cached... compute it and the dot product in one pass for speed
+ d1 = 0;
+ d2 = v.getLengthSquared();
+ dot = 0;
+ final Iterator<Element> i = iterateNonZero();
+ while (i.hasNext()) {
+ Element e = i.next();
+ double value = e.get();
+ d1 += value * value;
+ dot += value * v.getQuick(e.index());
+ }
+ lengthSquared = d1;
+ // again, round-off errors may be present
}
- Vector sparseAccessed;
- Vector randomlyAccessed;
- if (lengthSquared >= 0.0) {
- randomlyAccessed = this;
- sparseAccessed = v;
- } else { // TODO: could be further optimized, figure out which one is smaller, etc
- randomlyAccessed = v;
- sparseAccessed = this;
+
+ double r = d1 + d2 - 2 * dot;
+ if (r > 1e-3 * (d1 + d2)) {
+ return Math.max(0, r);
+ } else {
+ if (this.isSequentialAccess()) {
+ if (v.isSequentialAccess()) {
+ return mergeDiff(this, v);
+ } else {
+ return randomScanDiff(this, v);
+ }
+ } else {
+ return randomScanDiff(v, this);
+ }
}
+ }
- Iterator<Element> it = sparseAccessed.iterateNonZero();
- double d = randomlyAccessed.getLengthSquared();
- double d2 = 0;
- double dot = 0;
- while (it.hasNext()) {
- Element e = it.next();
- double value = e.get();
- d2 += value * value;
- dot += value * randomlyAccessed.getQuick(e.index());
+ /**
+ * Computes the squared difference of two vectors where iterateNonZero
+ * is efficient for each vector, but where the order of iteration is not
+ * known. This forces us to access most elements of v2 via get(), which
+ * would be very inefficient for some kinds of vectors.
+ *
+ * Note that this static method is exposed at a package level for testing purposes only.
+ * @param v1 The vector that we access only via iterateNonZero
+ * @param v2 The vector that we access via iterateNonZero and via Element.get()
+ * @return The squared difference between v1 and v2.
+ */
+ static double randomScanDiff(Vector v1, Vector v2) {
+ // keeps a list of elements we visited by iterating over v1. This should be
+ // almost all of the elements of v2 because we only call this method if the
+ // difference is small.
+ OpenIntHashSet visited = new OpenIntHashSet();
+
+ double r = 0;
+
+ // walk through non-zeros of v1
+ Iterator<Element> i = v1.iterateNonZero();
+ while (i.hasNext()) {
+ Element e1 = i.next();
+ visited.add(e1.index());
+ double x = e1.get() - v2.get(e1.index());
+ r += x * x;
+ }
+
+ // now walk through neglected elements of v2
+ i = v2.iterateNonZero();
+ while (i.hasNext()) {
+ Element e2 = i.next();
+ if (!visited.contains(e2.index())) {
+ // if not visited already then v1's value here would be zero.
+ double x = e2.get();
+ r += x * x;
+ }
}
- if (sparseAccessed instanceof LengthCachingVector) {
- ((LengthCachingVector) sparseAccessed).setLengthSquared(d2);
+
+ return r;
+ }
+
+ /**
+ * Computes the squared difference of two vectors where iterateNonZero returns
+ * elements in index order for both vectors. This allows a merge to be used to
+ * compute the difference. A merge allows a single sequential pass over each
+ * vector and should be faster than any alternative.
+ *
+ * Note that this static method is exposed at a package level for testing purposes only.
+ * @param v1 The first vector.
+ * @param v2 The second vector.
+ * @return The squared difference between the two vectors.
+ */
+ static double mergeDiff(Vector v1, Vector v2) {
+ Iterator<Element> i1 = v1.iterateNonZero();
+ Iterator<Element> i2 = v2.iterateNonZero();
+
+ // v1 is empty?
+ if (!i1.hasNext()) {
+ return v2.getLengthSquared();
+ }
+
+ // v2 is empty?
+ if (!i2.hasNext()) {
+ return v1.getLengthSquared();
+ }
+
+ Element e1 = i1.next();
+ Element e2 = i2.next();
+
+ double r = 0;
+ while (e1 != null && e2 != null) {
+ // eat elements of v1 that precede all in v2
+ while (e1 != null && e1.index() < e2.index()) {
+ double x = e1.get();
+ r += x * x;
+
+ if (i1.hasNext()) {
+ e1 = i1.next();
+ } else {
+ e1 = null;
+ }
+ }
+
+ // at this point we have three possibilities, e1 == null or e1 matches e2 or
+ // e2 precedes e1. Here we handle the e2 < e1 case
+ while (e2 != null && (e1 == null || e2.index() < e1.index())) {
+ double x = e2.get();
+ r += x * x;
+
+ if (i2.hasNext()) {
+ e2 = i2.next();
+ } else {
+ e2 = null;
+ }
+ }
+
+ // and now we handle the e1 == e2 case. For convenience, we
+ // grab as many of these as possible. Given that we are called here
+ // only when v1 and v2 are nearly equal, this loop should dominate
+ while (e1 != null && e2 != null && e1.index() == e2.index()) {
+ double x = e1.get() - e2.get();
+ r += x * x;
+
+ if (i1.hasNext()) {
+ e1 = i1.next();
+ } else {
+ e1 = null;
+ }
+
+ if (i2.hasNext()) {
+ e2 = i2.next();
+ } else {
+ e2 = null;
+ }
+ }
+ }
+
+ // one of i1 or i2 is exhausted here, but the other may not be
+ while (e1 != null ) {
+ double x = e1.get();
+ r += x * x;
+
+ if (i1.hasNext()) {
+ e1 = i1.next();
+ } else {
+ e1 = null;
+ }
}
- //assert d > -1.0e-9; // round-off errors should never be too far off!
- return Math.abs(d + d2 - 2 * dot);
+
+ while (e2 != null) {
+ double x = e2.get();
+ r += x * x;
+
+ if (i2.hasNext()) {
+ e2 = i2.next();
+ } else {
+ e2 = null;
+ }
+ }
+ // both v1 and v2 have been completely processed
+ return r;
}
@Override
@@ -348,7 +504,7 @@ public abstract class AbstractVector imp
}
return result;
}
-
+
@Override
public int maxValueIndex() {
int result = -1;
@@ -473,7 +629,7 @@ public abstract class AbstractVector imp
if (x == 1.0) {
return result;
}
-
+
Iterator<Element> iter = result.iterateNonZero();
while (iter.hasNext()) {
Element element = iter.next();
@@ -600,7 +756,7 @@ public abstract class AbstractVector imp
@Override
public final int size() {
- return size;
+ return size;
}
@Override
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/Centroid.java Thu Oct 11 01:17:10 2012
@@ -25,8 +25,8 @@ import org.apache.mahout.math.function.D
*/
public class Centroid extends WeightedVector {
public Centroid(WeightedVector original) {
- super(original.size(), original.getWeight(), original.getIndex());
- delegate = original.like();
+ super(original.getWeight(), original.getIndex());
+ delegate = original.getVector().like();
delegate.assign(original);
}
@@ -67,7 +67,12 @@ public class Centroid extends WeightedVe
setWeight(totalWeight);
}
- /**
+ @Override
+ public Vector like() {
+ return new Centroid(getIndex(), getVector().like(), getWeight());
+ }
+
+ /**
* Gets the index of this centroid. Use getIndex instead to maintain standard names.
*/
@Deprecated
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/DelegatingVector.java Thu Oct 11 01:17:10 2012
@@ -32,14 +32,13 @@ import java.util.Iterator;
public class DelegatingVector implements Vector, LengthCachingVector {
protected Vector delegate;
- public DelegatingVector(int size) {
- delegate = new DenseVector(size);
- }
-
public DelegatingVector(Vector v) {
delegate = v;
}
+ protected DelegatingVector() {
+ }
+
public Vector getVector() {
return delegate;
}
@@ -126,14 +125,10 @@ public class DelegatingVector implements
return delegate.getLengthSquared();
}
- // not normally called because the delegate vector is who would need this and
- // they will call their own version of this method. In fact, if the delegate is
- // also a delegating vector the same logic will apply recursively down to the first
- // non-delegating vector. This makes this very hard to test except in trivial ways.
@Override
- public void setLengthSquared(double d2) {
+ public void invalidateCachedLength() {
if (delegate instanceof LengthCachingVector) {
- ((LengthCachingVector) delegate).setLengthSquared(d2);
+ ((LengthCachingVector) delegate).invalidateCachedLength();
}
}
@@ -275,7 +270,7 @@ public class DelegatingVector implements
@Override
public Vector like() {
- return delegate.like();
+ return new DelegatingVector(delegate.like());
}
@Override
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/DenseVector.java Thu Oct 11 01:17:10 2012
@@ -56,7 +56,7 @@ public class DenseVector extends Abstrac
/**
* Copy-constructor (for use in turning a sparse vector into a dense one, for example)
- * @param vector
+ * @param vector The vector to copy
*/
public DenseVector(Vector vector) {
super(vector.size());
@@ -95,7 +95,7 @@ public class DenseVector extends Abstrac
}
@Override
- public double dotSelf() {
+ protected double dotSelf() {
double result = 0.0;
int max = size();
for (int i = 0; i < max; i++) {
@@ -117,13 +117,13 @@ public class DenseVector extends Abstrac
@Override
public void setQuick(int index, double value) {
- lengthSquared = -1.0;
+ invalidateCachedLength();
values[index] = value;
}
@Override
public Vector assign(double value) {
- this.lengthSquared = -1;
+ invalidateCachedLength();
Arrays.fill(values, value);
return this;
}
@@ -145,7 +145,7 @@ public class DenseVector extends Abstrac
values[i] = function.apply(values[i], other.getQuick(i));
}
}
- lengthSquared = -1;
+ invalidateCachedLength();
return this;
}
@@ -197,21 +197,6 @@ public class DenseVector extends Abstrac
return super.equals(o);
}
- @Override
- public double getLengthSquared() {
- if (lengthSquared >= 0.0) {
- return lengthSquared;
- }
-
- double result = 0.0;
- for (double value : values) {
- result += value * value;
-
- }
- lengthSquared = result;
- return result;
- }
-
public void addAll(Vector v) {
if (size() != v.size()) {
throw new CardinalityException(size(), v.size());
@@ -281,7 +266,7 @@ public class DenseVector extends Abstrac
@Override
public void set(double value) {
- lengthSquared = -1;
+ invalidateCachedLength();
values[index] = value;
}
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/LengthCachingVector.java Thu Oct 11 01:17:10 2012
@@ -21,12 +21,15 @@ package org.apache.mahout.math;
* Marker interface for vectors that may cache their squared length.
*/
interface LengthCachingVector {
+ /**
+ * Gets the currently cached squared length or if there is none, recalculates
+ * the value and returns that.
+ * @return The sum of the squares of all elements in the vector.
+ */
double getLengthSquared();
/**
- * This is a very dangerous method to call. Passing in a wrong value can
- * completely screw up distance computations and normalization.
- * @param d2 The new value for the squared length cache.
+ * Invalidates the length cache. This should be called by all mutators of the vector.
*/
- void setLengthSquared(double d2);
+ void invalidateCachedLength();
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/RandomAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -132,7 +132,7 @@ public class RandomAccessSparseVector ex
@Override
public void setQuick(int index, double value) {
- lengthSquared = -1.0;
+ invalidateCachedLength();
if (value == 0.0) {
values.removeKey(index);
} else {
@@ -225,7 +225,7 @@ public class RandomAccessSparseVector ex
@Override
public void set(double value) {
- lengthSquared = -1;
+ invalidateCachedLength();
if (value == 0.0) {
values.removeKey(index);
} else {
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/SequentialAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -186,7 +186,7 @@ public class SequentialAccessSparseVecto
@Override
public void setQuick(int index, double value) {
- lengthSquared = -1;
+ invalidateCachedLength();
values.set(index, value);
}
@@ -285,7 +285,7 @@ public class SequentialAccessSparseVecto
@Override
public void set(double value) {
- lengthSquared = -1;
+ invalidateCachedLength();
values.getValues()[offset] = value;
}
}
@@ -321,7 +321,7 @@ public class SequentialAccessSparseVecto
@Override
public void set(double value) {
- lengthSquared = -1;
+ invalidateCachedLength();
if (index == values.getIndices()[nextOffset]) {
values.getValues()[nextOffset] = value;
} else {
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/WeightedVector.java Thu Oct 11 01:17:10 2012
@@ -21,79 +21,84 @@ package org.apache.mahout.math;
* Decorates a vector with a floating point weight and an index.
*/
public class WeightedVector extends DelegatingVector implements Comparable<WeightedVector> {
- private static final int INVALID_INDEX = -1;
- private double weight;
- private int index;
-
- protected WeightedVector(int size, double weight, int index) {
- super(size);
- this.weight = weight;
- this.index = index;
- }
-
- public WeightedVector(Vector v, double weight, int index) {
- super(v);
- this.weight = weight;
- this.index = index;
- }
-
- public WeightedVector(Vector v, Vector projection, int index) {
- super(v);
- this.index = index;
- this.weight = v.dot(projection);
- }
-
- public static WeightedVector project(Vector v, Vector projection) {
- return project(v, projection, INVALID_INDEX);
- }
-
- public static WeightedVector project(Vector v, Vector projection, int index) {
- return new WeightedVector(v, projection, index);
- }
-
- public double getWeight() {
- return weight;
- }
-
-
- @Override
- public int compareTo(WeightedVector other) {
- if (this == other) {
- return 0;
- }
- int r = Double.compare(weight, other.getWeight());
- if (r == 0 || Math.abs(weight - other.getWeight()) < 1.0e-8) {
- double diff = this.minus(other).norm(1);
- if (diff < 1.0e-12) {
- return 0;
- } else {
- for (Vector.Element element : this) {
- r = Double.compare(element.get(), other.get(element.index()));
- if (r != 0) {
- return r;
- }
- }
- return 0;
- }
- } else {
+ private static final int INVALID_INDEX = -1;
+ private double weight;
+ private int index;
+
+ protected WeightedVector(double weight, int index) {
+ super();
+ this.weight = weight;
+ this.index = index;
+ }
+
+ public WeightedVector(Vector v, double weight, int index) {
+ super(v);
+ this.weight = weight;
+ this.index = index;
+ }
+
+ public WeightedVector(Vector v, Vector projection, int index) {
+ super(v);
+ this.index = index;
+ this.weight = v.dot(projection);
+ }
+
+ public static WeightedVector project(Vector v, Vector projection) {
+ return project(v, projection, INVALID_INDEX);
+ }
+
+ public static WeightedVector project(Vector v, Vector projection, int index) {
+ return new WeightedVector(v, projection, index);
+ }
+
+ public double getWeight() {
+ return weight;
+ }
+
+
+ @Override
+ public int compareTo(WeightedVector other) {
+ if (this == other) {
+ return 0;
+ }
+ int r = Double.compare(weight, other.getWeight());
+ if (r == 0 || Math.abs(weight - other.getWeight()) < 1e-8) {
+ double diff = this.minus(other).norm(1);
+ if (diff < 1e-12) {
+ return 0;
+ } else {
+ for (Vector.Element element : this) {
+ r = Double.compare(element.get(), other.get(element.index()));
+ if (r != 0) {
return r;
+ }
}
- }
-
- public int getIndex() {
- return index;
- }
-
- public void setWeight(double newWeight) {
- this.weight = newWeight;
- }
-
- public void setIndex(int index) {
- this.index = index;
- }
-
- @Override
- public String toString() {
- return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
- }
+ return 0;
+ }
+ } else {
+ return r;
+ }
+ }
+
+ public int getIndex() {
+ return index;
+ }
+
+ public void setWeight(double newWeight) {
+ this.weight = newWeight;
+ }
+
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public Vector like() {
+ return new WeightedVector(getVector().like(), weight, index);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("index=%d, weight=%.2f, v=%s", index, weight, getVector());
+ }
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/AbstractVectorTest.java Thu Oct 11 01:17:10 2012
@@ -3,8 +3,10 @@ package org.apache.mahout.math;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.jet.random.Normal;
+import org.junit.Before;
import org.junit.Test;
+import java.util.Iterator;
import java.util.Random;
/**
@@ -18,6 +20,17 @@ import java.util.Random;
public abstract class AbstractVectorTest<T extends Vector> extends MahoutTestCase {
private static final double FUZZ = 1.0e-13;
+ private static final double[] values = {1.1, 2.2, 3.3};
+ private static final double[] gold = {0.0, 1.1, 0.0, 2.2, 0.0, 3.3, 0.0};
+ private Vector test;
+
+ private static void checkIterator(Iterator<Vector.Element> nzIter, double[] values) {
+ while (nzIter.hasNext()) {
+ Vector.Element elt = nzIter.next();
+ assertEquals(elt.index() + " Value: " + values[elt.index()]
+ + " does not equal: " + elt.get(), values[elt.index()], elt.get(), 0.0);
+ }
+ }
public abstract T vectorToTest(int size);
@@ -27,13 +40,14 @@ public abstract class AbstractVectorTest
T v0 = vectorToTest(20);
Random gen = RandomUtils.getRandom();
Vector v1 = v0.assign(new Normal(0, 1, gen));
- Vector v2 = vectorToTest(20).assign(new Normal(0, 1, gen));
+ // verify that v0 and v1 share and are identical
assertEquals(v0.get(12), v1.get(12), 0);
v0.set(12, gen.nextDouble());
assertEquals(v0.get(12), v1.get(12), 0);
assertSame(v0, v1);
+ Vector v2 = vectorToTest(20).assign(new Normal(0, 1, gen));
Vector dv1 = new DenseVector(v1);
Vector dv2 = new DenseVector(v2);
Vector sv1 = new RandomAccessSparseVector(v1);
@@ -147,4 +161,420 @@ public abstract class AbstractVectorTest
}
+
+ abstract Vector generateTestVector(int cardinality);
+
+ Vector getTestVector() {
+ return test;
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ test = generateTestVector(2 * values.length + 1);
+ for (int i = 0; i < values.length; i++) {
+ test.set(2*i + 1, values[i]);
+ }
+ }
+
+ @Test
+ public void testCardinality() {
+ assertEquals("size", 7, test.size());
+ }
+
+ @Test
+ public void testIterator() {
+ Iterator<Vector.Element> iterator = test.iterateNonZero();
+ checkIterator(iterator, gold);
+
+ iterator = test.iterator();
+ checkIterator(iterator, gold);
+
+ double[] doubles = {0.0, 5.0, 0, 3.0};
+ RandomAccessSparseVector zeros = new RandomAccessSparseVector(doubles.length);
+ for (int i = 0; i < doubles.length; i++) {
+ zeros.setQuick(i, doubles[i]);
+ }
+ iterator = zeros.iterateNonZero();
+ checkIterator(iterator, doubles);
+ iterator = zeros.iterator();
+ checkIterator(iterator, doubles);
+
+ doubles = new double[]{0.0, 0.0, 0, 0.0};
+ zeros = new RandomAccessSparseVector(doubles.length);
+ for (int i = 0; i < doubles.length; i++) {
+ zeros.setQuick(i, doubles[i]);
+ }
+ iterator = zeros.iterateNonZero();
+ checkIterator(iterator, doubles);
+ iterator = zeros.iterator();
+ checkIterator(iterator, doubles);
+
+ }
+
+ @Test
+ public void testIteratorSet() {
+ Vector clone = test.clone();
+ Iterator<Vector.Element> it = clone.iterateNonZero();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ e.set(e.get() * 2.0);
+ }
+ it = clone.iterateNonZero();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
+ }
+ clone = test.clone();
+ it = clone.iterator();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ e.set(e.get() * 2.0);
+ }
+ it = clone.iterator();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ assertEquals(test.get(e.index()) * 2.0, e.get(), EPSILON);
+ }
+ }
+
+ @Test
+ public void testCopy() {
+ Vector copy = test.clone();
+ for (int i = 0; i < test.size(); i++) {
+ assertEquals("copy [" + i + ']', test.get(i), copy.get(i), EPSILON);
+ }
+ }
+
+ @Test
+ public void testGet() {
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2], test.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test(expected = IndexException.class)
+ public void testGetOver() {
+ test.get(test.size());
+ }
+
+ @Test(expected = IndexException.class)
+ public void testGetUnder() {
+ test.get(-1);
+ }
+
+ @Test
+ public void testSet() {
+ test.set(3, 4.5);
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+ } else if (i == 3) {
+ assertEquals("set [" + i + ']', 4.5, test.get(i), EPSILON);
+ } else {
+ assertEquals("set [" + i + ']', values[i/2], test.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testSize() {
+ assertEquals("size", 3, test.getNumNondefaultElements());
+ }
+
+ @Test
+ public void testViewPart() {
+ Vector part = test.viewPart(1, 2);
+ assertEquals("part size", 2, part.getNumNondefaultElements());
+ for (int i = 0; i < part.size(); i++) {
+ assertEquals("part[" + i + ']', test.get(i+1), part.get(i), EPSILON);
+ }
+ }
+
+ @Test(expected = IndexException.class)
+ public void testViewPartUnder() {
+ test.viewPart(-1, values.length);
+ }
+
+ @Test(expected = IndexException.class)
+ public void testViewPartOver() {
+ test.viewPart(2, 7);
+ }
+
+ @Test(expected = IndexException.class)
+ public void testViewPartCardinality() {
+ test.viewPart(1, 8);
+ }
+
+ @Test
+ public void testSparseDoubleVectorInt() {
+ Vector val = new RandomAccessSparseVector(4);
+ assertEquals("size", 4, val.size());
+ for (int i = 0; i < 4; i++) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ }
+ }
+
+ @Test
+ public void testDot() {
+ double res = test.dot(test);
+ double expected = 3.3 * 3.3 + 2.2 * 2.2 + 1.1 * 1.1;
+ assertEquals("dot", expected, res, EPSILON);
+ }
+
+ @Test
+ public void testDot2() {
+ Vector test2 = test.clone();
+ test2.set(1, 0.0);
+ test2.set(3, 0.0);
+ assertEquals(3.3 * 3.3, test2.dot(test), EPSILON);
+ }
+
+ @Test(expected = CardinalityException.class)
+ public void testDotCardinality() {
+ test.dot(new DenseVector(test.size() + 1));
+ }
+
+ @Test
+ public void testNormalize() {
+ Vector val = test.normalize();
+ double mag = Math.sqrt(1.1 * 1.1 + 2.2 * 2.2 + 3.3 * 3.3);
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("dot", values[i/2] / mag, val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testMinus() {
+ Vector val = test.minus(test);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ }
+
+ val = test.minus(test).minus(test);
+ assertEquals("cardinality", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i) + test.get(i), EPSILON);
+ }
+
+ Vector val1 = test.plus(1);
+ val = val1.minus(test);
+ for (int i = 0; i < test.size(); i++) {
+ assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
+ }
+
+ val1 = test.plus(-1);
+ val = val1.minus(test);
+ for (int i = 0; i < test.size(); i++) {
+ assertEquals("get [" + i + ']', -1.0, val.get(i), EPSILON);
+ }
+ }
+
+ @Test
+ public void testPlusDouble() {
+ Vector val = test.plus(1);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 1.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2] + 1.0, val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testPlusVector() {
+ Vector val = test.plus(test);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2] * 2.0, val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test(expected = CardinalityException.class)
+ public void testPlusVectorCardinality() {
+ test.plus(new DenseVector(test.size() + 1));
+ }
+
+ @Test
+ public void testTimesDouble() {
+ Vector val = test.times(3);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2] * 3.0, val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testDivideDouble() {
+ Vector val = test.divide(3);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2] / 3.0, val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testTimesVector() {
+ Vector val = test.times(test);
+ assertEquals("size", test.size(), val.size());
+ for (int i = 0; i < test.size(); i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, val.get(i), EPSILON);
+ } else {
+ assertEquals("get [" + i + ']', values[i/2] * values[i/2], val.get(i), EPSILON);
+ }
+ }
+ }
+
+ @Test(expected = CardinalityException.class)
+ public void testTimesVectorCardinality() {
+ test.times(new DenseVector(test.size() + 1));
+ }
+
+ @Test
+ public void testZSum() {
+ double expected = 0;
+ for (double value : values) {
+ expected += value;
+ }
+ assertEquals("wrong zSum", expected, test.zSum(), EPSILON);
+ }
+
+ @Test
+ public void testGetDistanceSquared() {
+ Vector other = new RandomAccessSparseVector(test.size());
+ other.set(1, -2);
+ other.set(2, -5);
+ other.set(3, -9);
+ other.set(4, 1);
+ double expected = test.minus(other).getLengthSquared();
+ assertTrue("a.getDistanceSquared(b) != a.minus(b).getLengthSquared",
+ Math.abs(expected - test.getDistanceSquared(other)) < 10.0E-7);
+ }
+
+ @Test
+ public void testAssignDouble() {
+ test.assign(0);
+ for (int i = 0; i < values.length; i++) {
+ assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+ }
+ }
+
+ @Test
+ public void testAssignDoubleArray() {
+ double[] array = new double[test.size()];
+ test.assign(array);
+ for (int i = 0; i < values.length; i++) {
+ assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+ }
+ }
+
+ @Test(expected = CardinalityException.class)
+ public void testAssignDoubleArrayCardinality() {
+ double[] array = new double[test.size() + 1];
+ test.assign(array);
+ }
+
+ @Test
+ public void testAssignVector() {
+ Vector other = new DenseVector(test.size());
+ test.assign(other);
+ for (int i = 0; i < values.length; i++) {
+ assertEquals("value[" + i + ']', 0.0, test.getQuick(i), EPSILON);
+ }
+ }
+
+ @Test(expected = CardinalityException.class)
+ public void testAssignVectorCardinality() {
+ Vector other = new DenseVector(test.size() - 1);
+ test.assign(other);
+ }
+
+ @Test
+ public void testAssignUnaryFunction() {
+ test.assign(Functions.NEGATE);
+ for (int i = 1; i < values.length; i += 2) {
+ assertEquals("value[" + i + ']', -values[i], test.getQuick(i+2), EPSILON);
+ }
+ }
+
+ @Test
+ public void testAssignBinaryFunction() {
+ test.assign(test, Functions.PLUS);
+ for (int i = 0; i < values.length; i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+ } else {
+ assertEquals("value[" + i + ']', 2 * values[i - 1], test.getQuick(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testAssignBinaryFunction2() {
+ test.assign(Functions.plus(4));
+ for (int i = 0; i < values.length; i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 4.0, test.get(i), EPSILON);
+ } else {
+ assertEquals("value[" + i + ']', values[i - 1] + 4, test.getQuick(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testAssignBinaryFunction3() {
+ test.assign(Functions.mult(4));
+ for (int i = 0; i < values.length; i++) {
+ if (i % 2 == 0) {
+ assertEquals("get [" + i + ']', 0.0, test.get(i), EPSILON);
+ } else {
+ assertEquals("value[" + i + ']', values[i - 1] * 4, test.getQuick(i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testLike() {
+ Vector other = test.like();
+ assertTrue("not like", test.getClass().isAssignableFrom(other.getClass()));
+ assertEquals("size", test.size(), other.size());
+ }
+
+ @Test
+ public void testCrossProduct() {
+ Matrix result = test.cross(test);
+ assertEquals("row size", test.size(), result.rowSize());
+ assertEquals("col size", test.size(), result.columnSize());
+ for (int row = 0; row < result.rowSize(); row++) {
+ for (int col = 0; col < result.columnSize(); col++) {
+ assertEquals("cross[" + row + "][" + col + ']', test.getQuick(row)
+ * test.getQuick(col), result.getQuick(row, col), EPSILON);
+ }
+ }
+ }
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/CentroidTest.java Thu Oct 11 01:17:10 2012
@@ -59,4 +59,15 @@ public class CentroidTest extends Abstra
public Vector vectorToTest(int size) {
return new Centroid(new WeightedVector(new DenseVector(size), 3.15, 51));
}
+
+ @Override
+ public void testSize() {
+ assertEquals("size", 7, getTestVector().getNumNondefaultElements());
+ }
+
+ @Override
+ Vector generateTestVector(int cardinality) {
+ return new Centroid(new WeightedVector(new DenseVector(cardinality), 3.14, 53));
+ }
+
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestDenseVector.java Thu Oct 11 01:17:10 2012
@@ -17,7 +17,9 @@
package org.apache.mahout.math;
-public final class TestDenseVector extends AbstractTestVector {
+import org.apache.mahout.math.function.Functions;
+
+public final class TestDenseVector extends AbstractVectorTest<DenseVector> {
@Override
Vector generateTestVector(int cardinality) {
@@ -29,4 +31,10 @@ public final class TestDenseVector exten
assertEquals("size", 7, getTestVector().getNumNondefaultElements());
}
+ @Override
+ public DenseVector vectorToTest(int size) {
+ DenseVector r = new DenseVector(size);
+ r.assign(Functions.random());
+ return r;
+ }
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestRandomAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -17,11 +17,26 @@
package org.apache.mahout.math;
-public final class TestRandomAccessSparseVector extends AbstractTestVector {
+import org.apache.mahout.common.RandomUtils;
+
+import java.util.Random;
+
+public final class TestRandomAccessSparseVector extends AbstractVectorTest<RandomAccessSparseVector> {
@Override
Vector generateTestVector(int cardinality) {
return new RandomAccessSparseVector(cardinality);
}
+
+ @Override
+ public RandomAccessSparseVector vectorToTest(int size) {
+ RandomAccessSparseVector r = new RandomAccessSparseVector(size);
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < 3; i++) {
+ r.set(gen.nextInt(r.size()), gen.nextGaussian());
+ }
+ return r;
+ }
+
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/TestSequentialAccessSparseVector.java Thu Oct 11 01:17:10 2012
@@ -17,18 +17,20 @@
package org.apache.mahout.math;
+import org.apache.mahout.common.RandomUtils;
import org.junit.Test;
-public final class TestSequentialAccessSparseVector extends AbstractTestVector {
+import java.util.Random;
+
+public final class TestSequentialAccessSparseVector extends AbstractVectorTest<SequentialAccessSparseVector> {
@Override
Vector generateTestVector(int cardinality) {
return new SequentialAccessSparseVector(cardinality);
}
- @Override
@Test
- public void testDot2() {
+ public void testDotSuperBig() {
Vector w = new SequentialAccessSparseVector(Integer.MAX_VALUE, 12);
w.set(1, 0.4);
w.set(2, 0.4);
@@ -37,6 +39,17 @@ public final class TestSequentialAccessS
Vector v = new SequentialAccessSparseVector(Integer.MAX_VALUE, 12);
v.set(3, 1);
- assertEquals("dot2", -0.666666667, v.dot(w), EPSILON);
+ assertEquals("super-big", -0.666666667, v.dot(w), EPSILON);
+ }
+
+
+ @Override
+ public SequentialAccessSparseVector vectorToTest(int size) {
+ SequentialAccessSparseVector r = new SequentialAccessSparseVector(size);
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < 3; i++) {
+ r.set(gen.nextInt(r.size()), gen.nextGaussian());
+ }
+ return r;
}
}
\ No newline at end of file
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/VectorTest.java Thu Oct 11 01:17:10 2012
@@ -886,4 +886,37 @@ public final class VectorTest extends Ma
assertFalse(left.hashCode() == right.hashCode());
}
+ @Test
+ public void testMergeDiff() {
+ Vector left = new SequentialAccessSparseVector(20);
+ Vector right = new SequentialAccessSparseVector(20);
+
+ assertEquals(0, AbstractVector.mergeDiff(left, right), 0);
+
+ left.set(5, 1.5);
+ assertEquals(1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+ right.set(4, 3.1);
+ assertEquals(3.1 * 3.1 + 1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+ left.set(3, 1.2);
+ assertEquals(1.2 * 1.2 + 3.1 * 3.1 + 1.5 * 1.5, AbstractVector.mergeDiff(left, right), 0);
+
+ left.set(6, 2);
+ right.set(6, 2);
+ right.set(8, 2);
+ assertEquals(1.2 * 1.2 + 3.1 * 3.1 + 1.5 * 1.5 + 2 * 2, AbstractVector.mergeDiff(left, right), 0);
+ }
+
+ @Test
+ public void testRandomScanDiff() {
+ Vector left = new SequentialAccessSparseVector(20);
+ Vector right = new SequentialAccessSparseVector(20);
+ left.set(4, 1.1);
+ left.set(6, 2.1);
+ right.set(7, 3.1);
+ right.set(4, 1.2);
+
+ assertEquals(.1 * .1 + 2.1 * 2.1 + 3.1 * 3.1, AbstractVector.randomScanDiff(left, right), 0);
+ }
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java?rev=1396888&r1=1396887&r2=1396888&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/WeightedVectorTest.java Thu Oct 11 01:17:10 2012
@@ -74,4 +74,14 @@ public class WeightedVectorTest extends
WeightedVector v5 = WeightedVector.project(q.viewColumn(0), qx);
assertEquals(Math.sqrt(0.5), v5.getWeight(), 1.0e-13);
}
+
+ @Override
+ public void testSize() {
+ assertEquals("size", 7, getTestVector().getNumNondefaultElements());
+ }
+
+ @Override
+ Vector generateTestVector(int cardinality) {
+ return new WeightedVector(new DenseVector(cardinality), 3.14, 53);
+ }
}