You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/22 19:26:17 UTC
svn commit: r915007 - in
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering:
canopy/Canopy.java kmeans/Cluster.java
Author: robinanil
Date: Mon Feb 22 18:26:17 2010
New Revision: 915007
URL: http://svn.apache.org/viewvc?rev=915007&view=rev
Log:
MAHOUT-297 First cut changes for kmeans and canopy for 0.3, rest for 0.4
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java?rev=915007&r1=915006&r2=915007&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java Mon Feb 22 18:26:17 2010
@@ -25,6 +25,7 @@
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -48,8 +49,8 @@
*/
public Canopy(Vector point, int canopyId) {
this.setId(canopyId);
- this.setCenter(point.clone());
- this.setPointTotal(point.clone());
+ this.setCenter(new RandomAccessSparseVector(point.clone()));
+ this.setPointTotal(getCenter().clone());
this.setNumPoints(1);
}
@@ -64,7 +65,7 @@
super.readFields(in);
VectorWritable temp = new VectorWritable();
temp.readFields(in);
- this.setCenter(temp.get());
+ this.setCenter(new RandomAccessSparseVector(temp.get()));
this.setPointTotal(getCenter().clone());
this.setNumPoints(1);
}
@@ -106,8 +107,7 @@
*/
public void addPoint(Vector point) {
setNumPoints(getNumPoints() + 1);
- setPointTotal(getPointTotal().plus(point));
-
+ point.addTo(getPointTotal());
}
/**
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=915007&r1=915006&r2=915007&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Mon Feb 22 18:26:17 2010
@@ -23,8 +23,10 @@
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.SquareRootFunction;
public class Cluster extends ClusterBase {
@@ -33,13 +35,13 @@
private static final String ERROR_UNKNOWN_CLUSTER_FORMAT = "Unknown cluster format:\n";
/** The current centroid is lazy evaluated and may be null */
- private Vector centroid = null;
+ private Vector centroid;
/** The total of all the points squared, used for std computation */
- private Vector pointSquaredTotal = null;
+ private Vector pointSquaredTotal;
/** Has the centroid converged with the center? */
- private boolean converged = false;
+ private boolean converged;
/**
* Format the cluster for output
@@ -100,7 +102,7 @@
this.converged = in.readBoolean();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
- this.setCenter(temp.get());
+ this.setCenter(new RandomAccessSparseVector(temp.get()));
this.setNumPoints(0);
this.setPointTotal(getCenter().like());
this.pointSquaredTotal = getCenter().like();
@@ -130,7 +132,7 @@
*/
public Cluster(Vector center) {
super();
- this.setCenter(center);
+ this.setCenter(new RandomAccessSparseVector(center));
this.setNumPoints(0);
this.setPointTotal(center.like());
this.pointSquaredTotal = center.like();
@@ -148,10 +150,10 @@
public Cluster(Vector center, int clusterId) {
super();
this.setId(clusterId);
- this.setCenter(center);
+ this.setCenter(new RandomAccessSparseVector(center));
this.setNumPoints(0);
this.setPointTotal(center.like());
- this.pointSquaredTotal = center.like();
+ this.pointSquaredTotal = getCenter().like();
}
/** Construct a new clsuter with the given id as identifier */
@@ -195,10 +197,10 @@
setNumPoints(getNumPoints() + count);
if (getPointTotal() == null) {
setPointTotal(delta.clone());
- pointSquaredTotal = delta.times(delta);
+ pointSquaredTotal = new RandomAccessSparseVector(delta.clone().assign(Functions.square));
} else {
delta.addTo(getPointTotal());
- delta.times(delta).addTo(pointSquaredTotal);
+ delta.clone().assign(Functions.square).addTo(pointSquaredTotal);
}
}