You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2012/06/06 18:55:14 UTC

svn commit: r1346978 - in /mahout/trunk/core/src: main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java

Author: jeastman
Date: Wed Jun  6 16:55:14 2012
New Revision: 1346978

URL: http://svn.apache.org/viewvc?rev=1346978&view=rev
Log:
MAHOUT-1028: 
- Added unit test that produced a NaN pdf value with zero vector and/or zero cluster center
- Added zero vector corner case detection to CosineDistanceMeasure
- All tests run

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java?rev=1346978&r1=1346977&r2=1346978&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java Wed Jun  6 16:55:14 2012
@@ -81,13 +81,18 @@ public class CosineDistanceMeasure imple
       denominator = dotProduct;
     }
     
+    // correct for zero-vector corner case
+    if (denominator == 0 && dotProduct == 0) {
+      return 1;
+    }
+    
     return 1.0 - dotProduct / denominator;
   }
   
   @Override
   public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
-
-    double lengthSquaredv =  v.getLengthSquared();
+    
+    double lengthSquaredv = v.getLengthSquared();
     
     double dotProduct = v.dot(centroid);
     double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java?rev=1346978&r1=1346977&r2=1346978&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java Wed Jun  6 16:55:14 2012
@@ -30,17 +30,10 @@ import org.apache.mahout.clustering.cano
 import org.apache.mahout.clustering.classify.ClusterClassifier;
 import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
 import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
-import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
-import org.apache.mahout.clustering.iterator.ClusterIterator;
-import org.apache.mahout.clustering.iterator.ClusteringPolicy;
-import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy;
-import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
-import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
-import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
-import org.apache.mahout.clustering.iterator.MeanShiftClusteringPolicy;
 import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
 import org.apache.mahout.clustering.meanshift.MeanShiftCanopy;
 import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.CosineDistanceMeasure;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
 import org.apache.mahout.math.DenseVector;
@@ -70,6 +63,15 @@ public final class TestClusterClassifier
     return new ClusterClassifier(models, new KMeansClusteringPolicy());
   }
   
+  private static ClusterClassifier newCosineKlusterClassifier() {
+    List<Cluster> models = Lists.newArrayList();
+    DistanceMeasure measure = new CosineDistanceMeasure();
+    models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
+    models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
+    models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
+    return new ClusterClassifier(models, new KMeansClusteringPolicy());
+  }
+
   private static ClusterClassifier newSoftClusterClassifier() {
     List<Cluster> models = Lists.newArrayList();
     DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -277,8 +279,16 @@ public final class TestClusterClassifier
       assertEquals(3, posterior.getModels().size());
       for (Cluster cluster : posterior.getModels()) {
         System.out.println(cluster.asFormatString(null));
-      }
-      
+      }     
     }
   }
+  
+  @Test
+  public void testCosineKlusterClassification() {
+    ClusterClassifier classifier = newCosineKlusterClassifier();
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.333, 0.333, 0.333]", AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.545, 0.273, 0.182]", AbstractCluster.formatVector(pdf, null));
+  }
 }