You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 22:08:12 UTC
svn commit: r909914 [1/5] - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/
main/java/org/apache/mahout/clustering/canopy/
main/java/org/apache/mahout/clustering/dirichlet/
main/java/org/apache/mahout/clustering/dirichlet/mode...
Author: robinanil
Date: Sat Feb 13 21:07:53 2010
New Revision: 909914
URL: http://svn.apache.org/viewvc?rev=909914&view=rev
Log:
MAHOUT-291
Mahout Clustering Code Cleanup
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusteringJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonClusterAdapter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/JsonModelAdapter.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/VectorModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansCombiner.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansConfigKeys.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansInfo.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansOutput.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyConfigKeys.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyReducer.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java Sat Feb 13 21:07:53 2010
@@ -34,133 +34,135 @@
import com.google.gson.reflect.TypeToken;
public abstract class ClusterBase implements Writable, Printable {
-
- /**
- * Return a human-readable formatted string representation of the vector, not intended
- * to be complete nor usable as an input/output representation such as Json
- *
- * @param v a Vector
- * @return a String
- */
- public static String formatVector(Vector v, String[] bindings) {
- StringBuilder buf = new StringBuilder();
- int nzero = 0;
- Iterator<Element> iterateNonZero = v.iterateNonZero();
- while (iterateNonZero.hasNext()) {
- iterateNonZero.next();
- nzero++;
- }
- // if vector is sparse or if we have bindings, use sparse notation
- if (nzero < v.size() || bindings != null) {
- buf.append('[');
- for (int i = 0; i < v.size(); i++) {
- double elem = v.get(i);
- if (elem == 0.0)
- continue;
- String label;
- if (bindings != null && (label = bindings[i]) != null)
- buf.append(label).append(':');
- else
- buf.append(i).append(':');
- buf.append(String.format("%.3f", elem)).append(", ");
- }
- } else {
- buf.append('[');
- for (int i = 0; i < v.size(); i++) {
- double elem = v.get(i);
- buf.append(String.format("%.3f", elem)).append(", ");
- }
- }
- buf.setLength(buf.length() - 2);
- buf.append(']');
- return buf.toString();
- }
-
// this cluster's clusterId
private int id;
-
+
// the current cluster center
private Vector center = new RandomAccessSparseVector(0);
-
+
// the number of points in the cluster
- private int numPoints = 0;
-
+ private int numPoints;
+
// the Vector total of all points added to the cluster
- private Vector pointTotal = null;
-
+ private Vector pointTotal;
+
public int getId() {
return id;
}
-
+
public void setId(int id) {
this.id = id;
}
-
+
public Vector getCenter() {
return center;
}
-
+
public void setCenter(Vector center) {
this.center = center;
}
-
+
public int getNumPoints() {
return numPoints;
}
-
+
public void setNumPoints(int numPoints) {
this.numPoints = numPoints;
}
-
+
public Vector getPointTotal() {
return pointTotal;
}
-
+
public void setPointTotal(Vector pointTotal) {
this.pointTotal = pointTotal;
}
-
+
/**
* @deprecated
* @return
*/
@Deprecated
public abstract String asFormatString();
-
+
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder();
- buf.append(getIdentifier()).append(": ").append(formatVector(computeCentroid(), bindings));
+ buf.append(getIdentifier()).append(": ").append(ClusterBase.formatVector(computeCentroid(), bindings));
return buf.toString();
}
-
+
public abstract Vector computeCentroid();
-
+
public abstract Object getIdentifier();
-
+
@Override
public String asJsonString() {
- Type vectorType = new TypeToken<Vector>() {
- }.getType();
+ Type vectorType = new TypeToken<Vector>() { }.getType();
GsonBuilder gBuilder = new GsonBuilder();
gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
Gson gson = gBuilder.create();
return gson.toJson(this, this.getClass());
}
-
+
/**
* Simply writes out the id, and that's it!
- *
- * @param out The {@link java.io.DataOutput}
+ *
+ * @param out
+ * The {@link java.io.DataOutput}
*/
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
}
-
+
/** Reads in the id, nothing else */
@Override
public void readFields(DataInput in) throws IOException {
id = in.readInt();
}
+
+ /**
+ * Return a human-readable formatted string representation of the vector, not intended to be complete nor
+ * usable as an input/output representation such as Json
+ *
+ * @param v
+ * a Vector
+ * @return a String
+ */
+ public static String formatVector(Vector v, String[] bindings) {
+ StringBuilder buf = new StringBuilder();
+ int nzero = 0;
+ Iterator<Element> iterateNonZero = v.iterateNonZero();
+ while (iterateNonZero.hasNext()) {
+ iterateNonZero.next();
+ nzero++;
+ }
+ // if vector is sparse or if we have bindings, use sparse notation
+ if ((nzero < v.size()) || (bindings != null)) {
+ buf.append('[');
+ for (int i = 0; i < v.size(); i++) {
+ double elem = v.get(i);
+ if (elem == 0.0) {
+ continue;
+ }
+ String label;
+ if ((bindings != null) && ((label = bindings[i]) != null)) {
+ buf.append(label).append(':');
+ } else {
+ buf.append(i).append(':');
+ }
+ buf.append(String.format("%.3f", elem)).append(", ");
+ }
+ } else {
+ buf.append('[');
+ for (int i = 0; i < v.size(); i++) {
+ double elem = v.get(i);
+ buf.append(String.format("%.3f", elem)).append(", ");
+ }
+ }
+ buf.setLength(buf.length() - 2);
+ buf.append(']');
+ return buf.toString();
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Printable.java Sat Feb 13 21:07:53 2010
@@ -15,29 +15,29 @@
*/
package org.apache.mahout.clustering;
-
/**
- * Implementations of this interface have a printable representation. This representation
- * may be enhanced by an optional Vector label bindings dictionary.
- *
+ * Implementations of this interface have a printable representation. This representation may be enhanced by
+ * an optional Vector label bindings dictionary.
+ *
*/
public interface Printable {
-
+
/**
* Produce a custom, printable representation of the receiver.
*
- * @param bindings an optional String[] containing labels used to format the primary
- * Vector/s of this implementation.
+ * @param bindings
+ * an optional String[] containing labels used to format the primary Vector/s of this
+ * implementation.
* @return a String
*/
- public String asFormatString(String[] bindings);
-
+ String asFormatString(String[] bindings);
+
/**
- * Produce a printable representation of the receiver using Json. (Label bindings
- * are transient and not part of the Json representation)
+ * Produce a printable representation of the receiver using Json. (Label bindings are transient and not part
+ * of the Json representation)
*
* @return a Json String
*/
- public String asJsonString();
-
+ String asJsonString();
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,10 @@
package org.apache.mahout.clustering.canopy;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.mahout.clustering.ClusterBase;
@@ -24,26 +28,23 @@
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-
/**
- * This class models a canopy as a center point, the number of points that are contained within it according to the
- * application of some distance metric, and a point total which is the sum of all the points and is used to compute the
- * centroid when needed.
+ * This class models a canopy as a center point, the number of points that are contained within it according
+ * to the application of some distance metric, and a point total which is the sum of all the points and is
+ * used to compute the centroid when needed.
*/
public class Canopy extends ClusterBase {
-
+
/** Used for deserializaztion as a writable */
- public Canopy() {
- }
-
+ public Canopy() { }
+
/**
* Create a new Canopy containing the given point and canopyId
- *
- * @param point a point in vector space
- * @param canopyId an int identifying the canopy local to this process only
+ *
+ * @param point
+ * a point in vector space
+ * @param canopyId
+ * an int identifying the canopy local to this process only
*/
public Canopy(Vector point, int canopyId) {
this.setId(canopyId);
@@ -51,13 +52,13 @@
this.setPointTotal(point.clone());
this.setNumPoints(1);
}
-
+
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
VectorWritable.writeVector(out, computeCentroid());
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
@@ -67,22 +68,22 @@
this.setPointTotal(getCenter().clone());
this.setNumPoints(1);
}
-
+
/** Format the canopy for output */
public static String formatCanopy(Canopy canopy) {
- return "C" + canopy.getId() + ": "
- + canopy.computeCentroid().asFormatString();
+ return "C" + canopy.getId() + ": " + canopy.computeCentroid().asFormatString();
}
-
+
@Override
public String asFormatString() {
- return formatCanopy(this);
+ return Canopy.formatCanopy(this);
}
-
+
/**
* Decodes and returns a Canopy from the formattedString
- *
- * @param formattedString a String prouced by formatCanopy
+ *
+ * @param formattedString
+ * a String prouced by formatCanopy
* @return a new Canopy
*/
public static Canopy decodeCanopy(String formattedString) {
@@ -90,49 +91,48 @@
String id = formattedString.substring(0, beginIndex);
String centroid = formattedString.substring(beginIndex);
if (id.charAt(0) == 'C') {
- int canopyId = Integer.parseInt(formattedString.substring(1,
- beginIndex - 2));
+ int canopyId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));
Vector canopyCentroid = AbstractVector.decodeVector(centroid);
return new Canopy(canopyCentroid, canopyId);
}
return null;
}
-
+
/**
* Add a point to the canopy
- *
- * @param point some point to add
+ *
+ * @param point
+ * some point to add
*/
public void addPoint(Vector point) {
setNumPoints(getNumPoints() + 1);
setPointTotal(getPointTotal().plus(point));
-
+
}
-
+
/**
* Emit the point to the collector, keyed by the canopy's formatted representation
- *
- * @param point a point to emit.
+ *
+ * @param point
+ * a point to emit.
*/
- public void emitPoint(Vector point, OutputCollector<Text, Vector> collector)
- throws IOException {
+ public void emitPoint(Vector point, OutputCollector<Text,Vector> collector) throws IOException {
collector.collect(new Text(this.getIdentifier()), point);
}
-
+
@Override
public String toString() {
return getIdentifier() + " - " + getCenter().asFormatString();
}
-
+
@Override
public String getIdentifier() {
return "C" + getId();
}
-
-
+
/**
* Compute the centroid by averaging the pointTotals
- *
+ *
* @return a RandomAccessSparseVector (required by Mapper) which is the new centroid
*/
@Override
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Sat Feb 13 21:07:53 2010
@@ -23,45 +23,46 @@
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
public class CanopyClusterer {
-
+
private int nextCanopyId;
// the T1 distance threshold
private double t1;
-
+
// the T2 distance threshold
private double t2;
-
+
// the distance measure
private DistanceMeasure measure;
-
- //private int nextClusterId = 0;
+
+ // private int nextClusterId = 0;
public CanopyClusterer(DistanceMeasure measure, double t1, double t2) {
this.t1 = t1;
this.t2 = t2;
this.measure = measure;
}
-
+
public CanopyClusterer(JobConf job) {
this.configure(job);
}
-
+
/**
* Configure the Canopy and its distance measure
*
- * @param job the JobConf for this job
+ * @param job
+ * the JobConf for this job
*/
public void configure(JobConf job) {
try {
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- Class<?> cl = ccl.loadClass(job
- .get(CanopyConfigKeys.DISTANCE_MEASURE_KEY));
+ Class<?> cl = ccl.loadClass(job.get(CanopyConfigKeys.DISTANCE_MEASURE_KEY));
measure = (DistanceMeasure) cl.newInstance();
measure.configure(job);
} catch (ClassNotFoundException e) {
@@ -75,55 +76,61 @@
t2 = Double.parseDouble(job.get(CanopyConfigKeys.T2_KEY));
nextCanopyId = 0;
}
-
+
/** Configure the Canopy for unit tests */
public void config(DistanceMeasure aMeasure, double aT1, double aT2) {
measure = aMeasure;
t1 = aT1;
t2 = aT2;
}
-
+
/**
- * This is the same algorithm as the reference but inverted to iterate over
- * existing canopies instead of the points. Because of this it does not need
- * to actually store the points, instead storing a total points vector and the
- * number of points. From this a centroid can be computed.
+ * This is the same algorithm as the reference but inverted to iterate over existing canopies instead of the
+ * points. Because of this it does not need to actually store the points, instead storing a total points
+ * vector and the number of points. From this a centroid can be computed.
* <p/>
* This method is used by the CanopyReducer.
*
- * @param point the point to be added
- * @param canopies the List<Canopy> to be appended
+ * @param point
+ * the point to be added
+ * @param canopies
+ * the List<Canopy> to be appended
+ * @param reporter
+ * Object to report status to the MR interface
*/
- public void addPointToCanopies(Vector point, List<Canopy> canopies) {
+ public void addPointToCanopies(Vector point, List<Canopy> canopies, Reporter reporter) {
boolean pointStronglyBound = false;
for (Canopy canopy : canopies) {
- double dist = measure.distance(canopy.getCenter().getLengthSquared(),
- canopy.getCenter(), point);
+ double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
if (dist < t1) {
canopy.addPoint(point);
}
pointStronglyBound = pointStronglyBound || (dist < t2);
}
if (!pointStronglyBound) {
+ reporter.setStatus("Created new Canopy:" + nextCanopyId);
canopies.add(new Canopy(point, nextCanopyId++));
}
}
-
+
/**
- * This method is used by the CanopyMapper to perform canopy inclusion tests
- * and to emit the point and its covering canopies to the output. The
- * CanopyCombiner will then sum the canopy points and produce the centroids.
+ * This method is used by the CanopyMapper to perform canopy inclusion tests and to emit the point and its
+ * covering canopies to the output. The CanopyCombiner will then sum the canopy points and produce the
+ * centroids.
*
- * @param point the point to be added
- * @param canopies the List<Canopy> to be appended
- * @param collector an OutputCollector in which to emit the point
+ * @param point
+ * the point to be added
+ * @param canopies
+ * the List<Canopy> to be appended
+ * @param collector
+ * an OutputCollector in which to emit the point
*/
- public void emitPointToNewCanopies(Vector point, List<Canopy> canopies,
- OutputCollector<Text, Vector> collector) throws IOException {
+ public void emitPointToNewCanopies(Vector point,
+ List<Canopy> canopies,
+ OutputCollector<Text,Vector> collector) throws IOException {
boolean pointStronglyBound = false;
for (Canopy canopy : canopies) {
- double dist = measure.distance(canopy.getCenter().getLengthSquared(),
- canopy.getCenter(), point);
+ double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
if (dist < t1) {
canopy.emitPoint(point, collector);
}
@@ -135,26 +142,28 @@
canopy.emitPoint(point, collector);
}
}
-
+
/**
- * This method is used by the CanopyMapper to perform canopy inclusion tests
- * and to emit the point keyed by its covering canopies to the output. if the
- * point is not covered by any canopies (due to canopy centroid clustering),
- * emit the point to the closest covering canopy.
+ * This method is used by the CanopyMapper to perform canopy inclusion tests and to emit the point keyed by
+ * its covering canopies to the output. if the point is not covered by any canopies (due to canopy centroid
+ * clustering), emit the point to the closest covering canopy.
*
- * @param point the point to be added
- * @param canopies the List<Canopy> to be appended
- * @param collector an OutputCollector in which to emit the point
+ * @param point
+ * the point to be added
+ * @param canopies
+ * the List<Canopy> to be appended
+ * @param collector
+ * an OutputCollector in which to emit the point
*/
- public void emitPointToExistingCanopies(Vector point, List<Canopy> canopies,
- OutputCollector<Text, VectorWritable> collector) throws IOException {
+ public void emitPointToExistingCanopies(Vector point,
+ List<Canopy> canopies,
+ OutputCollector<Text,VectorWritable> collector) throws IOException {
double minDist = Double.MAX_VALUE;
Canopy closest = null;
boolean isCovered = false;
VectorWritable vw = new VectorWritable();
for (Canopy canopy : canopies) {
- double dist = measure.distance(canopy.getCenter().getLengthSquared(),
- canopy.getCenter(), point);
+ double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
if (dist < t1) {
isCovered = true;
vw.set(point);
@@ -171,15 +180,15 @@
collector.collect(new Text(closest.getIdentifier()), vw);
}
}
-
+
/**
* Return if the point is covered by the canopy
*
- * @param point a point
+ * @param point
+ * a point
* @return if the point is covered
*/
public boolean canopyCovers(Canopy canopy, Vector point) {
- return measure.distance(canopy.getCenter().getLengthSquared(),
- canopy.getCenter(), point) < t1;
+ return measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point) < t1;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusteringJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusteringJob.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusteringJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusteringJob.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -30,24 +32,24 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-
/**
- * Runs the {@link org.apache.mahout.clustering.canopy.CanopyDriver#runJob(String, String, String, double, double)}
- * and then {@link org.apache.mahout.clustering.canopy.ClusterDriver#runJob(String, String, String, String, double, double)}.
+ * Runs the
+ * {@link org.apache.mahout.clustering.canopy.CanopyDriver#runJob(String, String, String, double, double)} and
+ * then
+ * {@link org.apache.mahout.clustering.canopy.ClusterDriver#runJob(String, String, String, String, double, double)}
+ * .
*/
public final class CanopyClusteringJob {
-
+
private static final Logger log = LoggerFactory.getLogger(CanopyClusteringJob.class);
-
+
/** The default name of the canopies output sub-directory. */
public static final String DEFAULT_CANOPIES_OUTPUT_DIRECTORY = "/canopies";
/** The default name of the directory used to output clusters. */
public static final String DEFAULT_CLUSTER_OUTPUT_DIRECTORY = ClusterDriver.DEFAULT_CLUSTER_OUTPUT_DIRECTORY;
-
- private CanopyClusteringJob() {
- }
-
+
+ private CanopyClusteringJob() { }
+
/**
* @param args
*/
@@ -55,83 +57,87 @@
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
- withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
- withDescription("The Path to put the output in").withShortName("o").create();
-
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path to put the output in").withShortName("o").create();
+
Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
- abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
- withDescription("The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
-
+ abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+
Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
- abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
- withDescription("The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v").create();
+ abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v")
+ .create();
Option t1Opt = obuilder.withLongName("t1").withRequired(true).withArgument(
- abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).
- withDescription("t1").withShortName("t1").create();
+ abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).withDescription("t1").withShortName(
+ "t1").create();
Option t2Opt = obuilder.withLongName("t2").withRequired(true).withArgument(
- abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).
- withDescription("t2").withShortName("t2").create();
-
-
- Option helpOpt = obuilder.withLongName("help").
- withDescription("Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt)
- .withOption(measureClassOpt).withOption(vectorClassOpt)
- .withOption(t1Opt).withOption(t2Opt)
- .withOption(helpOpt).create();
-
-
+ abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).withDescription("t2").withShortName(
+ "t2").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
+ measureClassOpt).withOption(vectorClassOpt).withOption(t1Opt).withOption(t2Opt).withOption(helpOpt)
+ .create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
-
+
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
if (cmdLine.hasOption(measureClassOpt)) {
measureClass = cmdLine.getValue(measureClassOpt).toString();
}
-
- //Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
- // RandomAccessSparseVector.class
- // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+
+ // Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+ // RandomAccessSparseVector.class
+ // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
double t1 = Double.parseDouble(cmdLine.getValue(t1Opt).toString());
double t2 = Double.parseDouble(cmdLine.getValue(t2Opt).toString());
-
- runJob(input, output, measureClass, t1, t2);
-
+
+ CanopyClusteringJob.runJob(input, output, measureClass, t1, t2);
+
} catch (OptionException e) {
- log.error("Exception", e);
+ CanopyClusteringJob.log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
}
-
+
/**
* Run the job
- *
- * @param input the input pathname String
- * @param output the output pathname String
- * @param measureClassName the DistanceMeasure class name
- * @param t1 the T1 distance threshold
- * @param t2 the T2 distance threshold
+ *
+ * @param input
+ * the input pathname String
+ * @param output
+ * the output pathname String
+ * @param measureClassName
+ * the DistanceMeasure class name
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
*/
- public static void runJob(String input, String output,
- String measureClassName, double t1, double t2) throws IOException {
- CanopyDriver.runJob(input, output + DEFAULT_CANOPIES_OUTPUT_DIRECTORY, measureClassName, t1, t2);
- ClusterDriver.runJob(input, output + DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output, measureClassName, t1, t2);
+ public static void runJob(String input, String output, String measureClassName, double t1, double t2) throws IOException {
+ CanopyDriver.runJob(input, output + CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY,
+ measureClassName, t1, t2);
+ ClusterDriver.runJob(input, output + CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output,
+ measureClassName, t1, t2);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java Sat Feb 13 21:07:53 2010
@@ -18,11 +18,11 @@
package org.apache.mahout.clustering.canopy;
public interface CanopyConfigKeys {
-
- String T1_KEY = "org.apache.mahout.clustering.canopy.t1";
- String CANOPY_PATH_KEY = "org.apache.mahout.clustering.canopy.path";
- String T2_KEY = "org.apache.mahout.clustering.canopy.t2";
- // keys used by Driver, Mapper, Combiner & Reducer
- String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.canopy.measure";
-
+
+ String T1_KEY = "org.apache.mahout.clustering.canopy.t1";
+ String CANOPY_PATH_KEY = "org.apache.mahout.clustering.canopy.path";
+ String T2_KEY = "org.apache.mahout.clustering.canopy.t2";
+ // keys used by Driver, Mapper, Combiner & Reducer
+ String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.canopy.measure";
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -41,116 +43,118 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-
public final class CanopyDriver {
-
+
private static final Logger log = LoggerFactory.getLogger(CanopyDriver.class);
-
- private CanopyDriver() {
- }
-
+
+ private CanopyDriver() { }
+
public static void main(String[] args) throws IOException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
- withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
- withDescription("The Path to put the output in").withShortName("o").create();
-
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path to put the output in").withShortName("o").create();
+
Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
- abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
- withDescription("The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
-
+ abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+
Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
- abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
- withDescription("The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v").create();
+ abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v")
+ .create();
Option t1Opt = obuilder.withLongName("t1").withRequired(true).withArgument(
- abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).
- withDescription("t1").withShortName("t1").create();
+ abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).withDescription("t1").withShortName(
+ "t1").create();
Option t2Opt = obuilder.withLongName("t2").withRequired(true).withArgument(
- abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).
- withDescription("t2").withShortName("t2").create();
-
-
- Option helpOpt = obuilder.withLongName("help").
- withDescription("Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt)
- .withOption(measureClassOpt).withOption(vectorClassOpt)
- .withOption(t1Opt).withOption(t2Opt)
- .withOption(helpOpt).create();
-
+ abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).withDescription("t2").withShortName(
+ "t2").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
+ measureClassOpt).withOption(vectorClassOpt).withOption(t1Opt).withOption(t2Opt).withOption(helpOpt)
+ .create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
-
+
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
if (cmdLine.hasOption(measureClassOpt)) {
measureClass = cmdLine.getValue(measureClassOpt).toString();
}
-
- //Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
- // RandomAccessSparseVector.class
- // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+
+ // Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+ // RandomAccessSparseVector.class
+ // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
double t1 = Double.parseDouble(cmdLine.getValue(t1Opt).toString());
double t2 = Double.parseDouble(cmdLine.getValue(t2Opt).toString());
-
- runJob(input, output, measureClass, t1, t2);
+
+ CanopyDriver.runJob(input, output, measureClass, t1, t2);
} catch (OptionException e) {
- log.error("Exception", e);
+ CanopyDriver.log.error("Exception", e);
CommandLineUtil.printHelp(group);
-
+
}
}
-
+
/**
* Run the job
- *
- * @param input the input pathname String
- * @param output the output pathname String
- * @param measureClassName the DistanceMeasure class name
- * @param t1 the T1 distance threshold
- * @param t2 the T2 distance threshold
+ *
+ * @param input
+ * the input pathname String
+ * @param output
+ * the output pathname String
+ * @param measureClassName
+ * the DistanceMeasure class name
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
*/
- public static void runJob(String input, String output,
- String measureClassName, double t1, double t2) throws IOException {
- log.info("Input: {} Out: {} Measure: {} t1: {} t2: {}", new Object[] {input, output, measureClassName, t1, t2});
+ public static void runJob(String input, String output, String measureClassName, double t1, double t2) throws IOException {
+ CanopyDriver.log.info("Input: {} Out: {} Measure: {} t1: {} t2: {}", new Object[] {input, output,
+ measureClassName, t1,
+ t2});
Configurable client = new JobClient();
JobConf conf = new JobConf(CanopyDriver.class);
conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, measureClassName);
conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(t1));
conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(t2));
-
+
conf.setInputFormat(SequenceFileInputFormat.class);
-
+
conf.setMapOutputKeyClass(Text.class);
conf.setMapOutputValueClass(VectorWritable.class);
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(Canopy.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
-
+
conf.setMapperClass(CanopyMapper.class);
conf.setReducerClass(CanopyReducer.class);
conf.setNumReduceTasks(1);
conf.setOutputFormat(SequenceFileOutputFormat.class);
-
+
client.setConf(conf);
FileSystem dfs = FileSystem.get(outPath.toUri(), conf);
if (dfs.exists(outPath)) {
@@ -158,5 +162,5 @@
}
JobClient.runJob(conf);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,10 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -27,32 +31,30 @@
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
public class CanopyMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
-
+ Mapper<WritableComparable<?>,VectorWritable,Text,VectorWritable> {
+
private final List<Canopy> canopies = new ArrayList<Canopy>();
-
- private OutputCollector<Text, VectorWritable> outputCollector;
-
+
+ private OutputCollector<Text,VectorWritable> outputCollector;
+
private CanopyClusterer canopyClusterer;
@Override
- public void map(WritableComparable<?> key, VectorWritable point,
- OutputCollector<Text, VectorWritable> output, Reporter reporter) throws IOException {
+ public void map(WritableComparable<?> key,
+ VectorWritable point,
+ OutputCollector<Text,VectorWritable> output,
+ Reporter reporter) throws IOException {
outputCollector = output;
- canopyClusterer.addPointToCanopies(point.get(), canopies);
+ canopyClusterer.addPointToCanopies(point.get(), canopies, reporter);
}
-
+
@Override
public void configure(JobConf job) {
super.configure(job);
canopyClusterer = new CanopyClusterer(job);
}
-
+
@Override
public void close() throws IOException {
for (Canopy canopy : canopies) {
@@ -63,5 +65,5 @@
}
super.close();
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,11 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
@@ -26,34 +31,30 @@
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-
-public class CanopyReducer extends MapReduceBase implements
- Reducer<Text, VectorWritable, Text, Canopy> {
-
+public class CanopyReducer extends MapReduceBase implements Reducer<Text,VectorWritable,Text,Canopy> {
+
private final List<Canopy> canopies = new ArrayList<Canopy>();
-
+
private CanopyClusterer canopyClusterer;
@Override
- public void reduce(Text key, Iterator<VectorWritable> values,
- OutputCollector<Text, Canopy> output, Reporter reporter) throws IOException {
+ public void reduce(Text key,
+ Iterator<VectorWritable> values,
+ OutputCollector<Text,Canopy> output,
+ Reporter reporter) throws IOException {
while (values.hasNext()) {
Vector point = values.next().get();
- canopyClusterer.addPointToCanopies(point, canopies);
+ canopyClusterer.addPointToCanopies(point, canopies, reporter);
}
for (Canopy canopy : canopies) {
output.collect(new Text(canopy.getIdentifier()), canopy);
}
}
-
+
@Override
public void configure(JobConf job) {
super.configure(job);
canopyClusterer = new CanopyClusterer(job);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -42,58 +44,54 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-
-public class ClusterDriver {
-
- private static final Logger log = LoggerFactory.getLogger(ClusterDriver.class);
-
+public final class ClusterDriver {
+
public static final String DEFAULT_CLUSTER_OUTPUT_DIRECTORY = "/clusters";
-
- private ClusterDriver() {
- }
-
+
+ private static final Logger log = LoggerFactory.getLogger(ClusterDriver.class);
+
+ private ClusterDriver() { }
+
public static void main(String[] args) throws IOException {
-
+
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
- abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
- withDescription("The Vector implementation class name. Default is RandomAccessSparseVector.class")
- .withShortName("v").create();
+ abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v")
+ .create();
Option t1Opt = obuilder.withLongName("t1").withRequired(true).withArgument(
- abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).
- withDescription("t1").withShortName("t1").create();
+ abuilder.withName("t1").withMinimum(1).withMaximum(1).create()).withDescription("t1").withShortName(
+ "t1").create();
Option t2Opt = obuilder.withLongName("t2").withRequired(true).withArgument(
- abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).
- withDescription("t2").withShortName("t2").create();
-
+ abuilder.withName("t2").withMinimum(1).withMaximum(1).create()).withDescription("t2").withShortName(
+ "t2").create();
+
Option pointsOpt = obuilder.withLongName("points").withRequired(true).withArgument(
- abuilder.withName("points").withMinimum(1).withMaximum(1).create()).
- withDescription("The path containing the points").withShortName("p").create();
-
+ abuilder.withName("points").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The path containing the points").withShortName("p").create();
+
Option canopiesOpt = obuilder.withLongName("canopies").withRequired(true).withArgument(
- abuilder.withName("canopies").withMinimum(1).withMaximum(1).create()).
- withDescription("The location of the canopies, as a Path").withShortName("c").create();
-
+ abuilder.withName("canopies").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The location of the canopies, as a Path").withShortName("c").create();
+
Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
- abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
- withDescription("The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
-
+ abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
- withDescription("The Path to put the output in").withShortName("o").create();
-
- Option helpOpt = obuilder.withLongName("help").
- withDescription("Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(vectorClassOpt)
- .withOption(t1Opt).withOption(t2Opt)
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path to put the output in").withShortName("o").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(vectorClassOpt).withOption(t1Opt).withOption(t2Opt)
.withOption(pointsOpt).withOption(canopiesOpt).withOption(measureClassOpt).withOption(outputOpt)
.withOption(helpOpt).create();
-
+
try {
Parser parser = new Parser();
parser.setGroup(group);
@@ -102,7 +100,7 @@
CommandLineUtil.printHelp(group);
return;
}
-
+
String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
if (cmdLine.hasOption(measureClassOpt)) {
measureClass = cmdLine.getValue(measureClassOpt).toString();
@@ -110,57 +108,67 @@
String output = cmdLine.getValue(outputOpt).toString();
String canopies = cmdLine.getValue(canopiesOpt).toString();
String points = cmdLine.getValue(pointsOpt).toString();
- //Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
- // RandomAccessSparseVector.class
- // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+ // Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+ // RandomAccessSparseVector.class
+ // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
double t1 = Double.parseDouble(cmdLine.getValue(t1Opt).toString());
double t2 = Double.parseDouble(cmdLine.getValue(t2Opt).toString());
-
- runJob(points, canopies, output, measureClass, t1, t2);
-
+
+ ClusterDriver.runJob(points, canopies, output, measureClass, t1, t2);
+
} catch (OptionException e) {
- log.error("Exception", e);
+ ClusterDriver.log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
-
-
+
}
-
+
/**
* Run the job
- *
- * @param points the input points directory pathname String
- * @param canopies the input canopies directory pathname String
- * @param output the output directory pathname String
- * @param measureClassName the DistanceMeasure class name
- * @param t1 the T1 distance threshold
- * @param t2 the T2 distance threshold
+ *
+ * @param points
+ * the input points directory pathname String
+ * @param canopies
+ * the input canopies directory pathname String
+ * @param output
+ * the output directory pathname String
+ * @param measureClassName
+ * the DistanceMeasure class name
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
*/
- public static void runJob(String points, String canopies, String output,
- String measureClassName, double t1, double t2) throws IOException {
+ public static void runJob(String points,
+ String canopies,
+ String output,
+ String measureClassName,
+ double t1,
+ double t2) throws IOException {
Configurable client = new JobClient();
JobConf conf = new JobConf(ClusterDriver.class);
-
+
conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, measureClassName);
conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(t1));
conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(t2));
conf.set(CanopyConfigKeys.CANOPY_PATH_KEY, canopies);
-
+
conf.setInputFormat(SequenceFileInputFormat.class);
-
- /*conf.setMapOutputKeyClass(Text.class);
- conf.setMapOutputValueClass(RandomAccessSparseVector.class);*/
+
+ /*
+ * conf.setMapOutputKeyClass(Text.class); conf.setMapOutputValueClass(RandomAccessSparseVector.class);
+ */
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(VectorWritable.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(points));
- Path outPath = new Path(output + DEFAULT_CLUSTER_OUTPUT_DIRECTORY);
+ Path outPath = new Path(output + ClusterDriver.DEFAULT_CLUSTER_OUTPUT_DIRECTORY);
FileOutputFormat.setOutputPath(conf, outPath);
-
+
conf.setMapperClass(ClusterMapper.class);
conf.setReducerClass(IdentityReducer.class);
-
+
client.setConf(conf);
FileSystem dfs = FileSystem.get(outPath.toUri(), conf);
if (dfs.exists(outPath)) {
@@ -168,5 +176,5 @@
}
JobClient.runJob(conf);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,10 @@
package org.apache.mahout.clustering.canopy;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
@@ -30,39 +34,38 @@
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
public class ClusterMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
-
+ Mapper<WritableComparable<?>,VectorWritable,Text,VectorWritable> {
+
private CanopyClusterer canopyClusterer;
private final List<Canopy> canopies = new ArrayList<Canopy>();
-
+
@Override
- public void map(WritableComparable<?> key, VectorWritable point,
- OutputCollector<Text, VectorWritable> output, Reporter reporter) throws IOException {
+ public void map(WritableComparable<?> key,
+ VectorWritable point,
+ OutputCollector<Text,VectorWritable> output,
+ Reporter reporter) throws IOException {
canopyClusterer.emitPointToExistingCanopies(point.get(), canopies, output);
}
-
+
/**
* Configure the mapper by providing its canopies. Used by unit tests.
- *
- * @param canopies a List<Canopy>
+ *
+ * @param canopies
+ * a List<Canopy>
*/
public void config(List<Canopy> canopies) {
this.canopies.clear();
this.canopies.addAll(canopies);
}
-
+
@Override
public void configure(JobConf job) {
super.configure(job);
canopyClusterer = new CanopyClusterer(job);
-
+
String canopyPath = job.get(CanopyConfigKeys.CANOPY_PATH_KEY);
- if (canopyPath != null && canopyPath.length() > 0) {
+ if ((canopyPath != null) && (canopyPath.length() > 0)) {
try {
Path path = new Path(canopyPath + "/part-00000");
FileSystem fs = FileSystem.get(path.toUri(), job);
@@ -86,7 +89,7 @@
}
}
}
-
+
public boolean canopyCovers(Canopy canopy, Vector point) {
return canopyClusterer.canopyCovers(canopy, point);
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java Sat Feb 13 21:07:53 2010
@@ -31,61 +31,60 @@
import com.google.gson.reflect.TypeToken;
public class DirichletCluster<O> implements Writable, Printable {
-
+
@Override
public void readFields(DataInput in) throws IOException {
this.totalCount = in.readDouble();
- this.model = readModel(in);
+ this.model = DirichletCluster.readModel(in);
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeDouble(totalCount);
- writeModel(out, model);
+ DirichletCluster.writeModel(out, model);
}
-
+
private Model<O> model; // the model for this iteration
-
+
private double totalCount; // total count of observations for the model
-
+
public DirichletCluster(Model<O> model, double totalCount) {
super();
this.model = model;
this.totalCount = totalCount;
}
-
+
public DirichletCluster(Model<O> model) {
super();
this.model = model;
this.totalCount = 0.0;
}
-
+
public DirichletCluster() {
super();
}
-
+
public Model<O> getModel() {
return model;
}
-
+
public void setModel(Model<O> model) {
this.model = model;
this.totalCount += model.count();
}
-
+
public double getTotalCount() {
return totalCount;
}
-
- private static final Type clusterType = new TypeToken<DirichletCluster<Vector>>() {
- }.getType();
-
+
+ private static final Type clusterType = new TypeToken<DirichletCluster<Vector>>() { }.getType();
+
/** Reads a typed Model instance from the input stream */
public static <O> Model<O> readModel(DataInput in) throws IOException {
String modelClassName = in.readUTF();
Model<O> model;
try {
- model = (Model<O>) Class.forName(modelClassName).asSubclass(Model.class).newInstance();
+ model = Class.forName(modelClassName).asSubclass(Model.class).newInstance();
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
@@ -96,24 +95,24 @@
model.readFields(in);
return model;
}
-
+
/** Writes a typed Model instance to the output stream */
public static void writeModel(DataOutput out, Model<?> model) throws IOException {
out.writeUTF(model.getClass().getName());
model.write(out);
}
-
+
@Override
public String asFormatString(String[] bindings) {
return model.toString();
}
-
+
@Override
public String asJsonString() {
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(Model.class, new JsonModelAdapter());
Gson gson = builder.create();
- return gson.toJson(this, clusterType);
+ return gson.toJson(this, DirichletCluster.clusterType);
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java Sat Feb 13 21:07:53 2010
@@ -17,43 +17,61 @@
package org.apache.mahout.clustering.dirichlet;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.function.TimesFunction;
import org.apache.mahout.math.Vector;
-
-import java.util.ArrayList;
-import java.util.List;
+import org.apache.mahout.math.function.TimesFunction;
/**
- * Performs Bayesian mixture modeling. <p/> The idea is that we use a probabilistic mixture of a number of models that
- * we use to explain some observed data. The idea here is that each observed data point is assumed to have come from one
- * of the models in the mixture, but we don't know which. The way we deal with that is to use a so-called latent
- * parameter which specifies which model each data point came from. <p/> In addition, since this is a Bayesian
- * clustering algorithm, we don't want to actually commit to any single explanation, but rather to sample from the
- * distribution of models and latent assignments of data points to models given the observed data and the prior
- * distributions of model parameters. <p/> This sampling process is initialized by taking models at random from the
- * prior distribution for models. <p/> Then, we iteratively assign points to the different models using the mixture
- * probabilities and the degree of fit between the point and each model expressed as a probability that the point was
- * generated by that model. <p/> After points are assigned, new parameters for each model are sampled from the posterior
- * distribution for the model parameters considering all of the observed data points that were assigned to the model.
- * Models without any data points are also sampled, but since they have no points assigned, the new samples are
- * effectively taken from the prior distribution for model parameters. <p/> The result is a number of samples that
- * represent mixing probabilities, models and assignment of points to models. If the total number of possible models is
- * substantially larger than the number that ever have points assigned to them, then this algorithm provides a (nearly)
- * non-parametric clustering algorithm. <p/> These samples can give us interesting information that is lacking from a
- * normal clustering that consists of a single assignment of points to clusters. Firstly, by examining the number of
- * models in each sample that actually has any points assigned to it, we can get information about how many models
- * (clusters) that the data support. <p/> Morevoer, by examining how often two points are assigned to the same model, we
- * can get an approximate measure of how likely these points are to be explained by the same model. Such soft
- * membership information is difficult to come by with conventional clustering methods. <p/> Finally, we can get an idea
- * of the stability of how the data can be described. Typically, aspects of the data with lots of data available wind
- * up with stable descriptions while at the edges, there are aspects that are phenomena that we can't really commit to a
- * solid description, but it is still clear that the well supported explanations are insufficient to explain these
- * additional aspects. <p/> One thing that can be difficult about these samples is that we can't always assign a
- * correlation between the models in the different samples. Probably the best way to do this is to look for overlap in
- * the assignments of data observations to the different models. <p/>
+ * Performs Bayesian mixture modeling.
+ * <p/>
+ * The idea is that we use a probabilistic mixture of a number of models that we use to explain some observed
+ * data. The idea here is that each observed data point is assumed to have come from one of the models in the
+ * mixture, but we don't know which. The way we deal with that is to use a so-called latent parameter which
+ * specifies which model each data point came from.
+ * <p/>
+ * In addition, since this is a Bayesian clustering algorithm, we don't want to actually commit to any single
+ * explanation, but rather to sample from the distribution of models and latent assignments of data points to
+ * models given the observed data and the prior distributions of model parameters.
+ * <p/>
+ * This sampling process is initialized by taking models at random from the prior distribution for models.
+ * <p/>
+ * Then, we iteratively assign points to the different models using the mixture probabilities and the degree
+ * of fit between the point and each model expressed as a probability that the point was generated by that
+ * model.
+ * <p/>
+ * After points are assigned, new parameters for each model are sampled from the posterior distribution for
+ * the model parameters considering all of the observed data points that were assigned to the model. Models
+ * without any data points are also sampled, but since they have no points assigned, the new samples are
+ * effectively taken from the prior distribution for model parameters.
+ * <p/>
+ * The result is a number of samples that represent mixing probabilities, models and assignment of points to
+ * models. If the total number of possible models is substantially larger than the number that ever have
+ * points assigned to them, then this algorithm provides a (nearly) non-parametric clustering algorithm.
+ * <p/>
+ * These samples can give us interesting information that is lacking from a normal clustering that consists of
+ * a single assignment of points to clusters. Firstly, by examining the number of models in each sample that
+ * actually has any points assigned to it, we can get information about how many models (clusters) that the
+ * data support.
+ * <p/>
+ * Morevoer, by examining how often two points are assigned to the same model, we can get an approximate
+ * measure of how likely these points are to be explained by the same model. Such soft membership information
+ * is difficult to come by with conventional clustering methods.
+ * <p/>
+ * Finally, we can get an idea of the stability of how the data can be described. Typically, aspects of the
+ * data with lots of data available wind up with stable descriptions while at the edges, there are aspects
+ * that are phenomena that we can't really commit to a solid description, but it is still clear that the well
+ * supported explanations are insufficient to explain these additional aspects.
+ * <p/>
+ * One thing that can be difficult about these samples is that we can't always assign a correlation between
+ * the models in the different samples. Probably the best way to do this is to look for overlap in the
+ * assignments of data observations to the different models.
+ * <p/>
+ *
* <pre>
* \theta_i ~ prior()
* \lambda_i ~ Dirichlet(\alpha_0)
@@ -62,50 +80,59 @@
* </pre>
*/
public class DirichletClusterer<O> {
-
+
// observed data
private final List<O> sampleData;
-
+
// the ModelDistribution for the computation
private final ModelDistribution<O> modelFactory;
-
+
// the state of the clustering process
private final DirichletState<O> state;
-
+
private final int thin;
-
+
private final int burnin;
-
+
private final int numClusters;
-
+
private final List<Model<O>[]> clusterSamples = new ArrayList<Model<O>[]>();
-
+
/**
* Create a new instance on the sample data with the given additional parameters
- *
- * @param sampleData the observed data to be clustered
- * @param modelFactory the ModelDistribution to use
- * @param alpha_0 the double value for the beta distributions
- * @param numClusters the int number of clusters
- * @param thin the int thinning interval, used to report every n iterations
- * @param burnin the int burnin interval, used to suppress early iterations
+ *
+ * @param sampleData
+ * the observed data to be clustered
+ * @param modelFactory
+ * the ModelDistribution to use
+ * @param alpha_0
+ * the double value for the beta distributions
+ * @param numClusters
+ * the int number of clusters
+ * @param thin
+ * the int thinning interval, used to report every n iterations
+ * @param burnin
+ * the int burnin interval, used to suppress early iterations
*/
public DirichletClusterer(List<O> sampleData,
- ModelDistribution<O> modelFactory, double alpha_0,
- int numClusters, int thin, int burnin) {
+ ModelDistribution<O> modelFactory,
+ double alpha_0,
+ int numClusters,
+ int thin,
+ int burnin) {
this.sampleData = sampleData;
this.modelFactory = modelFactory;
this.thin = thin;
this.burnin = burnin;
this.numClusters = numClusters;
- state = new DirichletState<O>(modelFactory, numClusters, alpha_0,
- thin, burnin);
+ state = new DirichletState<O>(modelFactory, numClusters, alpha_0, thin, burnin);
}
-
+
/**
* Iterate over the sample data, obtaining cluster samples periodically and returning them.
- *
- * @param numIterations the int number of iterations to perform
+ *
+ * @param numIterations
+ * the int number of iterations to perform
* @return a List<List<Model<Observation>>> of the observed models
*/
public List<Model<O>[]> cluster(int numIterations) {
@@ -114,19 +141,19 @@
}
return clusterSamples;
}
-
+
/**
- * Perform one iteration of the clustering process, iterating over the samples to build a new array of models, then
- * updating the state for the next iteration
- *
- * @param state the DirichletState<Observation> of this iteration
+ * Perform one iteration of the clustering process, iterating over the samples to build a new array of
+ * models, then updating the state for the next iteration
+ *
+ * @param state
+ * the DirichletState<Observation> of this iteration
*/
private void iterate(int iteration, DirichletState<O> state) {
-
+
// create new posterior models
- Model<O>[] newModels = modelFactory.sampleFromPosterior(state
- .getModels());
-
+ Model<O>[] newModels = modelFactory.sampleFromPosterior(state.getModels());
+
// iterate over the samples, assigning each to a model
for (O x : sampleData) {
// compute normalized vector of probabilities that x is described by each model
@@ -137,7 +164,7 @@
// ask the selected model to observe the datum
newModels[k].observe(x);
}
-
+
// periodically add models to the cluster samples after the burn-in period
if ((iteration >= burnin) && (iteration % thin == 0)) {
clusterSamples.add(newModels);
@@ -145,17 +172,18 @@
// update the state from the new models
state.update(newModels);
}
-
+
/**
- * Compute a normalized vector of probabilities that x is described by each model using the mixture and the model
- * pdfs
- *
- * @param state the DirichletState<Observation> of this iteration
- * @param x an Observation
+ * Compute a normalized vector of probabilities that x is described by each model using the mixture and the
+ * model pdfs
+ *
+ * @param state
+ * the DirichletState<Observation> of this iteration
+ * @param x
+ * an Observation
* @return the Vector of probabilities
*/
- private Vector normalizedProbabilities(DirichletState<O> state,
- O x) {
+ private Vector normalizedProbabilities(DirichletState<O> state, O x) {
Vector pi = new DenseVector(numClusters);
double max = 0;
for (int k = 0; k < numClusters; k++) {
@@ -169,5 +197,5 @@
pi.assign(new TimesFunction(), 1.0 / max);
return pi;
}
-
+
}