You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/07/24 05:39:33 UTC
svn commit: r978786 [1/2] - in /mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/
core/src/main/java/org/apache/mahout/clustering/canopy/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
core/src/main/java/org/apache/mahou...
Author: jeastman
Date: Sat Jul 24 03:39:30 2010
New Revision: 978786
URL: http://svn.apache.org/viewvc?rev=978786&view=rev
Log:
MAHOUT-294:
- modified most DisplayClustering subclasses to use the new sequential method on drivers. (Not Dirichlet yet)
- using file system to transmit Clusters required a rework since they were not serializing needed state
- refactored Canopy, Cluster, SoftCluster and MeanShiftCanopy significantly, abstracting shared behavior to new AbstractCluster class.
- deleted ClusterBase after moving static method to AbstractCluster
- added ClusterObservations to replace KMeansInfo and FuzzyKMeansInfo
- changed all cluster identifier formatting to include type indication
- upshot of new clusters is improved posterior statistics for all with radius() now returning stdDev(), a vector
- new radius() used in Display examples to show elliptical clusters
- adjusted unit tests and all pass
probably should have made a new JIRA for some of this
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterObservations.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansInfo.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/VisibleCanopy.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/VisibleCluster.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestRandomSeedGenerator.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java
mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/TestClusterDumper.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java?rev=978786&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java Sat Jul 24 03:39:30 2010
@@ -0,0 +1,272 @@
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.Iterator;
+import java.util.Locale;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
+
+public abstract class AbstractCluster implements Writable, Cluster {
+
+ private static final Type VECTOR_TYPE = new TypeToken<Vector>() {
+ }.getType();
+
+ // cluster persistent state
+ private int id;
+
+ private int numPoints;
+
+ private Vector center;
+
+ private Vector radius;
+
+ /**
+ * @param id the id to set
+ */
+ protected void setId(int id) {
+ this.id = id;
+ }
+
+ /**
+ * @param numPoints the numPoints to set
+ */
+ protected void setNumPoints(int numPoints) {
+ this.numPoints = numPoints;
+ }
+
+ /**
+ * @param center the center to set
+ */
+ protected void setCenter(Vector center) {
+ this.center = center;
+ }
+
+ /**
+ * @param radius the radius to set
+ */
+ protected void setRadius(Vector radius) {
+ this.radius = radius;
+ }
+
+ // the observation statistics, initialized by the first observation
+ private transient double s0;
+
+ private transient Vector s1;
+
+ private transient Vector s2;
+
+ /**
+ * @return the s0
+ */
+ protected double getS0() {
+ return s0;
+ }
+
+ /**
+ * @return the s1
+ */
+ protected Vector getS1() {
+ return s1;
+ }
+
+ /**
+ * @return the s2
+ */
+ protected Vector getS2() {
+ return s2;
+ }
+
+ public void computeParameters() {
+ if (s0 == 0) {
+ return;
+ }
+ numPoints = (int) s0;
+ center = s1.divide(s0);
+ // compute the component stds
+ if (s0 > 1) {
+ radius = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
+ } else {
+ radius.assign(Double.MIN_NORMAL);
+ }
+ s0 = 0;
+ s1 = null;
+ s2 = null;
+ }
+
+ public void observe(ClusterObservations observations) {
+ s0 += observations.getS0();
+ if (s1 == null) {
+ s1 = observations.getS1().clone();
+ } else {
+ observations.getS1().addTo(s1);
+ }
+ if (s2 == null) {
+ s2 = observations.getS2().clone();
+ } else {
+ observations.getS2().addTo(s2);
+ }
+ }
+
+ public ClusterObservations getObservations() {
+ return new ClusterObservations(s0, s1, s2);
+ }
+
+ public void observe(Vector x, double weight) {
+ s0 += weight;
+ Vector weightedX = x.times(weight);
+ if (s1 == null) {
+ s1 = weightedX;
+ } else {
+ weightedX.addTo(s1);
+ }
+ Vector x2 = x.times(x).times(weight);
+ if (s2 == null) {
+ s2 = x2;
+ } else {
+ x2.addTo(s2);
+ }
+ }
+
+ public void observe(Vector x) {
+ observe(x, 1.0);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.id = in.readInt();
+ this.numPoints = in.readInt();
+ VectorWritable temp = new VectorWritable();
+ temp.readFields(in);
+ this.center = temp.get();
+ temp.readFields(in);
+ this.radius = temp.get();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
+ out.writeInt(numPoints);
+ VectorWritable.writeVector(out, center);
+ VectorWritable.writeVector(out, radius);
+ }
+
+ @Override
+ public String asFormatString(String[] bindings) {
+ StringBuilder buf = new StringBuilder(50);
+ buf.append(getIdentifier() + "{n=").append(numPoints).append(" c=");
+ if (center != null) {
+ buf.append(AbstractCluster.formatVector(center, bindings));
+ }
+ buf.append(" r=");
+ if (radius != null) {
+ buf.append(AbstractCluster.formatVector(radius, bindings));
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ public abstract String getIdentifier();
+
+ @Override
+ public String asJsonString() {
+ GsonBuilder gBuilder = new GsonBuilder();
+ gBuilder.registerTypeAdapter(VECTOR_TYPE, new JsonVectorAdapter());
+ Gson gson = gBuilder.create();
+ return gson.toJson(this, this.getClass());
+ }
+
+ @Override
+ public Vector getCenter() {
+ return center;
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ @Override
+ public int getNumPoints() {
+ return numPoints;
+ }
+
+ @Override
+ public Vector getRadius() {
+ return radius;
+ }
+
+ /**
+ * Compute the centroid by averaging the pointTotals
+ *
+ * @return the new centroid
+ */
+ public Vector computeCentroid() {
+ if (s0 == 0) {
+ return getCenter();
+ } else {
+ return s1.divide(s0);
+ }
+ }
+
+ /**
+ * Return a human-readable formatted string representation of the vector, not intended to be complete nor
+ * usable as an input/output representation such as Json
+ *
+ * @param v
+ * a Vector
+ * @return a String
+ */
+ public static String formatVector(Vector v, String[] bindings) {
+ StringBuilder buf = new StringBuilder();
+ if (v instanceof NamedVector) {
+ buf.append(((NamedVector) v).getName()).append(" = ");
+ }
+ int nzero = 0;
+ Iterator<Element> iterateNonZero = v.iterateNonZero();
+ while (iterateNonZero.hasNext()) {
+ iterateNonZero.next();
+ nzero++;
+ }
+ // if vector is sparse or if we have bindings, use sparse notation
+ if ((nzero < v.size()) || (bindings != null)) {
+ buf.append('[');
+ for (int i = 0; i < v.size(); i++) {
+ double elem = v.get(i);
+ if (elem == 0.0) {
+ continue;
+ }
+ String label;
+ if ((bindings != null) && ((label = bindings[i]) != null)) {
+ buf.append(label).append(':');
+ } else {
+ buf.append(i).append(':');
+ }
+ buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
+ }
+ } else {
+ buf.append('[');
+ for (int i = 0; i < v.size(); i++) {
+ double elem = v.get(i);
+ buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
+ }
+ }
+ if (buf.length() > 1) {
+ buf.setLength(buf.length() - 2);
+ }
+ buf.append(']');
+ return buf.toString();
+ }
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java Sat Jul 24 03:39:30 2010
@@ -21,15 +21,11 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
-import java.util.Iterator;
-import java.util.Locale;
import com.google.gson.reflect.TypeToken;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.math.JsonVectorAdapter;
-import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.Vector.Element;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
@@ -100,7 +96,7 @@ public abstract class ClusterBase implem
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder();
- buf.append(getIdentifier()).append(": ").append(formatVector(getCenter(), bindings));
+ buf.append(getIdentifier()).append(": ").append(AbstractCluster.formatVector(getCenter(), bindings));
return buf.toString();
}
@@ -125,60 +121,13 @@ public abstract class ClusterBase implem
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
+ out.writeInt(numPoints);
}
/** Reads in the id, nothing else */
@Override
public void readFields(DataInput in) throws IOException {
id = in.readInt();
- }
-
- /**
- * Return a human-readable formatted string representation of the vector, not intended to be complete nor
- * usable as an input/output representation such as Json
- *
- * @param v
- * a Vector
- * @return a String
- */
- public static String formatVector(Vector v, String[] bindings) {
- StringBuilder buf = new StringBuilder();
- if (v instanceof NamedVector) {
- buf.append(((NamedVector) v).getName()).append(" = ");
- }
- int nzero = 0;
- Iterator<Element> iterateNonZero = v.iterateNonZero();
- while (iterateNonZero.hasNext()) {
- iterateNonZero.next();
- nzero++;
- }
- // if vector is sparse or if we have bindings, use sparse notation
- if ((nzero < v.size()) || (bindings != null)) {
- buf.append('[');
- for (int i = 0; i < v.size(); i++) {
- double elem = v.get(i);
- if (elem == 0.0) {
- continue;
- }
- String label;
- if ((bindings != null) && ((label = bindings[i]) != null)) {
- buf.append(label).append(':');
- } else {
- buf.append(i).append(':');
- }
- buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
- }
- } else {
- buf.append('[');
- for (int i = 0; i < v.size(); i++) {
- double elem = v.get(i);
- buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
- }
- }
- if (buf.length() > 1) {
- buf.setLength(buf.length() - 2);
- }
- buf.append(']');
- return buf.toString();
+ numPoints = in.readInt();
}
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterObservations.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterObservations.java?rev=978786&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterObservations.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterObservations.java Sat Jul 24 03:39:30 2010
@@ -0,0 +1,106 @@
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class ClusterObservations implements Writable {
+
+ private int combinerState = 0;
+
+ private double s0;
+
+ private Vector s1;
+
+ private Vector s2;
+
+ public ClusterObservations(double s0, Vector s1, Vector s2) {
+ super();
+ this.s0 = s0;
+ this.s1 = s1;
+ this.s2 = s2;
+ }
+
+ public ClusterObservations(int combinerState, double s0, Vector s1, Vector s2) {
+ super();
+ this.combinerState = combinerState;
+ this.s0 = s0;
+ this.s1 = s1;
+ this.s2 = s2;
+ }
+
+ public ClusterObservations() {
+ super();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.combinerState = in.readInt();
+ this.s0 = in.readDouble();
+ VectorWritable temp = new VectorWritable();
+ temp.readFields(in);
+ this.s1 = temp.get();
+ temp.readFields(in);
+ this.s2 = temp.get();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(combinerState);
+ out.writeDouble(s0);
+ VectorWritable.writeVector(out, s1);
+ VectorWritable.writeVector(out, s2);
+ }
+
+ /**
+ * @return the combinerState
+ */
+ public int getCombinerState() {
+ return combinerState;
+ }
+
+ /**
+ * @return the s0
+ */
+ public double getS0() {
+ return s0;
+ }
+
+ /**
+ * @return the s1
+ */
+ public Vector getS1() {
+ return s1;
+ }
+
+ /**
+ * @return the s2
+ */
+ public Vector getS2() {
+ return s2;
+ }
+
+ public String toString() {
+ StringBuilder buf = new StringBuilder(50);
+ buf.append("co{s0=").append(s0).append(" s1=");
+ if (s1 != null) {
+ buf.append(AbstractCluster.formatVector(s1, null));
+ }
+ buf.append(" s2=");
+ if (s2 != null) {
+ buf.append(AbstractCluster.formatVector(s2, null));
+ }
+ buf.append('}');
+ return buf.toString();
+ }
+
+ public ClusterObservations incrementCombinerState() {
+ combinerState++;
+ return this;
+ }
+
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java Sat Jul 24 03:39:30 2010
@@ -66,7 +66,7 @@ public class WeightedVectorWritable impl
}
public String toString() {
- return weight + ": " + (vector == null ? "null" : ClusterBase.formatVector(vector.get(), null));
+ return weight + ": " + (vector == null ? "null" : AbstractCluster.formatVector(vector.get(), null));
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java Sat Jul 24 03:39:30 2010
@@ -21,18 +21,17 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.math.AbstractVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
/**
* This class models a canopy as a center point, the number of points that are contained within it according
* to the application of some distance metric, and a point total which is the sum of all the points and is
* used to compute the centroid when needed.
*/
-public class Canopy extends ClusterBase {
+public class Canopy extends AbstractCluster {
/** Used for deserializaztion as a writable */
public Canopy() { }
@@ -40,32 +39,27 @@ public class Canopy extends ClusterBase
/**
* Create a new Canopy containing the given point and canopyId
*
- * @param point
+ * @param center
* a point in vector space
* @param canopyId
* an int identifying the canopy local to this process only
*/
- public Canopy(Vector point, int canopyId) {
+ public Canopy(Vector center, int canopyId) {
this.setId(canopyId);
- this.setCenter(new RandomAccessSparseVector(point));
- this.setPointTotal(getCenter().clone());
- this.setNumPoints(1);
+ this.setNumPoints(0);
+ this.setCenter(new RandomAccessSparseVector(center));
+ this.setRadius(center.like());
+ observe(center);
}
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
- VectorWritable.writeVector(out, computeCentroid());
}
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
- VectorWritable temp = new VectorWritable();
- temp.readFields(in);
- this.setCenter(new RandomAccessSparseVector(temp.get()));
- this.setPointTotal(getCenter().clone());
- this.setNumPoints(1);
}
/** Format the canopy for output */
@@ -73,7 +67,6 @@ public class Canopy extends ClusterBase
return "C" + canopy.getId() + ": " + canopy.computeCentroid().asFormatString();
}
- @Override
public String asFormatString() {
return formatCanopy(this);
}
@@ -97,39 +90,12 @@ public class Canopy extends ClusterBase
return null;
}
- /**
- * Add a point to the canopy
- *
- * @param point
- * some point to add
- */
- public void addPoint(Vector point) {
- setNumPoints(getNumPoints() + 1);
- point.addTo(getPointTotal());
- }
-
@Override
public String toString() {
return getIdentifier() + ": " + getCenter().asFormatString();
}
- @Override
public String getIdentifier() {
return "C-" + getId();
}
-
- /**
- * Compute the centroid by averaging the pointTotals
- *
- * @return a RandomAccessSparseVector (required by Mapper) which is the new centroid
- */
- @Override
- public Vector computeCentroid() {
- return getPointTotal().divide(getNumPoints());
- }
-
- @Override
- public Vector getRadius() {
- return getCenter().like();
- }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Sat Jul 24 03:39:30 2010
@@ -25,7 +25,7 @@ import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
@@ -113,13 +113,13 @@ public class CanopyClusterer {
for (Canopy canopy : canopies) {
double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
if (dist < t1) {
- log.info("Added point: " + ClusterBase.formatVector(point, null) + " to canopy: " + canopy.getIdentifier());
- canopy.addPoint(point);
+ log.info("Added point: " + AbstractCluster.formatVector(point, null) + " to canopy: " + canopy.getIdentifier());
+ canopy.observe(point);
}
pointStronglyBound = pointStronglyBound || (dist < t2);
}
if (!pointStronglyBound) {
- log.info("Created new Canopy:" + nextCanopyId + " at center:" + ClusterBase.formatVector(point, null));
+ log.info("Created new Canopy:" + nextCanopyId + " at center:" + AbstractCluster.formatVector(point, null));
canopies.add(new Canopy(point, nextCanopyId++));
}
}
@@ -202,13 +202,16 @@ public class CanopyClusterer {
double dist = measure.distance(p1, p2);
// Put all points that are within distance threshold T1 into the canopy
if (dist < t1) {
- canopy.addPoint(p2);
+ canopy.observe(p2);
}
// Remove from the list all points that are within distance threshold T2
if (dist < t2) {
ptIter.remove();
}
}
+ for (Canopy c : canopies) {
+ c.computeParameters();
+ }
}
return canopies;
}
@@ -220,10 +223,10 @@ public class CanopyClusterer {
* a List<Canopy>
* @return the List<Vector>
*/
- public static List<Vector> calculateCentroids(List<Canopy> canopies) {
+ public static List<Vector> getCenters(List<Canopy> canopies) {
List<Vector> result = new ArrayList<Vector>();
for (Canopy canopy : canopies) {
- result.add(canopy.computeCentroid());
+ result.add(canopy.getCenter());
}
return result;
}
@@ -236,7 +239,7 @@ public class CanopyClusterer {
*/
public static void updateCentroids(List<Canopy> canopies) {
for (Canopy canopy : canopies) {
- canopy.setCenter(canopy.computeCentroid());
+ canopy.computeParameters();
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java Sat Jul 24 03:39:30 2010
@@ -35,8 +35,8 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
-import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.clustering.kmeans.OutputLogFilter;
import org.apache.mahout.common.AbstractJob;
@@ -227,8 +227,9 @@ public class CanopyDriver extends Abstra
SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, Canopy.class);
try {
for (Canopy canopy : canopies) {
- log.info("Writing Canopy:" + canopy.getIdentifier() + " center:" + ClusterBase.formatVector(canopy.getCenter(), null)
- + " numPoints:" + canopy.getNumPoints() + " centroid:" + ClusterBase.formatVector(canopy.computeCentroid(), null));
+ canopy.computeParameters();
+ log.info("Writing Canopy:" + canopy.getIdentifier() + " center:" + AbstractCluster.formatVector(canopy.getCenter(), null)
+ + " numPoints:" + canopy.getNumPoints() + " radius:" + AbstractCluster.formatVector(canopy.getRadius(), null));
writer.append(new Text(canopy.getIdentifier()), canopy);
}
} finally {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java Sat Jul 24 03:39:30 2010
@@ -24,7 +24,6 @@ import java.util.List;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
class CanopyMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
@@ -47,8 +46,7 @@ class CanopyMapper extends Mapper<Writab
@Override
protected void cleanup(Context context) throws IOException, InterruptedException {
for (Canopy canopy : canopies) {
- Vector centroid = canopy.computeCentroid();
- context.write(new Text("centroid"), new VectorWritable(centroid));
+ context.write(new Text("centroid"), new VectorWritable(canopy.computeCentroid()));
}
super.cleanup(context);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java Sat Jul 24 03:39:30 2010
@@ -35,6 +35,7 @@ public class CanopyReducer extends Reduc
canopyClusterer.addPointToCanopies(point, canopies);
}
for (Canopy canopy : canopies) {
+ canopy.computeParameters();
context.write(new Text(canopy.getIdentifier()), canopy);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Sat Jul 24 03:39:30 2010
@@ -22,7 +22,7 @@ import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -152,11 +152,11 @@ public class AsymmetricSampledNormalMode
StringBuilder buf = new StringBuilder(50);
buf.append("asnm{n=").append(s0).append(" m=");
if (mean != null) {
- buf.append(ClusterBase.formatVector(mean, bindings));
+ buf.append(AbstractCluster.formatVector(mean, bindings));
}
buf.append(" sd=");
if (stdDev != null) {
- buf.append(ClusterBase.formatVector(stdDev, bindings));
+ buf.append(AbstractCluster.formatVector(stdDev, bindings));
}
buf.append('}');
return buf.toString();
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Sat Jul 24 03:39:30 2010
@@ -21,7 +21,7 @@ import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
@@ -106,7 +106,7 @@ public class L1Model implements Model<Ve
StringBuilder buf = new StringBuilder();
buf.append("l1m{n=").append(counter).append(" c=");
if (coefficients != null) {
- buf.append(ClusterBase.formatVector(coefficients, bindings));
+ buf.append(AbstractCluster.formatVector(coefficients, bindings));
}
buf.append('}');
return buf.toString();
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Sat Jul 24 03:39:30 2010
@@ -23,7 +23,7 @@ import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Locale;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.dirichlet.JsonModelAdapter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -140,7 +140,7 @@ public class NormalModel implements Mode
StringBuilder buf = new StringBuilder();
buf.append("nm{n=").append(s0).append(" m=");
if (mean != null) {
- buf.append(ClusterBase.formatVector(mean, bindings));
+ buf.append(AbstractCluster.formatVector(mean, bindings));
}
buf.append(" sd=").append(String.format(Locale.ENGLISH, "%.2f", stdDev)).append('}');
return buf.toString();
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java Sat Jul 24 03:39:30 2010
@@ -19,7 +19,7 @@ package org.apache.mahout.clustering.dir
import java.util.Locale;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.math.Vector;
public class SampledNormalModel extends NormalModel {
@@ -51,7 +51,7 @@ public class SampledNormalModel extends
StringBuilder buf = new StringBuilder();
buf.append("snm{n=").append(getS0()).append(" m=");
if (getMean() != null) {
- buf.append(ClusterBase.formatVector(getMean(), bindings));
+ buf.append(AbstractCluster.formatVector(getMean(), bindings));
}
buf.append(" sd=").append(String.format(Locale.ENGLISH, "%.2f", getStdDev())).append('}');
return buf.toString();
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Sat Jul 24 03:39:30 2010
@@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.DenseVector;
@@ -115,7 +116,9 @@ public class FuzzyKMeansClusterer {
* @param clusterList
* the List<Cluster> clusters
*/
- protected static boolean runFuzzyKMeansIteration(List<Vector> points, List<SoftCluster> clusterList, FuzzyKMeansClusterer clusterer) {
+ protected static boolean runFuzzyKMeansIteration(List<Vector> points,
+ List<SoftCluster> clusterList,
+ FuzzyKMeansClusterer clusterer) {
for (Vector point : points) {
clusterer.addPointToClusters(clusterList, point);
}
@@ -162,7 +165,7 @@ public class FuzzyKMeansClusterer {
*/
public void emitPointProbToCluster(Vector point,
List<SoftCluster> clusters,
- Mapper<WritableComparable<?>, VectorWritable, Text, FuzzyKMeansInfo>.Context context)
+ Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>.Context context)
throws IOException, InterruptedException {
List<Double> clusterDistanceList = new ArrayList<Double>();
@@ -171,9 +174,11 @@ public class FuzzyKMeansClusterer {
}
for (int i = 0; i < clusters.size(); i++) {
- double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- Text key = new Text(clusters.get(i).getIdentifier());
- FuzzyKMeansInfo value = new FuzzyKMeansInfo(probWeight, point);
+ SoftCluster cluster = clusters.get(i);
+ Text key = new Text(cluster.getIdentifier());
+ ClusterObservations value = new ClusterObservations(computeProbWeight(clusterDistanceList.get(i), clusterDistanceList),
+ point,
+ point.times(point));
context.write(key, value);
}
}
@@ -199,9 +204,7 @@ public class FuzzyKMeansClusterer {
* @return if the cluster is converged
*/
public boolean computeConvergence(SoftCluster cluster) {
- Vector centroid = cluster.computeCentroid();
- cluster.setConverged(measure.distance(cluster.getCenter(), centroid) <= convergenceDelta);
- return cluster.isConverged();
+ return cluster.computeConvergence(measure, convergenceDelta);
}
public double getM() {
@@ -285,22 +288,17 @@ public class FuzzyKMeansClusterer {
for (int i = 0; i < clusterList.size(); i++) {
double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- clusterList.get(i).addPoint(point, Math.pow(probWeight, getM()));
+ clusterList.get(i).observe(point, Math.pow(probWeight, getM()));
}
}
protected boolean testConvergence(List<SoftCluster> clusters) {
boolean converged = true;
for (SoftCluster cluster : clusters) {
- if (!computeConvergence(cluster)) {
+ if (!cluster.computeConvergence(measure, convergenceDelta)) {
converged = false;
}
- }
- // update the cluster centers
- if (!converged) {
- for (SoftCluster cluster : clusters) {
- cluster.recomputeCenter();
- }
+ cluster.computeParameters();
}
return converged;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java Sat Jul 24 03:39:30 2010
@@ -21,23 +21,23 @@ import java.io.IOException;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.ClusterObservations;
+
+public class FuzzyKMeansCombiner extends Reducer<Text, ClusterObservations, Text, ClusterObservations> {
-public class FuzzyKMeansCombiner extends Reducer<Text,FuzzyKMeansInfo,Text,FuzzyKMeansInfo> {
-
private FuzzyKMeansClusterer clusterer;
@Override
- protected void reduce(Text key, Iterable<FuzzyKMeansInfo> values, Context context) throws IOException, InterruptedException {
- SoftCluster cluster = new SoftCluster(key.toString().trim());
- for (FuzzyKMeansInfo value : values) {
- if (value.getCombinerPass() == 0) { // first time thru combiner
- cluster.addPoint(value.getVector(), Math.pow(value.getProbability(), clusterer.getM()));
+ protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
+ SoftCluster cluster = new SoftCluster();
+ for (ClusterObservations value : values) {
+ if (value.getCombinerState() == 0) { // first time thru combiner
+ cluster.observe(value.getS1(), Math.pow(value.getS0(), clusterer.getM()));
} else {
- cluster.addPoints(value.getVector(), value.getProbability());
+ cluster.observe(value);
}
- value.setCombinerPass(value.getCombinerPass() + 1);
}
- context.write(key, new FuzzyKMeansInfo(cluster.getPointProbSum(), cluster.getWeightedPointTotal(), 1));
+ context.write(key, cluster.getObservations().incrementCombinerState());
}
@Override
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java Sat Jul 24 03:39:30 2010
@@ -36,7 +36,9 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.clustering.kmeans.OutputLogFilter;
import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
@@ -230,7 +232,7 @@ public class FuzzyKMeansDriver extends A
Job job = new Job(conf);
job.setMapOutputKeyClass(Text.class);
- job.setMapOutputValueClass(FuzzyKMeansInfo.class);
+ job.setMapOutputValueClass(ClusterObservations.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(SoftCluster.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
@@ -409,6 +411,7 @@ public class FuzzyKMeansDriver extends A
boolean converged = false;
int iteration = 1;
while (!converged && iteration <= maxIterations) {
+ log.info("Fuzzy k-Means Iteration: " + iteration);
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(input.toUri(), conf);
FileStatus[] status = fs.listStatus(input, new OutputLogFilter());
@@ -434,12 +437,16 @@ public class FuzzyKMeansDriver extends A
SoftCluster.class);
try {
for (SoftCluster cluster : clusters) {
+ log.info("Writing Cluster:" + cluster.getId() + " center:" + AbstractCluster.formatVector(cluster.getCenter(), null)
+ + " numPoints:" + cluster.getNumPoints() + " radius:" + AbstractCluster.formatVector(cluster.getRadius(), null) + " to: "
+ + clustersOut.getName());
writer.append(new Text(cluster.getIdentifier()), cluster);
}
} finally {
writer.close();
}
clustersIn = clustersOut;
+ iteration++;
}
return clustersIn;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java Sat Jul 24 03:39:30 2010
@@ -26,11 +26,12 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class FuzzyKMeansMapper extends Mapper<WritableComparable<?>,VectorWritable,Text,FuzzyKMeansInfo> {
+public class FuzzyKMeansMapper extends Mapper<WritableComparable<?>,VectorWritable,Text,ClusterObservations> {
private static final Logger log = LoggerFactory.getLogger(FuzzyKMeansMapper.class);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java Sat Jul 24 03:39:30 2010
@@ -27,21 +27,22 @@ import org.apache.hadoop.conf.Configurat
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.ClusterObservations;
-public class FuzzyKMeansReducer extends Reducer<Text, FuzzyKMeansInfo, Text, SoftCluster> {
+public class FuzzyKMeansReducer extends Reducer<Text, ClusterObservations, Text, SoftCluster> {
private final Map<String, SoftCluster> clusterMap = new HashMap<String, SoftCluster>();
private FuzzyKMeansClusterer clusterer;
@Override
- protected void reduce(Text key, Iterable<FuzzyKMeansInfo> values, Context context) throws IOException, InterruptedException {
+ protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
SoftCluster cluster = clusterMap.get(key.toString());
- for (FuzzyKMeansInfo value : values) {
- if (value.getCombinerPass() == 0) { // escaped from combiner
- cluster.addPoint(value.getVector(), Math.pow(value.getProbability(), clusterer.getM()));
+ for (ClusterObservations value : values) {
+ if (value.getCombinerState() == 0) { // escaped from combiner
+ cluster.observe(value.getS1(), Math.pow(value.getS0(), clusterer.getM()));
} else {
- cluster.addPoints(value.getVector(), value.getProbability());
+ cluster.observe(value);
}
}
// force convergence calculation
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java Sat Jul 24 03:39:30 2010
@@ -29,6 +29,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.kmeans.Cluster;
import org.slf4j.Logger;
@@ -84,7 +85,7 @@ final class FuzzyKMeansUtil {
throw new IllegalStateException(e);
}
if (valueClass.equals(Cluster.class)) {
- Cluster value = new Cluster();
+ AbstractCluster value = new Cluster();
while (reader.next(key, value)) {
// get the cluster info
SoftCluster theCluster = new SoftCluster(value.getCenter(), value.getId());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Sat Jul 24 03:39:30 2010
@@ -17,38 +17,11 @@
package org.apache.mahout.clustering.fuzzykmeans;
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.math.AbstractVector;
-import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.math.function.SquareRootFunction;
-
-public class SoftCluster extends ClusterBase {
-
- // the current centroid is lazy evaluated and may be null
- private Vector centroid;
-
- // The Probability of belongingness sum
- private double pointProbSum;
-
- // the total of all points added to the cluster
- private Vector weightedPointTotal;
- // has the centroid converged with the center?
- private boolean converged;
-
- // track membership parameters
- private double s0;
-
- private Vector s1;
-
- private Vector s2;
+public class SoftCluster extends Cluster {
// For Writable
public SoftCluster() {
@@ -56,36 +29,12 @@ public class SoftCluster extends Cluster
/**
* Construct a new SoftCluster with the given point as its center
- *
- * @param center
- * the center point
- */
- public SoftCluster(Vector center) {
- setCenter(center.clone());
- this.pointProbSum = 0;
- this.weightedPointTotal = getCenter().like();
- }
-
- /**
- * Construct a new SoftCluster with the given point as its center
*
* @param center
* the center point
*/
public SoftCluster(Vector center, int clusterId) {
- this.setId(clusterId);
- this.setCenter(new RandomAccessSparseVector(center));
- this.pointProbSum = 0;
- this.weightedPointTotal = center.like();
- }
-
- /** Construct a new softcluster with the given clusterID */
- public SoftCluster(String clusterId) {
-
- this.setId(Integer.parseInt(clusterId.substring(1)));
- this.pointProbSum = 0;
- // this.weightedPointTotal = center.like();
- this.converged = clusterId.charAt(0) == 'V';
+ super(center, clusterId);
}
/**
@@ -121,151 +70,12 @@ public class SoftCluster extends Cluster
return null;
}
-
- @Override
- public void write(DataOutput out) throws IOException {
- out.writeInt(this.getId());
- out.writeBoolean(converged);
- Vector vector = computeCentroid();
- VectorWritable.writeVector(out, vector);
- }
-
- @Override
- public void readFields(DataInput in) throws IOException {
- this.setId(in.readInt());
- converged = in.readBoolean();
- VectorWritable temp = new VectorWritable();
- temp.readFields(in);
- this.setCenter(new RandomAccessSparseVector(temp.get()));
- this.pointProbSum = 0;
- this.weightedPointTotal = getCenter().like();
- }
-
- /**
- * Compute the centroid
- *
- * @return the new centroid
- */
- @Override
- public Vector computeCentroid() {
- if (centroid == null) {
- if (pointProbSum == 0) {
- return weightedPointTotal;
- }
- // lazy compute new centroid
- centroid = weightedPointTotal.divide(pointProbSum);
- }
- return centroid;
- }
-
- @Override
- public String toString() {
- return asFormatString(null);
- }
-
- @Override
- public String getIdentifier() {
- if (converged) {
- return "V-" + this.getId();
- } else {
- return "C-" + this.getId();
- }
- }
-
- /** Observe the point, accumulating weighted variables for std() calculation */
- private void observePoint(Vector point, double ptProb) {
- s0 += ptProb;
- Vector wtPt = point.times(ptProb);
- if (s1 == null) {
- s1 = point.clone();
- } else {
- s1 = s1.plus(wtPt);
- }
- if (s2 == null) {
- s2 = wtPt.times(wtPt);
- } else {
- s2 = s2.plus(wtPt.times(wtPt));
- }
- }
-
- /** Compute a "standard deviation" value to use as the "radius" of the cluster for display purposes */
- public double std() {
- if (s0 > 0) {
- Vector radical = s2.times(s0).minus(s1.times(s1));
- radical = radical.times(radical).assign(new SquareRootFunction());
- Vector stds = radical.assign(new SquareRootFunction()).divide(s0);
- return stds.zSum() / stds.size();
- } else {
- return 0;
- }
- }
-
- /**
- * Add the point to the SoftCluster
- *
- * @param point
- * a point to add
- */
- public void addPoint(Vector point, double ptProb) {
- observePoint(point, ptProb);
- centroid = null;
- pointProbSum += ptProb;
- if (weightedPointTotal == null) {
- weightedPointTotal = point.clone().assign(Functions.mult, ptProb);
- } else {
- point.clone().assign(Functions.mult, ptProb).addTo(weightedPointTotal);
- }
- }
-
- /**
- * Add the point to the cluster
- *
- * @param delta
- * a point to add
- */
- public void addPoints(Vector delta, double partialSumPtProb) {
- centroid = null;
- pointProbSum += partialSumPtProb;
- if (weightedPointTotal == null) {
- weightedPointTotal = delta.clone();
- } else {
- delta.addTo(weightedPointTotal);
- }
- }
-
- public double getPointProbSum() {
- return pointProbSum;
- }
-
- /** Compute the centroid and set the center to it. */
- public void recomputeCenter() {
- this.setCenter(computeCentroid());
- // set a reasonable value, for consistency with other Clusters
- setNumPoints((int) weightedPointTotal.zSum());
- pointProbSum = 0;
- weightedPointTotal = getCenter().like();
- }
-
- public Vector getWeightedPointTotal() {
- return weightedPointTotal;
- }
-
- public boolean isConverged() {
- return converged;
- }
-
- public void setConverged(boolean converged) {
- this.converged = converged;
- }
-
@Override
public String asFormatString() {
return formatCluster(this);
}
- @Override
- public Vector getRadius() {
- return getCenter().like().assign(std());
+ public String getIdentifier() {
+ return (isConverged() ? "SV-" : "SC-") + getId();
}
-
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Sat Jul 24 03:39:30 2010
@@ -20,46 +20,23 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.AbstractVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.math.function.SquareRootFunction;
-public class Cluster extends ClusterBase {
-
+public class Cluster extends AbstractCluster {
+
/** Error message for unknown cluster format in output. */
private static final String ERROR_UNKNOWN_CLUSTER_FORMAT = "Unknown cluster format:\n";
-
- /** The current centroid is lazy evaluated and may be null */
- private Vector centroid;
-
- /** The total of all the points squared, used for std computation */
- private Vector pointSquaredTotal;
-
+
/** Has the centroid converged with the center? */
private boolean converged;
-
- private double std = 0.00000001;
-
- /**
- * Construct a new cluster with the given point as its center
- *
- * @param center
- * the center point
- */
- public Cluster(Vector center) {
- this.setCenter(new RandomAccessSparseVector(center));
- this.setNumPoints(0);
- this.setPointTotal(getCenter().like());
- this.pointSquaredTotal = getCenter().like();
- }
/** For (de)serialization as a Writable */
- public Cluster() { }
+ public Cluster() {
+ }
/**
* Construct a new cluster with the given point as its center
@@ -69,20 +46,11 @@ public class Cluster extends ClusterBase
*/
public Cluster(Vector center, int clusterId) {
this.setId(clusterId);
- this.setCenter(new RandomAccessSparseVector(center));
this.setNumPoints(0);
- this.setPointTotal(getCenter().like());
- this.pointSquaredTotal = getCenter().like();
+ this.setCenter(new RandomAccessSparseVector(center));
+ this.setRadius(center.like());
}
- /** Construct a new clsuter with the given id as identifier */
- public Cluster(String clusterId) {
-
- this.setId(Integer.parseInt(clusterId.substring(1)));
- this.setNumPoints(0);
- this.converged = clusterId.startsWith("V");
- }
-
/**
* Format the cluster for output
*
@@ -93,12 +61,11 @@ public class Cluster extends ClusterBase
public static String formatCluster(Cluster cluster) {
return cluster.getIdentifier() + ": " + cluster.computeCentroid().asFormatString();
}
-
- @Override
+
public String asFormatString() {
return formatCluster(this);
}
-
+
/**
* Decodes and returns a Cluster from the formattedString.
*
@@ -108,7 +75,7 @@ public class Cluster extends ClusterBase
* @throws IllegalArgumentException
* when the string is wrongly formatted
*/
- public static Cluster decodeCluster(String formattedString) {
+ public static AbstractCluster decodeCluster(String formattedString) {
int beginIndex = formattedString.indexOf('{');
if (beginIndex <= 0) {
throw new IllegalArgumentException(ERROR_UNKNOWN_CLUSTER_FORMAT + formattedString);
@@ -128,89 +95,28 @@ public class Cluster extends ClusterBase
}
return cluster;
}
-
+
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
out.writeBoolean(converged);
- VectorWritable.writeVector(out, computeCentroid());
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
this.converged = in.readBoolean();
- VectorWritable temp = new VectorWritable();
- temp.readFields(in);
- this.setCenter(new RandomAccessSparseVector(temp.get()));
- this.setNumPoints(0);
- this.setPointTotal(getCenter().like());
- this.pointSquaredTotal = getCenter().like();
}
-
- /**
- * Compute the centroid by averaging the pointTotals
- *
- * @return the new centroid
- */
- @Override
- public Vector computeCentroid() {
- if (getNumPoints() == 0) {
- return getCenter();
- } else if (centroid == null) {
- // lazy compute new centroid
- centroid = getPointTotal().divide(getNumPoints());
- }
- return centroid;
- }
-
+
@Override
public String toString() {
- return getIdentifier() + ": " + getCenter().asFormatString();
+ return asFormatString(null);
}
-
- @Override
+
public String getIdentifier() {
- return (converged ? "V-" : "C-") + getId();
+ return (converged ? "VL-" : "CL-") + getId();
}
-
- /**
- * Add the point to the cluster
- *
- * @param point
- * a point to add
- */
- public void addPoint(Vector point) {
- addPoints(1, point);
- }
-
- /**
- * Add the point to the cluster
- *
- * @param count
- * the number of points in the delta
- * @param delta
- * a point to add
- */
- public void addPoints(int count, Vector delta) {
- centroid = null;
- if (getNumPoints() == 0) {
- setPointTotal(new RandomAccessSparseVector(delta.clone()));
- pointSquaredTotal = new RandomAccessSparseVector(delta.clone().assign(Functions.square));
- } else {
- delta.addTo(getPointTotal());
- delta.clone().assign(Functions.square).addTo(pointSquaredTotal);
- }
- setNumPoints(getNumPoints() + count);
- }
-
- /** Compute the centroid and set the center to it. */
- public void recomputeCenter() {
- std = getStd();
- setCenter(computeCentroid());
- centroid = null;
- }
-
+
/**
* Return if the cluster is converged by comparing its center and centroid.
*
@@ -225,28 +131,13 @@ public class Cluster extends ClusterBase
converged = measure.distance(centroid.getLengthSquared(), centroid, getCenter()) <= convergenceDelta;
return converged;
}
-
+
public boolean isConverged() {
return converged;
}
-
- private void setConverged(boolean converged) {
+
+ protected void setConverged(boolean converged) {
this.converged = converged;
}
-
- /** @return the std */
- public double getStd() {
- if (getNumPoints() == 0) {
- return std;
- }
- Vector stds = pointSquaredTotal.times(getNumPoints()).minus(getPointTotal().times(getPointTotal()))
- .assign(new SquareRootFunction()).divide(getNumPoints());
- return stds.zSum() / stds.size();
- }
-
- @Override
- public Vector getRadius() {
- return getCenter().like().assign(getStd());
- }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Sat Jul 24 03:39:30 2010
@@ -20,11 +20,14 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.SequenceFile.Writer;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
@@ -43,6 +46,8 @@ public class KMeansClusterer {
/** Distance to use for point to cluster comparison. */
private final DistanceMeasure measure;
+ private final double convergenceDelta;
+
/**
* Init the k-means clusterer with the distance measure to use for comparison.
*
@@ -52,6 +57,16 @@ public class KMeansClusterer {
*/
public KMeansClusterer(DistanceMeasure measure) {
this.measure = measure;
+ this.convergenceDelta = 0;
+ }
+
+ public KMeansClusterer(Configuration conf) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ Class<?> cl = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+ this.measure = (DistanceMeasure) cl.newInstance();
+ this.measure.configure(conf);
+
+ this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
}
/**
@@ -67,7 +82,7 @@ public class KMeansClusterer {
*/
public void emitPointToNearestCluster(Vector point,
List<Cluster> clusters,
- Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo>.Context context)
+ Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations>.Context context)
throws IOException, InterruptedException {
Cluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
@@ -82,7 +97,7 @@ public class KMeansClusterer {
nearestDistance = distance;
}
}
- context.write(new Text(nearestCluster.getIdentifier()), new KMeansInfo(1, point));
+ context.write(new Text(nearestCluster.getIdentifier()), new ClusterObservations(1, point, point.times(point)));
}
/**
@@ -100,7 +115,7 @@ public class KMeansClusterer {
closestDistance = distance;
}
}
- closestCluster.addPoint(point);
+ closestCluster.observe(point, 1);
}
/**
@@ -111,18 +126,12 @@ public class KMeansClusterer {
* @return
*/
protected boolean testConvergence(List<Cluster> clusters, double distanceThreshold) {
- // test for convergence
boolean converged = true;
for (Cluster cluster : clusters) {
- if (!cluster.computeConvergence(measure, distanceThreshold)) {
+ if (!computeConvergence(cluster)) {
converged = false;
}
- }
- // update the cluster centers
- if (!converged) {
- for (Cluster cluster : clusters) {
- cluster.recomputeCenter();
- }
+ cluster.computeParameters();
}
return converged;
}
@@ -131,9 +140,9 @@ public class KMeansClusterer {
List<Cluster> clusters,
Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable>.Context context)
throws IOException, InterruptedException {
- Cluster nearestCluster = null;
+ AbstractCluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
- for (Cluster cluster : clusters) {
+ for (AbstractCluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = measure.distance(clusterCenter.getLengthSquared(), clusterCenter, vector);
if ((distance < nearestDistance) || (nearestCluster == null)) {
@@ -157,9 +166,9 @@ public class KMeansClusterer {
*/
protected void emitPointToNearestCluster(Vector point, List<Cluster> clusters, Writer writer) throws IOException,
InterruptedException {
- Cluster nearestCluster = null;
+ AbstractCluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
- for (Cluster cluster : clusters) {
+ for (AbstractCluster cluster : clusters) {
Vector clusterCenter = cluster.getCenter();
double distance = this.measure.distance(clusterCenter.getLengthSquared(), clusterCenter, point);
if (log.isDebugEnabled()) {
@@ -197,12 +206,14 @@ public class KMeansClusterer {
boolean converged = false;
int iteration = 0;
while (!converged && iteration < maxIter) {
+ log.info("Reference Iteration: " + iteration);
List<Cluster> next = new ArrayList<Cluster>();
- for (Cluster c : clustersList.get(iteration++)) {
- next.add(new Cluster(c.getCenter()));
+ for (Cluster c : clustersList.get(iteration)) {
+ next.add(new Cluster(c.getCenter(), c.getId()));
}
clustersList.add(next);
converged = runKMeansIteration(points, next, measure, distanceThreshold);
+ iteration++;
}
return clustersList;
}
@@ -231,4 +242,8 @@ public class KMeansClusterer {
return clusterer.testConvergence(clusters, distanceThreshold);
}
+ public boolean computeConvergence(Cluster cluster) {
+ return cluster.computeConvergence(measure, convergenceDelta);
+ }
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java Sat Jul 24 03:39:30 2010
@@ -20,17 +20,18 @@ import java.io.IOException;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.ClusterObservations;
-public class KMeansCombiner extends Reducer<Text, KMeansInfo, Text, KMeansInfo> {
+public class KMeansCombiner extends Reducer<Text, ClusterObservations, Text, ClusterObservations> {
@Override
- protected void reduce(Text key, Iterable<KMeansInfo> values, Context context) throws IOException, InterruptedException {
+ protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
- Cluster cluster = new Cluster(key.toString());
- for (KMeansInfo value : values) {
- cluster.addPoints(value.getPoints(), value.getPointTotal());
+ Cluster cluster = new Cluster();
+ for (ClusterObservations value : values) {
+ cluster.observe(value);
}
- context.write(key, new KMeansInfo(cluster.getNumPoints(), cluster.getPointTotal()));
+ context.write(key, cluster.getObservations());
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Jul 24 03:39:30 2010
@@ -34,6 +34,8 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
@@ -199,7 +201,7 @@ public class KMeansDriver extends Abstra
Path clustersOut = buildClusters(input, clustersIn, output, measure, maxIterations, numReduceTasks, delta, runSequential);
if (runClustering) {
log.info("Clustering data");
- clusterData(input, clustersOut, new Path(output, Cluster.CLUSTERED_POINTS_DIR), measure, delta, runSequential);
+ clusterData(input, clustersOut, new Path(output, AbstractCluster.CLUSTERED_POINTS_DIR), measure, delta, runSequential);
}
}
@@ -276,6 +278,7 @@ public class KMeansDriver extends Abstra
boolean converged = false;
int iteration = 1;
while (!converged && iteration <= maxIterations) {
+ log.info("K-Means Iteration: " + iteration);
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(input.toUri(), conf);
FileStatus[] status = fs.listStatus(input, new OutputLogFilter());
@@ -293,7 +296,7 @@ public class KMeansDriver extends Abstra
}
}
converged = clusterer.testConvergence(clusters, Double.parseDouble(delta));
- Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
+ Path clustersOut = new Path(output, AbstractCluster.CLUSTERS_DIR + iteration);
SequenceFile.Writer writer = new SequenceFile.Writer(fs,
conf,
new Path(clustersOut, "part-r-00000"),
@@ -301,12 +304,16 @@ public class KMeansDriver extends Abstra
Cluster.class);
try {
for (Cluster cluster : clusters) {
+ log.info("Writing Cluster:" + cluster.getId() + " center:" + AbstractCluster.formatVector(cluster.getCenter(), null)
+ + " numPoints:" + cluster.getNumPoints() + " radius:" + AbstractCluster.formatVector(cluster.getRadius(), null) + " to: "
+ + clustersOut.getName());
writer.append(new Text(cluster.getIdentifier()), cluster);
}
} finally {
writer.close();
}
clustersIn = clustersOut;
+ iteration++;
}
return clustersIn;
}
@@ -336,7 +343,7 @@ public class KMeansDriver extends Abstra
while (!converged && (iteration <= maxIterations)) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
- Path clustersOut = new Path(output, Cluster.CLUSTERS_DIR + iteration);
+ Path clustersOut = new Path(output, AbstractCluster.CLUSTERS_DIR + iteration);
converged = runIteration(input, clustersIn, clustersOut, measure.getClass().getName(), delta, numReduceTasks);
// now point the input to the old output directory
clustersIn = clustersOut;
@@ -378,7 +385,7 @@ public class KMeansDriver extends Abstra
Job job = new Job(conf);
job.setMapOutputKeyClass(Text.class);
- job.setMapOutputValueClass(KMeansInfo.class);
+ job.setMapOutputValueClass(ClusterObservations.class);
job.setOutputKeyClass(Text.class);
job.setOutputValueClass(Cluster.class);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Sat Jul 24 03:39:30 2010
@@ -25,10 +25,11 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.VectorWritable;
-public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo> {
+public class KMeansMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, ClusterObservations> {
private KMeansClusterer clusterer;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Sat Jul 24 03:39:30 2010
@@ -26,28 +26,27 @@ import org.apache.hadoop.conf.Configurat
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.ClusterObservations;
import org.apache.mahout.common.distance.DistanceMeasure;
-public class KMeansReducer extends Reducer<Text, KMeansInfo, Text, Cluster> {
+public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster> {
private Map<String, Cluster> clusterMap;
- private double convergenceDelta;
-
- private DistanceMeasure measure;
+ private KMeansClusterer clusterer;
@Override
- protected void reduce(Text key, Iterable<KMeansInfo> values, Context context)
- throws IOException, InterruptedException {
+ protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
Cluster cluster = clusterMap.get(key.toString());
- for (KMeansInfo delta : values) {
- cluster.addPoints(delta.getPoints(), delta.getPointTotal());
+ for (ClusterObservations delta : values) {
+ cluster.observe(delta);
}
// force convergence calculation
- boolean converged = cluster.computeConvergence(this.measure, this.convergenceDelta);
+ boolean converged = clusterer.computeConvergence(cluster);
if (converged) {
context.getCounter("Clustering", "Converged Clusters").increment(1);
}
+ cluster.computeParameters();
context.write(new Text(cluster.getIdentifier()), cluster);
}
@@ -56,13 +55,7 @@ public class KMeansReducer extends Reduc
super.setup(context);
Configuration conf = context.getConfiguration();
try {
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- Class<?> cl = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
- this.measure = (DistanceMeasure) cl.newInstance();
- this.measure.configure(conf);
-
- this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
-
+ this.clusterer = new KMeansClusterer(conf);
this.clusterMap = new HashMap<String, Cluster>();
String path = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
@@ -93,7 +86,7 @@ public class KMeansReducer extends Reduc
public void setup(List<Cluster> clusters, DistanceMeasure measure) {
setClusterMap(clusters);
- this.measure = measure;
+ this.clusterer = new KMeansClusterer(measure);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Sat Jul 24 03:39:30 2010
@@ -85,7 +85,7 @@ public final class RandomSeedGenerator {
VectorWritable value = (VectorWritable) reader.getValueClass().newInstance();
while (reader.next(key, value)) {
Cluster newCluster = new Cluster(value.get(), nextClusterId++);
- newCluster.addPoint(value.get());
+ newCluster.observe(value.get(), 1);
Text newText = new Text(key.toString());
int currentSize = chosenTexts.size();
if (currentSize < k) {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=978786&r1=978785&r2=978786&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Sat Jul 24 03:39:30 2010
@@ -22,50 +22,33 @@ import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
-import com.google.gson.reflect.TypeToken;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.math.JsonVectorAdapter;
-import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.list.IntArrayList;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
/**
* This class models a canopy as a center point, the number of points that are contained within it according
* to the application of some distance metric, and a point total which is the sum of all the points and is
* used to compute the centroid when needed.
*/
-public class MeanShiftCanopy extends ClusterBase {
+public class MeanShiftCanopy extends Cluster {
- private static final Type VECTOR_TYPE = new TypeToken<Vector>() { }.getType();
-
- // TODO: this is problematic, but how else to encode membership?
+ private static final Type VECTOR_TYPE = new TypeToken<Vector>() {
+ }.getType();
+
+ // TODO: this is still problematic from a scalability perspective, but how else to encode membership?
private IntArrayList boundPoints = new IntArrayList();
-
- private boolean converged;
-
- public MeanShiftCanopy() {
- }
-
- /** Create a new Canopy with the given canopyId */
- /*
- * public MeanShiftCanopy(String id) { this.setId(Integer.parseInt(id.substring(1))); this.setCenter(null);
- * this.setPointTotal(null); this.setNumPoints(0); }
- */
/**
- * Create a new Canopy containing the given point
- *
- * @param point
- * a Vector
- */
- /*
- * public MeanShiftCanopy(Vector point) { this.setCenter(point); this.setPointTotal(point.clone());
- * this.setNumPoints(1); this.boundPoints.add(point); }
+ * Used for Writable
*/
+ public MeanShiftCanopy() {
+ }
/**
* Create a new Canopy containing the given point
@@ -74,13 +57,10 @@ public class MeanShiftCanopy extends Clu
* a Vector
*/
public MeanShiftCanopy(Vector point, int id) {
- this.setId(id);
- this.setCenter(point);
- this.setPointTotal(new RandomAccessSparseVector(point.clone()));
- this.setNumPoints(1);
- this.boundPoints.add(id);
+ super(point, id);
+ boundPoints.add(id);
}
-
+
/**
* Create a new Canopy containing the given point, id and bound points
*
@@ -96,66 +76,18 @@ public class MeanShiftCanopy extends Clu
MeanShiftCanopy(Vector point, int id, IntArrayList boundPoints, boolean converged) {
this.setId(id);
this.setCenter(point);
- this.setPointTotal(new RandomAccessSparseVector(point));
+ this.setRadius(point.like());
this.setNumPoints(1);
this.boundPoints = boundPoints;
- this.converged = converged;
- }
-
- /**
- * Add a point to the canopy some number of times
- *
- * @param point
- * a Vector to add
- * @param nPoints
- * the number of times to add the point
- * @throws org.apache.mahout.math.CardinalityException
- * if the cardinalities disagree
- */
- void addPoints(Vector point, int nPoints) {
- setNumPoints(getNumPoints() + nPoints);
- Vector subTotal = nPoints == 1 ? point.clone() : point.times(nPoints);
- if (getPointTotal() == null) {
- setPointTotal(new RandomAccessSparseVector(subTotal));
- } else {
- subTotal.addTo(getPointTotal());
- }
- }
-
-
- /**
- * Compute the centroid by normalizing the pointTotal
- *
- * @return a Vector which is the new centroid
- */
- @Override
- public Vector computeCentroid() {
- if (getNumPoints() <= 1) {
- return getCenter();
- } else {
- return getPointTotal().divide(getNumPoints());
- }
+ setConverged(converged);
}
-
+
public IntArrayList getBoundPoints() {
return boundPoints;
}
-
- public int getCanopyId() {
- return getId();
- }
-
- @Override
- public String getIdentifier() {
- return (converged ? "V-" : "C-") + getId();
- }
-
- public boolean isConverged() {
- return converged;
- }
-
+
/**
- * The receiver overlaps the given canopy. Touch it and add my bound points to it.
+ * The receiver overlaps the given canopy. Add my bound points to it.
*
* @param canopy
* an existing MeanShiftCanopy
@@ -163,12 +95,7 @@ public class MeanShiftCanopy extends Clu
void merge(MeanShiftCanopy canopy) {
boundPoints.addAllOf(canopy.boundPoints);
}
-
- @Override
- public String toString() {
- return formatCanopy(this);
- }
-
+
/**
* The receiver touches the given canopy. Add respective centers.
*
@@ -176,56 +103,52 @@ public class MeanShiftCanopy extends Clu
* an existing MeanShiftCanopy
*/
void touch(MeanShiftCanopy canopy) {
- canopy.addPoints(getCenter(), boundPoints.size());
- addPoints(canopy.getCenter(), canopy.boundPoints.size());
+ canopy.observe(getCenter(), boundPoints.size());
+ observe(canopy.getCenter(), canopy.boundPoints.size());
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
- VectorWritable temp = new VectorWritable();
- temp.readFields(in);
- this.setCenter(temp.get());
int numpoints = in.readInt();
this.boundPoints = new IntArrayList();
for (int i = 0; i < numpoints; i++) {
this.boundPoints.add(in.readInt());
}
}
-
+
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
- VectorWritable.writeVector(out, computeCentroid());
out.writeInt(boundPoints.size());
for (int v : boundPoints.elements()) {
out.writeInt(v);
}
}
-
+
public MeanShiftCanopy shallowCopy() {
MeanShiftCanopy result = new MeanShiftCanopy();
result.setId(this.getId());
result.setCenter(this.getCenter());
- result.setPointTotal(this.getPointTotal());
+ result.setRadius(this.getRadius());
result.setNumPoints(this.getNumPoints());
result.boundPoints = this.boundPoints;
return result;
}
-
+
@Override
public String asFormatString() {
return formatCanopy(this);
}
-
+
public void setBoundPoints(IntArrayList boundPoints) {
this.boundPoints = boundPoints;
}
-
- public void setConverged(boolean converged) {
- this.converged = converged;
+
+ public String getIdentifier() {
+ return (isConverged() ? "MSV-" : "MSC-") + getId();
}
-
+
/** Format the canopy for output */
public static String formatCanopy(MeanShiftCanopy canopy) {
GsonBuilder gBuilder = new GsonBuilder();
@@ -233,7 +156,7 @@ public class MeanShiftCanopy extends Clu
Gson gson = gBuilder.create();
return gson.toJson(canopy, MeanShiftCanopy.class);
}
-
+
/**
* Decodes and returns a Canopy from the formattedString
*
@@ -248,9 +171,4 @@ public class MeanShiftCanopy extends Clu
return gson.fromJson(formattedString, MeanShiftCanopy.class);
}
- @Override
- public Vector getRadius() {
- return getCenter().like();
- }
-
}