You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/22 19:26:17 UTC

svn commit: r915007 - in /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering: canopy/Canopy.java kmeans/Cluster.java

Author: robinanil
Date: Mon Feb 22 18:26:17 2010
New Revision: 915007

URL: http://svn.apache.org/viewvc?rev=915007&view=rev
Log:
MAHOUT-297 First cut changes for kmeans and canopy for 0.3, rest for 0.4

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java?rev=915007&r1=915006&r2=915007&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java Mon Feb 22 18:26:17 2010
@@ -25,6 +25,7 @@
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
@@ -48,8 +49,8 @@
    */
   public Canopy(Vector point, int canopyId) {
     this.setId(canopyId);
-    this.setCenter(point.clone());
-    this.setPointTotal(point.clone());
+    this.setCenter(new RandomAccessSparseVector(point.clone()));
+    this.setPointTotal(getCenter().clone());
     this.setNumPoints(1);
   }
   
@@ -64,7 +65,7 @@
     super.readFields(in);
     VectorWritable temp = new VectorWritable();
     temp.readFields(in);
-    this.setCenter(temp.get());
+    this.setCenter(new RandomAccessSparseVector(temp.get()));
     this.setPointTotal(getCenter().clone());
     this.setNumPoints(1);
   }
@@ -106,8 +107,7 @@
    */
   public void addPoint(Vector point) {
     setNumPoints(getNumPoints() + 1);
-    setPointTotal(getPointTotal().plus(point));
-    
+    point.addTo(getPointTotal());
   }
   
   /**

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=915007&r1=915006&r2=915007&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Mon Feb 22 18:26:17 2010
@@ -23,8 +23,10 @@
 import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.function.SquareRootFunction;
 
 public class Cluster extends ClusterBase {
@@ -33,13 +35,13 @@
   private static final String ERROR_UNKNOWN_CLUSTER_FORMAT = "Unknown cluster format:\n";
   
   /** The current centroid is lazy evaluated and may be null */
-  private Vector centroid = null;
+  private Vector centroid;
   
   /** The total of all the points squared, used for std computation */
-  private Vector pointSquaredTotal = null;
+  private Vector pointSquaredTotal;
   
   /** Has the centroid converged with the center? */
-  private boolean converged = false;
+  private boolean converged;
   
   /**
    * Format the cluster for output
@@ -100,7 +102,7 @@
     this.converged = in.readBoolean();
     VectorWritable temp = new VectorWritable();
     temp.readFields(in);
-    this.setCenter(temp.get());
+    this.setCenter(new RandomAccessSparseVector(temp.get()));
     this.setNumPoints(0);
     this.setPointTotal(getCenter().like());
     this.pointSquaredTotal = getCenter().like();
@@ -130,7 +132,7 @@
    */
   public Cluster(Vector center) {
     super();
-    this.setCenter(center);
+    this.setCenter(new RandomAccessSparseVector(center));
     this.setNumPoints(0);
     this.setPointTotal(center.like());
     this.pointSquaredTotal = center.like();
@@ -148,10 +150,10 @@
   public Cluster(Vector center, int clusterId) {
     super();
     this.setId(clusterId);
-    this.setCenter(center);
+    this.setCenter(new RandomAccessSparseVector(center));
     this.setNumPoints(0);
     this.setPointTotal(center.like());
-    this.pointSquaredTotal = center.like();
+    this.pointSquaredTotal = getCenter().like();
   }
   
   /** Construct a new clsuter with the given id as identifier */
@@ -195,10 +197,10 @@
     setNumPoints(getNumPoints() + count);
     if (getPointTotal() == null) {
       setPointTotal(delta.clone());
-      pointSquaredTotal = delta.times(delta);
+      pointSquaredTotal = new RandomAccessSparseVector(delta.clone().assign(Functions.square));
     } else {
       delta.addTo(getPointTotal());
-      delta.times(delta).addTo(pointSquaredTotal);
+      delta.clone().assign(Functions.square).addTo(pointSquaredTotal);
     }
   }