You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/08/08 03:42:06 UTC

svn commit: r802283 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/kmeans/ main/java/org/apache/mahout/matrix/ main/java/org/apache/mahout/utils/ test/java/org/apache/mahout/matrix/

Author: gsingers
Date: Sat Aug  8 01:42:06 2009
New Revision: 802283

URL: http://svn.apache.org/viewvc?rev=802283&view=rev
Log:
MAHOUT-121: distance calculation improvements

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/SquaredEuclideanDistanceMeasure.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=802283&r1=802282&r2=802283&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Sat Aug  8 01:42:06 2009
@@ -52,9 +52,6 @@
   // the current centroid is lazy evaluated and may be null
   private Vector centroid = null;
 
-  // the standard deviation of the covered points
-  private double std;
-
 
   // the total of all the points squared, used for std computation
   private Vector pointSquaredTotal = null;
@@ -171,8 +168,9 @@
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
-      double distance = measure.distance(point, cluster.getCenter());
-      if (nearestCluster == null || distance < nearestDistance) {
+      Vector clusterCenter = cluster.getCenter();
+      double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
+      if (distance < nearestDistance || nearestCluster == null ) {
         nearestCluster = cluster;
         nearestDistance = distance;
       }
@@ -187,8 +185,9 @@
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
-      double distance = measure.distance(point, cluster.getCenter());
-      if (nearestCluster == null || distance < nearestDistance) {
+      Vector clusterCenter = cluster.getCenter();
+      double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
+      if (distance < nearestDistance || nearestCluster == null) {
         nearestCluster = cluster;
         nearestDistance = distance;
       }
@@ -209,10 +208,6 @@
     } else if (centroid == null) {
       // lazy compute new centroid
       centroid = pointTotal.divide(numPoints);
-      Vector stds = pointSquaredTotal.times(numPoints).minus(
-          pointTotal.times(pointTotal)).assign(new SquareRootFunction())
-          .divide(numPoints);
-      std = stds.zSum() / 2;
     }
     return centroid;
   }
@@ -323,7 +318,10 @@
 
   /** @return the std */
   public double getStd() {
-    return std;
+    Vector stds = pointSquaredTotal.times(numPoints).minus(
+          pointTotal.times(pointTotal)).assign(new SquareRootFunction())
+          .divide(numPoints);
+    return stds.zSum() / 2;
   }
 
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java?rev=802283&r1=802282&r2=802283&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/matrix/SparseVector.java Sat Aug  8 01:42:06 2009
@@ -289,11 +289,11 @@
     this.values = values;
   }
 
-  private Double lengthSquared = null;
+  private double lengthSquared;
 
   @Override
   public double getLengthSquared() {
-    if (lengthSquared != null) {
+    if (lengthSquared < 0.0) {
       return lengthSquared;
     }
     double result = 0.0;

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/SquaredEuclideanDistanceMeasure.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/SquaredEuclideanDistanceMeasure.java?rev=802283&r1=802282&r2=802283&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/SquaredEuclideanDistanceMeasure.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/SquaredEuclideanDistanceMeasure.java Sat Aug  8 01:42:06 2009
@@ -80,9 +80,6 @@
     if (centroid.size() != v.size()) {
       throw new CardinalityException();
     }
-
-    double result = centroidLengthSquare;
-    result += v.getDistanceSquared(centroid);
-    return result;
+    return centroidLengthSquare + v.getDistanceSquared(centroid);
   }
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java?rev=802283&r1=802282&r2=802283&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/matrix/TestSparseVector.java Sat Aug  8 01:42:06 2009
@@ -43,10 +43,8 @@
 
   public void testAsFormatString() {
     String formatString = test.asFormatString();
-    assertEquals(
-        "format",
-        "{\"class\":\"org.apache.mahout.matrix.SparseVector\",\"vector\":\"{\\\"values\\\":{\\\"indices\\\":[1,2,3],\\\"values\\\":[1.1,2.2,3.3],\\\"numMappings\\\":3},\\\"cardinality\\\":5}\"}",
-        formatString);
+    Vector vec = AbstractVector.decodeVector(formatString);
+    assertEquals(vec, test);
   }
 
   public void testCardinality() {