You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2012/06/06 18:55:14 UTC
svn commit: r1346978 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
Author: jeastman
Date: Wed Jun 6 16:55:14 2012
New Revision: 1346978
URL: http://svn.apache.org/viewvc?rev=1346978&view=rev
Log:
MAHOUT-1028:
- Added unit test that produced a NaN pdf value with zero vector and/or zero cluster center
- Added zero vector corner case detection to CosineDistanceMeasure
- All tests run
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java?rev=1346978&r1=1346977&r2=1346978&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java Wed Jun 6 16:55:14 2012
@@ -81,13 +81,18 @@ public class CosineDistanceMeasure imple
denominator = dotProduct;
}
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 1;
+ }
+
return 1.0 - dotProduct / denominator;
}
@Override
public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
-
- double lengthSquaredv = v.getLengthSquared();
+
+ double lengthSquaredv = v.getLengthSquared();
double dotProduct = v.dot(centroid);
double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java?rev=1346978&r1=1346977&r2=1346978&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java Wed Jun 6 16:55:14 2012
@@ -30,17 +30,10 @@ import org.apache.mahout.clustering.cano
import org.apache.mahout.clustering.classify.ClusterClassifier;
import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
-import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
-import org.apache.mahout.clustering.iterator.ClusterIterator;
-import org.apache.mahout.clustering.iterator.ClusteringPolicy;
-import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy;
-import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
-import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
-import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
-import org.apache.mahout.clustering.iterator.MeanShiftClusteringPolicy;
import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.CosineDistanceMeasure;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
@@ -70,6 +63,15 @@ public final class TestClusterClassifier
return new ClusterClassifier(models, new KMeansClusteringPolicy());
}
+ private static ClusterClassifier newCosineKlusterClassifier() {
+ List<Cluster> models = Lists.newArrayList();
+ DistanceMeasure measure = new CosineDistanceMeasure();
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
+ return new ClusterClassifier(models, new KMeansClusteringPolicy());
+ }
+
private static ClusterClassifier newSoftClusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -277,8 +279,16 @@ public final class TestClusterClassifier
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
- }
-
+ }
}
}
+
+ @Test
+ public void testCosineKlusterClassification() {
+ ClusterClassifier classifier = newCosineKlusterClassifier();
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.333, 0.333, 0.333]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.545, 0.273, 0.182]", AbstractCluster.formatVector(pdf, null));
+ }
}