You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 22:08:12 UTC
svn commit: r909914 [4/5] - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/
main/java/org/apache/mahout/clustering/canopy/
main/java/org/apache/mahout/clustering/dirichlet/
main/java/org/apache/mahout/clustering/dirichlet/mode...
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Feb 13 21:07:53 2010
@@ -16,6 +16,8 @@
*/
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -43,85 +45,87 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-
-public class KMeansDriver {
-
+public final class KMeansDriver {
+
/** The name of the directory used to output final results. */
public static final String DEFAULT_OUTPUT_DIRECTORY = "/points";
-
+
private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
-
- private KMeansDriver() {
- }
-
- /** @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob} */
+
+ private KMeansDriver() { }
+
+ /**
+ * @param args
+ * Expects 7 args and they all correspond to the order of the params in {@link #runJob}
+ */
public static void main(String[] args) throws Exception {
-
+
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
Option clustersOpt = obuilder
.withLongName("clusters")
.withRequired(true)
.withArgument(abuilder.withName("clusters").withMinimum(1).withMaximum(1).create())
.withDescription(
- "The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. "
- + "If k is also specified, then a random set of vectors will be selected and written out to this path first")
+ "The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. "
+ + "If k is also specified, then a random set of vectors will be selected and written out to this path first")
.withShortName("c").create();
-
+
Option kOpt = obuilder
.withLongName("k")
.withRequired(false)
.withArgument(abuilder.withName("k").withMinimum(1).withMaximum(1).create())
.withDescription(
- "The k in k-Means. If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.")
+ "The k in k-Means. If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.")
.withShortName("k").create();
-
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Path to put the output in").withShortName("o").create();
-
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path to put the output in").withShortName("o").create();
+
Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
- "If set, overwrite the output directory").withShortName("w").create();
-
+ "If set, overwrite the output directory").withShortName("w").create();
+
Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
- abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
-
+ abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+
Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
- abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).withDescription(
- "The threshold below which the clusters are considered to be converged. Default is 0.5").withShortName("d")
- .create();
-
+ abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The threshold below which the clusters are considered to be converged. Default is 0.5")
+ .withShortName("d").create();
+
Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
- abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
- "The maximum number of iterations to perform. Default is 20").withShortName("x").create();
-
+ abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The maximum number of iterations to perform. Default is 20").withShortName("x").create();
+
Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
- abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v").create();
-
+ abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Vector implementation class name. Default is RandomAccessSparseVector.class").withShortName("v")
+ .create();
+
Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
- abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of reduce tasks").withShortName("r").create();
-
- Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
-
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt)
- .withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(
- numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).withOption(overwriteOutput).withOption(
- helpOpt).create();
+ abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of reduce tasks").withShortName("r").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(
+ outputOpt).withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt)
+ .withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt)
+ .withOption(overwriteOutput).withOption(helpOpt).create();
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
@@ -137,10 +141,11 @@
if (cmdLine.hasOption(convergenceDeltaOpt)) {
convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
}
-
- //Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ? RandomAccessSparseVector.class
- // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
-
+
+ // Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+ // RandomAccessSparseVector.class
+ // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+
int maxIterations = 20;
if (cmdLine.hasOption(maxIterationsOpt)) {
maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
@@ -153,72 +158,102 @@
HadoopUtil.overwriteOutput(output);
}
if (cmdLine.hasOption(kOpt)) {
- clusters = RandomSeedGenerator
- .buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+ clusters = RandomSeedGenerator.buildRandom(input, clusters,
+ Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
}
- runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations, numReduceTasks);
+ KMeansDriver.runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations,
+ numReduceTasks);
} catch (OptionException e) {
- log.error("Exception", e);
+ KMeansDriver.log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
}
-
+
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for initial & computed clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
- * @param convergenceDelta the convergence delta value
- * @param maxIterations the maximum number of iterations
- * @param numReduceTasks the number of reducers
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param measureClass
+ * the classname of the DistanceMeasure
+ * @param convergenceDelta
+ * the convergence delta value
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param numReduceTasks
+ * the number of reducers
*/
- public static void runJob(String input, String clustersIn, String output, String measureClass,
- double convergenceDelta, int maxIterations, int numReduceTasks) {
+ public static void runJob(String input,
+ String clustersIn,
+ String output,
+ String measureClass,
+ double convergenceDelta,
+ int maxIterations,
+ int numReduceTasks) {
// iterate until the clusters converge
String delta = Double.toString(convergenceDelta);
- if (log.isInfoEnabled()) {
- log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input, clustersIn, output, measureClass});
- log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
- new Object[] {convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName()});
+ if (KMeansDriver.log.isInfoEnabled()) {
+ KMeansDriver.log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input,
+ clustersIn,
+ output,
+ measureClass});
+ KMeansDriver.log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
+ new Object[] {convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName()});
}
boolean converged = false;
int iteration = 0;
- while (!converged && iteration < maxIterations) {
- log.info("Iteration {}", iteration);
+ while (!converged && (iteration < maxIterations)) {
+ KMeansDriver.log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
String clustersOut = output + "/clusters-" + iteration;
- converged = runIteration(input, clustersIn, clustersOut, measureClass, delta, numReduceTasks, iteration);
+ converged = KMeansDriver.runIteration(input, clustersIn, clustersOut, measureClass, delta,
+ numReduceTasks, iteration);
// now point the input to the old output directory
clustersIn = output + "/clusters-" + iteration;
iteration++;
}
// now actually cluster the points
- log.info("Clustering ");
- runClustering(input, clustersIn, output + DEFAULT_OUTPUT_DIRECTORY, measureClass, delta);
+ KMeansDriver.log.info("Clustering ");
+ KMeansDriver.runClustering(input, clustersIn, output + KMeansDriver.DEFAULT_OUTPUT_DIRECTORY,
+ measureClass, delta);
}
-
+
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for input clusters
- * @param clustersOut the directory pathname for output clusters
- * @param measureClass the classname of the DistanceMeasure
- * @param convergenceDelta the convergence delta value
- * @param numReduceTasks the number of reducer tasks
- * @param iteration The iteration number
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for input clusters
+ * @param clustersOut
+ * the directory pathname for output clusters
+ * @param measureClass
+ * the classname of the DistanceMeasure
+ * @param convergenceDelta
+ * the convergence delta value
+ * @param numReduceTasks
+ * the number of reducer tasks
+ * @param iteration
+ * The iteration number
* @return true if the iteration successfully runs
*/
- private static boolean runIteration(String input, String clustersIn, String clustersOut, String measureClass,
- String convergenceDelta, int numReduceTasks, int iteration) {
+ private static boolean runIteration(String input,
+ String clustersIn,
+ String clustersOut,
+ String measureClass,
+ String convergenceDelta,
+ int numReduceTasks,
+ int iteration) {
JobConf conf = new JobConf(KMeansDriver.class);
conf.setMapOutputKeyClass(Text.class);
conf.setMapOutputValueClass(KMeansInfo.class);
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(Cluster.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(clustersOut);
FileOutputFormat.setOutputPath(conf, outPath);
@@ -236,65 +271,80 @@
try {
JobClient.runJob(conf);
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
- return isConverged(clustersOut, conf, fs);
+ return KMeansDriver.isConverged(clustersOut, conf, fs);
} catch (IOException e) {
- log.warn(e.toString(), e);
+ KMeansDriver.log.warn(e.toString(), e);
return true;
}
}
-
+
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for input clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
- * @param convergenceDelta the convergence delta value
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for input clusters
+ * @param output
+ * the directory pathname for output points
+ * @param measureClass
+ * the classname of the DistanceMeasure
+ * @param convergenceDelta
+ * the convergence delta value
*/
- private static void runClustering(String input, String clustersIn, String output, String measureClass,
- String convergenceDelta) {
- if (log.isInfoEnabled()) {
- log.info("Running Clustering");
- log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
- new Object[] {input, clustersIn, output, measureClass});
- log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class.getName());
+ private static void runClustering(String input,
+ String clustersIn,
+ String output,
+ String measureClass,
+ String convergenceDelta) {
+ if (KMeansDriver.log.isInfoEnabled()) {
+ KMeansDriver.log.info("Running Clustering");
+ KMeansDriver.log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input,
+ clustersIn,
+ output,
+ measureClass});
+ KMeansDriver.log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class
+ .getName());
}
JobConf conf = new JobConf(KMeansDriver.class);
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
-
+
conf.setMapOutputKeyClass(Text.class);
conf.setMapOutputValueClass(VectorWritable.class);
conf.setOutputKeyClass(Text.class);
// the output is the cluster id
conf.setOutputValueClass(Text.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
-
+
conf.setMapperClass(KMeansClusterMapper.class);
conf.setNumReduceTasks(0);
conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn);
conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
-
+
try {
JobClient.runJob(conf);
} catch (IOException e) {
- log.warn(e.toString(), e);
+ KMeansDriver.log.warn(e.toString(), e);
}
}
-
+
/**
* Return if all of the Clusters in the parts in the filePath have converged or not
*
- * @param filePath the file path to the single file containing the clusters
- * @param conf the JobConf
- * @param fs the FileSystem
+ * @param filePath
+ * the file path to the single file containing the clusters
+ * @param conf
+ * the JobConf
+ * @param fs
+ * the FileSystem
* @return true if all Clusters are converged
- * @throws IOException if there was an IO error
+ * @throws IOException
+ * if there was an IO error
*/
private static boolean isConverged(String filePath, JobConf conf, FileSystem fs) throws IOException {
FileStatus[] parts = fs.listStatus(new Path(filePath));
@@ -305,11 +355,11 @@
Writable key;
try {
key = (Writable) reader.getKeyClass().newInstance();
- } catch (InstantiationException e) {// shouldn't happen
- log.error("Exception", e);
+ } catch (InstantiationException e) { // shouldn't happen
+ KMeansDriver.log.error("Exception", e);
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
- log.error("Exception", e);
+ KMeansDriver.log.error("Exception", e);
throw new IllegalStateException(e);
}
Cluster value = new Cluster();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java Sat Feb 13 21:07:53 2010
@@ -17,41 +17,40 @@
package org.apache.mahout.clustering.kmeans;
-import org.apache.hadoop.io.Writable;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
-public class KMeansInfo implements Writable {
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+public class KMeansInfo implements Writable {
+
private int points;
private Vector pointTotal;
-
- public KMeansInfo() {
- }
-
+
+ public KMeansInfo() { }
+
public KMeansInfo(int points, Vector pointTotal) {
this.points = points;
this.pointTotal = pointTotal;
}
-
+
public int getPoints() {
return points;
}
-
+
public Vector getPointTotal() {
return pointTotal;
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(points);
VectorWritable.writeVector(out, pointTotal);
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
this.points = in.readInt();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Sat Feb 13 21:07:53 2010
@@ -16,6 +16,10 @@
*/
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -26,23 +30,20 @@
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
public class KMeansMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo> {
-
+ Mapper<WritableComparable<?>,VectorWritable,Text,KMeansInfo> {
+
private KMeansClusterer clusterer;
private final List<Cluster> clusters = new ArrayList<Cluster>();
-
+
@Override
- public void map(WritableComparable<?> key, VectorWritable point,
- OutputCollector<Text, KMeansInfo> output, Reporter reporter)
- throws IOException {
- this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, output);
+ public void map(WritableComparable<?> key,
+ VectorWritable point,
+ OutputCollector<Text,KMeansInfo> output,
+ Reporter reporter) throws IOException {
+ this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, output);
}
-
+
/**
* Configure the mapper by providing its clusters. Used by unit tests.
*
@@ -53,27 +54,26 @@
this.clusters.clear();
this.clusters.addAll(clusters);
}
-
+
@Override
public void configure(JobConf job) {
super.configure(job);
try {
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- Class<?> cl = ccl.loadClass(job
- .get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+ Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
DistanceMeasure measure = (DistanceMeasure) cl.newInstance();
measure.configure(job);
-
+
this.clusterer = new KMeansClusterer(measure);
-
+
String clusterPath = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
- if (clusterPath != null && clusterPath.length() > 0) {
+ if ((clusterPath != null) && (clusterPath.length() > 0)) {
KMeansUtil.configureWithClusterInfo(clusterPath, clusters);
if (clusters.isEmpty()) {
throw new IllegalStateException("Cluster is empty!");
}
}
-
+
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Sat Feb 13 21:07:53 2010
@@ -16,14 +16,6 @@
*/
package org.apache.mahout.clustering.kmeans;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.common.distance.DistanceMeasure;
-
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
@@ -31,18 +23,27 @@
import java.util.List;
import java.util.Map;
-public class KMeansReducer extends MapReduceBase implements
- Reducer<Text, KMeansInfo, Text, Cluster> {
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.common.distance.DistanceMeasure;
- private Map<String, Cluster> clusterMap;
+public class KMeansReducer extends MapReduceBase implements Reducer<Text,KMeansInfo,Text,Cluster> {
+
+ private Map<String,Cluster> clusterMap;
private double convergenceDelta;
private DistanceMeasure measure;
-
+
@Override
- public void reduce(Text key, Iterator<KMeansInfo> values,
- OutputCollector<Text, Cluster> output, Reporter reporter) throws IOException {
+ public void reduce(Text key,
+ Iterator<KMeansInfo> values,
+ OutputCollector<Text,Cluster> output,
+ Reporter reporter) throws IOException {
Cluster cluster = clusterMap.get(key.toString());
-
+
while (values.hasNext()) {
KMeansInfo delta = values.next();
cluster.addPoints(delta.getPoints(), delta.getPointTotal());
@@ -54,23 +55,21 @@
}
output.collect(new Text(cluster.getIdentifier()), cluster);
}
-
+
@Override
public void configure(JobConf job) {
-
+
super.configure(job);
try {
ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- Class<?> cl = ccl.loadClass(job
- .get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+ Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
this.measure = (DistanceMeasure) cl.newInstance();
this.measure.configure(job);
-
- this.convergenceDelta = Double.parseDouble(job
- .get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
-
- this.clusterMap = new HashMap<String, Cluster>();
-
+
+ this.convergenceDelta = Double.parseDouble(job.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
+
+ this.clusterMap = new HashMap<String,Cluster>();
+
String path = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
if (path.length() > 0) {
List<Cluster> clusters = new ArrayList<Cluster>();
@@ -88,18 +87,18 @@
throw new IllegalStateException(e);
}
}
-
+
private void setClusterMap(List<Cluster> clusters) {
- clusterMap = new HashMap<String, Cluster>();
+ clusterMap = new HashMap<String,Cluster>();
for (Cluster cluster : clusters) {
clusterMap.put(cluster.getIdentifier(), cluster);
}
clusters.clear();
}
-
+
public void config(List<Cluster> clusters) {
setClusterMap(clusters);
-
+
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,10 @@
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
@@ -30,26 +34,20 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
final class KMeansUtil {
-
+
private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class);
-
- private KMeansUtil() {
- }
-
+
+ private KMeansUtil() { }
+
/** Configure the mapper with the cluster info */
- public static void configureWithClusterInfo(String clusterPathStr,
- List<Cluster> clusters) {
-
+ public static void configureWithClusterInfo(String clusterPathStr, List<Cluster> clusters) {
+
// Get the path location where the cluster Info is stored
JobConf job = new JobConf(KMeansUtil.class);
Path clusterPath = new Path(clusterPathStr + "/*");
List<Path> result = new ArrayList<Path>();
-
+
// filter out the files
PathFilter clusterFileFilter = new PathFilter() {
@Override
@@ -57,17 +55,17 @@
return path.getName().startsWith("part");
}
};
-
+
try {
// get all filtered file names in result list
FileSystem fs = clusterPath.getFileSystem(job);
- FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
- clusterPath, clusterFileFilter)), clusterFileFilter);
-
+ FileStatus[] matches = fs.listStatus(
+ FileUtil.stat2Paths(fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
+
for (FileStatus match : matches) {
result.add(fs.makeQualified(match.getPath()));
}
-
+
// iterate thru the result path list
for (Path path : result) {
SequenceFile.Reader reader = null;
@@ -77,11 +75,11 @@
Writable key;
try {
key = (Writable) reader.getKeyClass().newInstance();
- } catch (InstantiationException e) {//Should not be possible
- log.error("Exception", e);
+ } catch (InstantiationException e) { // Should not be possible
+ KMeansUtil.log.error("Exception", e);
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
- log.error("Exception", e);
+ KMeansUtil.log.error("Exception", e);
throw new IllegalStateException(e);
}
if (valueClass.equals(Cluster.class)) {
@@ -104,11 +102,11 @@
IOUtils.quietClose(reader);
}
}
-
+
} catch (IOException e) {
- log.info("Exception occurred in loading clusters:", e);
+ KMeansUtil.log.info("Exception occurred in loading clusters:", e);
throw new IllegalStateException(e);
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,11 @@
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -29,28 +34,23 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-
-
/**
- * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and write them
- * to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing the initial centroid to use.
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and
+ * write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing the
+ * initial centroid to use.
* <p/>
*/
public final class RandomSeedGenerator {
-
+
private static final Logger log = LoggerFactory.getLogger(RandomSeedGenerator.class);
-
+
public static final String K = "k";
-
- private RandomSeedGenerator() {
- }
-
- public static Path buildRandom(String input, String output,
- int k) throws IOException, IllegalAccessException, InstantiationException {
+
+ private RandomSeedGenerator() { }
+
+ public static Path buildRandom(String input, String output, int k) throws IOException,
+ IllegalAccessException,
+ InstantiationException {
// delete the output directory
JobConf conf = new JobConf(RandomSeedGenerator.class);
Path outPath = new Path(output);
@@ -61,7 +61,7 @@
fs.mkdirs(outPath);
Path outFile = new Path(outPath, "part-randomSeed");
if (fs.exists(outFile)) {
- log.warn("Deleting {}", outFile);
+ RandomSeedGenerator.log.warn("Deleting {}", outFile);
fs.delete(outFile, false);
}
boolean newFile = fs.createNewFile(outFile);
@@ -83,7 +83,9 @@
int nextClusterId = 0;
for (FileStatus fileStatus : inputFiles) {
- if(fileStatus.isDir() == true) continue; // select only the top level files
+ if (fileStatus.isDir() == true) {
+ continue; // select only the top level files
+ }
SequenceFile.Reader reader = new SequenceFile.Reader(fs, fileStatus.getPath(), conf);
Writable key = (Writable) reader.getKeyClass().newInstance();
VectorWritable value = (VectorWritable) reader.getValueClass().newInstance();
@@ -109,10 +111,10 @@
for (int i = 0; i < k; i++) {
writer.append(chosenTexts.get(i), chosenClusters.get(i));
}
- log.info("Wrote {} vectors to {}", k, outFile);
+ RandomSeedGenerator.log.info("Wrote {} vectors to {}", k, outFile);
writer.close();
}
-
+
return outFile;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java Sat Feb 13 21:07:53 2010
@@ -26,50 +26,49 @@
import org.apache.hadoop.io.WritableComparator;
/**
-* Saves two ints, x and y.
-*/
+ * Saves two ints, x and y.
+ */
public class IntPairWritable implements WritableComparable<IntPairWritable> {
-
+
private int x;
private int y;
-
+
/** For serialization purposes only */
- public IntPairWritable() {
- }
-
+ public IntPairWritable() { }
+
public IntPairWritable(int x, int y) {
this.x = x;
this.y = y;
}
-
+
public void setX(int x) {
this.x = x;
}
-
+
public int getX() {
return x;
}
-
+
public void setY(int y) {
this.y = y;
}
-
+
public int getY() {
return y;
}
-
+
@Override
public void write(DataOutput dataOutput) throws IOException {
dataOutput.writeInt(x);
dataOutput.writeInt(y);
}
-
+
@Override
public void readFields(DataInput dataInput) throws IOException {
x = dataInput.readInt();
y = dataInput.readInt();
}
-
+
@Override
public int compareTo(IntPairWritable that) {
if (this.x < that.getX()) {
@@ -80,51 +79,52 @@
return this.y < that.getY() ? -1 : this.y > that.getY() ? 1 : 0;
}
}
-
+
+ @Override
public boolean equals(Object o) {
- if (this == o) {
+ if (this == o) {
return true;
} else if (!(o instanceof IntPairWritable)) {
return false;
}
-
+
IntPairWritable that = (IntPairWritable) o;
-
- return that.getX() == this.x && this.y == that.getY();
+
+ return (that.getX() == this.x) && (this.y == that.getY());
}
-
+
@Override
public int hashCode() {
return 43 * x + y;
}
-
+
@Override
public String toString() {
return "(" + x + ", " + y + ')';
}
-
+
static {
WritableComparator.define(IntPairWritable.class, new Comparator());
}
-
+
public static class Comparator extends WritableComparator implements Serializable {
public Comparator() {
super(IntPairWritable.class);
}
-
+
@Override
public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
if (l1 != 8) {
throw new IllegalArgumentException();
}
- int int11 = readInt(b1, s1);
- int int21 = readInt(b2, s2);
+ int int11 = WritableComparator.readInt(b1, s1);
+ int int21 = WritableComparator.readInt(b2, s2);
if (int11 != int21) {
return int11 - int21;
}
-
- int int12 = readInt(b1, s1 + 4);
- int int22 = readInt(b2, s2 + 4);
+
+ int int12 = WritableComparator.readInt(b1, s1 + 4);
+ int int22 = WritableComparator.readInt(b2, s2 + 4);
return int12 - int22;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sat Feb 13 21:07:53 2010
@@ -39,117 +39,114 @@
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * Estimates an LDA model from a corpus of documents,
- * which are SparseVectors of word counts. At each
- * phase, it outputs a matrix of log probabilities of
- * each topic.
+ * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
+ * it outputs a matrix of log probabilities of each topic.
*/
public final class LDADriver {
-
+
static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
-
+
static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
-
+
static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
-
+
static final int LOG_LIKELIHOOD_KEY = -2;
static final int TOPIC_SUM_KEY = -1;
-
+
static final double OVERALL_CONVERGENCE = 1.0E-5;
-
+
private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
-
- private LDADriver() {
- }
-
- public static void main(String[] args) throws ClassNotFoundException,
- IOException, InterruptedException {
-
+
+ private LDADriver() { }
+
+ public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
+
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Output Working Directory").withShortName("o").create();
-
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Output Working Directory").withShortName("o").create();
+
Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
- "If set, overwrite the output directory").withShortName("w").create();
-
+ "If set, overwrite the output directory").withShortName("w").create();
+
Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
- abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of topics").withShortName("k").create();
-
+ abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of topics").withShortName("k").create();
+
Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
- abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
- "The total number of words in the corpus").withShortName("v").create();
-
- Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
- .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
- "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
-
+ abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The total number of words in the corpus").withShortName("v").create();
+
+ Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(
+ abuilder.withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create())
+ .withDescription("Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+
Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
- abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
- "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
-
+ abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+
Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
- abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
- "Max iterations to run (or until convergence). Default 10").create();
-
- Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
-
+ abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create())
+ .withDescription("Max iterations to run (or until convergence). Default 10").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
- topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
- numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
+ topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(numReducOpt)
+ .withOption(overwriteOutput).withOption(helpOpt).create();
try {
Parser parser = new Parser();
parser.setGroup(group);
CommandLine cmdLine = parser.parse(args);
-
+
if (cmdLine.hasOption(helpOpt)) {
CommandLineUtil.printHelp(group);
return;
}
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
-
+
int maxIterations = -1;
if (cmdLine.hasOption(maxIterOpt)) {
maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
}
-
+
int numReduceTasks = 2;
if (cmdLine.hasOption(numReducOpt)) {
numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
}
-
+
int numTopics = 20;
if (cmdLine.hasOption(topicsOpt)) {
numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
}
-
+
int numWords = 20;
if (cmdLine.hasOption(wordsOpt)) {
numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
}
-
+
if (cmdLine.hasOption(overwriteOutput)) {
HadoopUtil.overwriteOutput(output);
}
-
+
double topicSmoothing = -1.0;
if (cmdLine.hasOption(topicSmOpt)) {
topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
@@ -157,72 +154,81 @@
if (topicSmoothing < 1) {
topicSmoothing = 50.0 / numTopics;
}
-
- runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
- numReduceTasks);
-
+
+ LDADriver.runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks);
+
} catch (OptionException e) {
- log.error("Exception", e);
+ LDADriver.log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
}
-
+
/**
* Run the job using supplied arguments
- *
- * @param input the directory pathname for input points
- * @param output the directory pathname for output points
- * @param numTopics the number of topics
- * @param numWords the number of words
- * @param topicSmoothing pseudocounts for each topic, typically small < .5
- * @param maxIterations the maximum number of iterations
- * @param numReducers the number of Reducers desired
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param output
+ * the directory pathname for output points
+ * @param numTopics
+ * the number of topics
+ * @param numWords
+ * the number of words
+ * @param topicSmoothing
+ * pseudocounts for each topic, typically small < .5
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param numReducers
+ * the number of Reducers desired
* @throws IOException
*/
- public static void runJob(String input, String output, int numTopics,
- int numWords, double topicSmoothing, int maxIterations, int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
-
+ public static void runJob(String input,
+ String output,
+ int numTopics,
+ int numWords,
+ double topicSmoothing,
+ int maxIterations,
+ int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+
String stateIn = output + "/state-0";
- writeInitialState(stateIn, numTopics, numWords);
+ LDADriver.writeInitialState(stateIn, numTopics, numWords);
double oldLL = Double.NEGATIVE_INFINITY;
boolean converged = false;
-
- for (int iteration = 0; (maxIterations < 1 || iteration < maxIterations) && !converged; iteration++) {
- log.info("Iteration {}", iteration);
+
+ for (int iteration = 0; ((maxIterations < 1) || (iteration < maxIterations)) && !converged; iteration++) {
+ LDADriver.log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
String stateOut = output + "/state-" + (iteration + 1);
- double ll = runIteration(input, stateIn, stateOut, numTopics,
- numWords, topicSmoothing, numReducers);
+ double ll = LDADriver.runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing,
+ numReducers);
double relChange = (oldLL - ll) / oldLL;
-
+
// now point the input to the old output directory
- log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
- log.info("(Old LL: {})", oldLL);
- log.info("(Rel Change: {})", relChange);
-
- converged = iteration > 2 && relChange < OVERALL_CONVERGENCE;
+ LDADriver.log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
+ LDADriver.log.info("(Old LL: {})", oldLL);
+ LDADriver.log.info("(Rel Change: {})", relChange);
+
+ converged = (iteration > 2) && (relChange < LDADriver.OVERALL_CONVERGENCE);
stateIn = stateOut;
oldLL = ll;
}
}
-
- private static void writeInitialState(String statePath,
- int numTopics, int numWords) throws IOException {
+
+ private static void writeInitialState(String statePath, int numTopics, int numWords) throws IOException {
Path dir = new Path(statePath);
Configuration job = new Configuration();
FileSystem fs = dir.getFileSystem(job);
-
+
IntPairWritable kw = new IntPairWritable();
DoubleWritable v = new DoubleWritable();
-
+
Random random = RandomUtils.getRandom();
-
+
for (int k = 0; k < numTopics; ++k) {
Path path = new Path(dir, "part-" + k);
- SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
- IntPairWritable.class, DoubleWritable.class);
-
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class,
+ DoubleWritable.class);
+
kw.setX(k);
double total = 0.0; // total number of pseudo counts we made
for (int w = 0; w < numWords; ++w) {
@@ -233,64 +239,75 @@
v.set(Math.log(pseudocount));
writer.append(kw, v);
}
-
- kw.setY(TOPIC_SUM_KEY);
+
+ kw.setY(LDADriver.TOPIC_SUM_KEY);
v.set(Math.log(total));
writer.append(kw, v);
-
+
writer.close();
}
}
-
+
private static double findLL(String statePath, Configuration job) throws IOException {
Path dir = new Path(statePath);
FileSystem fs = dir.getFileSystem(job);
-
+
double ll = 0.0;
-
+
IntPairWritable key = new IntPairWritable();
DoubleWritable value = new DoubleWritable();
for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
Path path = status.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
while (reader.next(key, value)) {
- if (key.getX() == LOG_LIKELIHOOD_KEY) {
+ if (key.getX() == LDADriver.LOG_LIKELIHOOD_KEY) {
ll = value.get();
break;
}
}
reader.close();
}
-
+
return ll;
}
-
+
/**
* Run the job using supplied arguments
- *
- * @param input the directory pathname for input points
- * @param stateIn the directory pathname for input state
- * @param stateOut the directory pathname for output state
- * @param numTopics the number of clusters
- * @param numReducers the number of Reducers desired
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param stateIn
+ * the directory pathname for input state
+ * @param stateOut
+ * the directory pathname for output state
+ * @param numTopics
+ * the number of clusters
+ * @param numReducers
+ * the number of Reducers desired
*/
- public static double runIteration(String input, String stateIn,
- String stateOut, int numTopics, int numWords, double topicSmoothing,
- int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+ public static double runIteration(String input,
+ String stateIn,
+ String stateOut,
+ int numTopics,
+ int numWords,
+ double topicSmoothing,
+ int numReducers) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
Configuration conf = new Configuration();
- conf.set(STATE_IN_KEY, stateIn);
- conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
- conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
- conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
-
+ conf.set(LDADriver.STATE_IN_KEY, stateIn);
+ conf.set(LDADriver.NUM_TOPICS_KEY, Integer.toString(numTopics));
+ conf.set(LDADriver.NUM_WORDS_KEY, Integer.toString(numWords));
+ conf.set(LDADriver.TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+
Job job = new Job(conf);
-
+
job.setOutputKeyClass(IntPairWritable.class);
job.setOutputValueClass(DoubleWritable.class);
FileInputFormat.addInputPaths(job, input);
Path outPath = new Path(stateOut);
FileOutputFormat.setOutputPath(job, outPath);
-
+
job.setMapperClass(LDAMapper.class);
job.setReducerClass(LDAReducer.class);
job.setCombinerClass(LDAReducer.class);
@@ -298,24 +315,24 @@
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setJarByClass(LDADriver.class);
-
+
job.waitForCompletion(true);
- return findLL(stateOut, conf);
+ return LDADriver.findLL(stateOut, conf);
}
-
+
static LDAState createState(Configuration job) throws IOException {
- String statePath = job.get(STATE_IN_KEY);
- int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
- int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
- double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
-
+ String statePath = job.get(LDADriver.STATE_IN_KEY);
+ int numTopics = Integer.parseInt(job.get(LDADriver.NUM_TOPICS_KEY));
+ int numWords = Integer.parseInt(job.get(LDADriver.NUM_WORDS_KEY));
+ double topicSmoothing = Double.parseDouble(job.get(LDADriver.TOPIC_SMOOTHING_KEY));
+
Path dir = new Path(statePath);
FileSystem fs = dir.getFileSystem(job);
-
+
DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
double[] logTotals = new double[numTopics];
double ll = 0.0;
-
+
IntPairWritable key = new IntPairWritable();
DoubleWritable value = new DoubleWritable();
for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
@@ -324,15 +341,15 @@
while (reader.next(key, value)) {
int topic = key.getX();
int word = key.getY();
- if (word == TOPIC_SUM_KEY) {
+ if (word == LDADriver.TOPIC_SUM_KEY) {
logTotals[topic] = value.get();
if (Double.isInfinite(value.get())) {
throw new IllegalArgumentException();
}
- } else if (topic == LOG_LIKELIHOOD_KEY) {
+ } else if (topic == LDADriver.LOG_LIKELIHOOD_KEY) {
ll = value.get();
} else {
- if (!(topic >= 0 && word >= 0)) {
+ if (!((topic >= 0) && (word >= 0))) {
throw new IllegalArgumentException(topic + " " + word);
}
if (pWgT.getQuick(topic, word) != 0.0) {
@@ -346,8 +363,7 @@
}
reader.close();
}
-
- return new LDAState(numTopics, numWords, topicSmoothing,
- pWgT, logTotals, ll);
+
+ return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Sat Feb 13 21:07:53 2010
@@ -22,157 +22,151 @@
import java.util.Map;
import org.apache.commons.math.special.Gamma;
-import org.apache.mahout.math.function.BinaryFunction;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
/**
- * Class for performing infererence on a document, which involves
- * computing (an approximation to) p(word|topic) for each word and
- * topic, and a prior distribution p(topic) for each topic.
+ * Class for performing infererence on a document, which involves computing (an approximation to)
+ * p(word|topic) for each word and topic, and a prior distribution p(topic) for each topic.
*/
public class LDAInference {
-
+
private static final double E_STEP_CONVERGENCE = 1.0E-6;
private static final int MAX_ITER = 20;
-
+
public LDAInference(LDAState state) {
this.state = state;
}
-
+
/**
- * An estimate of the probabilitys for each document.
- * Gamma(k) is the probability of seeing topic k in
- * the document, phi(k,w) is the probability of
- * topic k generating w in this document.
- */
+ * An estimate of the probabilitys for each document. Gamma(k) is the probability of seeing topic k in the
+ * document, phi(k,w) is the probability of topic k generating w in this document.
+ */
public static class InferredDocument {
-
+
private final Vector wordCounts;
private final Vector gamma; // p(topic)
private final Matrix mphi; // log p(columnMap(w)|t)
- private final Map<Integer, Integer> columnMap; // maps words into the matrix's column map
+ private final Map<Integer,Integer> columnMap; // maps words into the matrix's column map
public final double logLikelihood;
-
+
public double phi(int k, int w) {
return mphi.getQuick(k, columnMap.get(w));
}
-
- InferredDocument(Vector wordCounts, Vector gamma,
- Map<Integer, Integer> columnMap, Matrix phi,
- double ll) {
+
+ InferredDocument(Vector wordCounts, Vector gamma, Map<Integer,Integer> columnMap, Matrix phi, double ll) {
this.wordCounts = wordCounts;
this.gamma = gamma;
this.mphi = phi;
this.columnMap = columnMap;
this.logLikelihood = ll;
}
-
+
public Vector getWordCounts() {
return wordCounts;
}
-
+
public Vector getGamma() {
return gamma;
}
}
-
+
/**
- * Performs inference on the given document, returning
- * an InferredDocument.
- */
+ * Performs inference on the given document, returning an InferredDocument.
+ */
public InferredDocument infer(Vector wordCounts) {
double docTotal = wordCounts.zSum();
int docLength = wordCounts.size();
-
+
// initialize variational approximation to p(z|doc)
Vector gamma = new DenseVector(state.numTopics);
gamma.assign(state.topicSmoothing + docTotal / state.numTopics);
Vector nextGamma = new DenseVector(state.numTopics);
-
+
DenseMatrix phi = new DenseMatrix(state.numTopics, docLength);
-
+
// digamma is expensive, precompute
- Vector digammaGamma = digamma(gamma);
+ Vector digammaGamma = LDAInference.digamma(gamma);
// and log normalize:
- double digammaSumGamma = digamma(gamma.zSum());
+ double digammaSumGamma = LDAInference.digamma(gamma.zSum());
digammaGamma = digammaGamma.plus(-digammaSumGamma);
-
- Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
-
+
+ Map<Integer,Integer> columnMap = new HashMap<Integer,Integer>();
+
int iteration = 0;
-
+
boolean converged = false;
double oldLL = 1;
- while (!converged && iteration < MAX_ITER) {
+ while (!converged && (iteration < LDAInference.MAX_ITER)) {
nextGamma.assign(state.topicSmoothing); // nG := alpha, for all topics
-
+
int mapping = 0;
- for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
- iter.hasNext();) {
- Vector.Element e = iter.next();
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+ Vector.Element e = iter.next();
int word = e.index();
Vector phiW = eStepForWord(word, digammaGamma);
phi.assignColumn(mapping, phiW);
if (iteration == 0) { // first iteration
columnMap.put(word, mapping);
}
-
+
for (int k = 0; k < nextGamma.size(); ++k) {
double g = nextGamma.getQuick(k);
nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
}
-
+
mapping++;
}
-
+
Vector tempG = gamma;
gamma = nextGamma;
nextGamma = tempG;
-
+
// digamma is expensive, precompute
- digammaGamma = digamma(gamma);
+ digammaGamma = LDAInference.digamma(gamma);
// and log normalize:
- digammaSumGamma = digamma(gamma.zSum());
+ digammaSumGamma = LDAInference.digamma(gamma.zSum());
digammaGamma = digammaGamma.plus(-digammaSumGamma);
-
+
double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
assert !Double.isNaN(ll);
- converged = oldLL < 0 && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
-
+ converged = (oldLL < 0) && ((oldLL - ll) / oldLL < LDAInference.E_STEP_CONVERGENCE);
+
oldLL = ll;
iteration++;
}
-
+
return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
}
-
+
private final LDAState state;
-
- private double computeLikelihood(Vector wordCounts, Map<Integer, Integer> columnMap,
- Matrix phi, Vector gamma, Vector digammaGamma) {
+
+ private double computeLikelihood(Vector wordCounts,
+ Map<Integer,Integer> columnMap,
+ Matrix phi,
+ Vector gamma,
+ Vector digammaGamma) {
double ll = 0.0;
-
+
// log normalizer for q(gamma);
ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
-
+
// now for the the rest of q(gamma);
for (int k = 0; k < state.numTopics; ++k) {
ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
ll += Gamma.logGamma(gamma.get(k));
-
+
}
ll -= Gamma.logGamma(gamma.zSum());
assert !Double.isNaN(ll);
-
-
+
// for each word
- for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
- iter.hasNext();) {
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
double n = e.get();
@@ -181,19 +175,18 @@
for (int k = 0; k < state.numTopics; k++) {
double llPart = 0.0;
llPart += Math.exp(phi.get(k, mapping))
- * (digammaGamma.get(k) - phi.get(k, mapping)
- + state.logProbWordGivenTopic(w, k));
-
+ * (digammaGamma.get(k) - phi.get(k, mapping) + state.logProbWordGivenTopic(w, k));
+
ll += llPart * n;
-
- assert state.logProbWordGivenTopic(w, k) < 0;
+
+ assert state.logProbWordGivenTopic(w, k) < 0;
assert !Double.isNaN(llPart);
}
}
assert ll <= 0;
return ll;
}
-
+
/**
* Compute log q(k|w,doc) for each topic k, for a given word.
*/
@@ -203,7 +196,7 @@
for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
-
+
assert !Double.isNaN(phiTotal);
assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
@@ -211,57 +204,53 @@
}
return phi.plus(-phiTotal); // log normalize
}
-
-
+
private static Vector digamma(Vector v) {
Vector digammaGamma = new DenseVector(v.size());
digammaGamma.assign(v, new BinaryFunction() {
@Override
public double apply(double unused, double g) {
- return digamma(g);
+ return LDAInference.digamma(g);
}
});
return digammaGamma;
}
-
+
/**
* Approximation to the digamma function, from Radford Neal.
- *
- * Original License:
- * Copyright (c) 1995-2003 by Radford M. Neal
- *
- * Permission is granted for anyone to copy, use, modify, or distribute this
- * program and accompanying programs and documents for any purpose, provided
- * this copyright notice is retained and prominently displayed, along with
- * a note saying that the original programs are available from Radford Neal's
- * web page, and note is made of any changes made to the programs. The
- * programs and documents are distributed without any warranty, express or
- * implied. As the programs were written for research purposes only, they have
- * not been tested to the degree that would be advisable in any important
- * application. All use of these programs is entirely at the user's own risk.
- *
- *
+ *
+ * Original License: Copyright (c) 1995-2003 by Radford M. Neal
+ *
+ * Permission is granted for anyone to copy, use, modify, or distribute this program and accompanying
+ * programs and documents for any purpose, provided this copyright notice is retained and prominently
+ * displayed, along with a note saying that the original programs are available from Radford Neal's web
+ * page, and note is made of any changes made to the programs. The programs and documents are distributed
+ * without any warranty, express or implied. As the programs were written for research purposes only, they
+ * have not been tested to the degree that would be advisable in any important application. All use of these
+ * programs is entirely at the user's own risk.
+ *
+ *
* Ported to Java for Mahout.
- *
+ *
*/
private static double digamma(double x) {
double r = 0.0;
-
+
while (x <= 5) {
r -= 1 / x;
x += 1;
}
-
+
double f = 1.0 / (x * x);
- double t = f * (-1 / 12.0
- + f * (1 / 120.0
- + f * (-1 / 252.0
- + f * (1 / 240.0
- + f * (-1 / 132.0
- + f * (691 / 32760.0
- + f * (-1 / 12.0
- + f * 3617.0 / 8160.0)))))));
+ double t = f
+ * (-1 / 12.0 + f
+ * (1 / 120.0 + f
+ * (-1 / 252.0 + f
+ * (1 / 240.0 + f
+ * (-1 / 132.0 + f
+ * (691 / 32760.0 + f
+ * (-1 / 12.0 + f * 3617.0 / 8160.0)))))));
return r + Math.log(x) - 0.5 / x + t;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Sat Feb 13 21:07:53 2010
@@ -29,39 +29,35 @@
import org.apache.mahout.math.VectorWritable;
/**
-* Runs inference on the input documents (which are
-* sparse vectors of word counts) and outputs
-* the sufficient statistics for the word-topic
-* assignments.
-*/
-public class LDAMapper extends
- Mapper<WritableComparable<?>, VectorWritable, IntPairWritable, DoubleWritable> {
-
+ * Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient
+ * statistics for the word-topic assignments.
+ */
+public class LDAMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> {
+
private LDAState state;
private LDAInference infer;
-
+
@Override
- public void map(WritableComparable<?> key, VectorWritable wordCountsWritable, Context context)
- throws IOException, InterruptedException {
+ public void map(WritableComparable<?> key, VectorWritable wordCountsWritable, Context context) throws IOException,
+ InterruptedException {
Vector wordCounts = wordCountsWritable.get();
LDAInference.InferredDocument doc = infer.infer(wordCounts);
-
+
double[] logTotals = new double[state.numTopics];
Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
-
+
// Output sufficient statistics for each word. == pseudo-log counts.
IntPairWritable kw = new IntPairWritable();
DoubleWritable v = new DoubleWritable();
- for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
- iter.hasNext();) {
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
kw.setY(w);
for (int k = 0; k < state.numTopics; ++k) {
v.set(doc.phi(k, w) + Math.log(e.get()));
-
+
kw.setX(k);
-
+
// ouput (topic, word)'s logProb contribution
context.write(kw, v);
logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
@@ -77,19 +73,19 @@
assert !Double.isNaN(v.get());
context.write(kw, v);
}
-
+
// Output log-likelihoods.
kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
v.set(doc.logLikelihood);
context.write(kw, v);
}
-
+
public void configure(LDAState myState) {
this.state = myState;
this.infer = new LDAInference(state);
}
-
+
public void configure(Configuration job) {
try {
LDAState myState = LDADriver.createState(job);
@@ -98,11 +94,10 @@
throw new IllegalStateException("Error creating LDA State!", e);
}
}
-
+
@Override
protected void setup(Context context) {
configure(context.getConfiguration());
}
-
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Sat Feb 13 21:07:53 2010
@@ -19,20 +19,16 @@
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.mapreduce.Reducer;
-
/**
-* A very simple reducer which simply logSums the
-* input doubles and outputs a new double for sufficient
-* statistics, and sums log likelihoods.
-*/
-public class LDAReducer extends
- Reducer<IntPairWritable, DoubleWritable, IntPairWritable, DoubleWritable> {
-
+ * A very simple reducer which simply logSums the input doubles and outputs a new double for sufficient
+ * statistics, and sums log likelihoods.
+ */
+public class LDAReducer extends Reducer<IntPairWritable,DoubleWritable,IntPairWritable,DoubleWritable> {
+
@Override
- public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values,
- Context context)
- throws java.io.IOException, InterruptedException {
-
+ public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values, Context context) throws java.io.IOException,
+ InterruptedException {
+
// sum likelihoods
if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
double accum = 0.0;
@@ -58,7 +54,7 @@
}
context.write(topicWord, new DoubleWritable(accum));
}
-
+
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Sat Feb 13 21:07:53 2010
@@ -19,15 +19,19 @@
import org.apache.mahout.math.Matrix;
public class LDAState {
- public final int numTopics;
- public final int numWords;
+ public final int numTopics;
+ public final int numWords;
public final double topicSmoothing;
private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
public final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
-
- public LDAState(int numTopics, int numWords, double topicSmoothing,
- Matrix topicWordProbabilities, double[] logTotals, double ll) {
+
+ public LDAState(int numTopics,
+ int numWords,
+ double topicSmoothing,
+ Matrix topicWordProbabilities,
+ double[] logTotals,
+ double ll) {
this.numWords = numWords;
this.numTopics = numTopics;
this.topicSmoothing = topicSmoothing;
@@ -35,10 +39,9 @@
this.logTotals = logTotals;
this.logLikelihood = ll;
}
-
+
public double logProbWordGivenTopic(int word, int topic) {
double logProb = topicWordProbabilities.getQuick(topic, word);
- return logProb == Double.NEGATIVE_INFINITY ? -100.0
- : logProb - logTotals[topic];
+ return logProb == Double.NEGATIVE_INFINITY ? -100.0 : logProb - logTotals[topic];
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java Sat Feb 13 21:07:53 2010
@@ -20,17 +20,14 @@
* Various utility classes for doing LDA inference..
*/
final class LDAUtil {
- private LDAUtil() {
- } // no creation
-
+ private LDAUtil() { } // no creation
+
/**
* @return log(exp(a) + exp(b))
*/
static double logSum(double a, double b) {
- return (a == Double.NEGATIVE_INFINITY) ? b
- : (b == Double.NEGATIVE_INFINITY) ? a
- : (a < b) ? b + Math.log(1 + Math.exp(a - b))
- : a + Math.log(1 + Math.exp(b - a));
+ return a == Double.NEGATIVE_INFINITY ? b : b == Double.NEGATIVE_INFINITY ? a
+ : a < b ? b + Math.log(1 + Math.exp(a - b)) : a + Math.log(1 + Math.exp(b - a));
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Sat Feb 13 21:07:53 2010
@@ -17,68 +17,63 @@
package org.apache.mahout.clustering.meanshift;
-import com.google.gson.Gson;
-import com.google.gson.GsonBuilder;
-import com.google.gson.reflect.TypeToken;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.JsonVectorAdapter;
-import static org.apache.mahout.math.function.Functions.*;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-import java.lang.reflect.Type;
-import java.util.ArrayList;
-import java.util.List;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
/**
- * This class models a canopy as a center point, the number of points that are contained within it according to the
- * application of some distance metric, and a point total which is the sum of all the points and is used to compute the
- * centroid when needed.
+ * This class models a canopy as a center point, the number of points that are contained within it according
+ * to the application of some distance metric, and a point total which is the sum of all the points and is
+ * used to compute the centroid when needed.
*/
public class MeanShiftCanopy extends ClusterBase {
-
- // TODO: this is problematic, but how else to encode membership?
+
+ // TODO: this is problematic, but how else to encode membership?
private List<Vector> boundPoints = new ArrayList<Vector>();
-
+
private boolean converged = false;
-
+
public MeanShiftCanopy() {
super();
}
-
+
/** Create a new Canopy with the given canopyId */
/*
- public MeanShiftCanopy(String id) {
- this.setId(Integer.parseInt(id.substring(1)));
- this.setCenter(null);
- this.setPointTotal(null);
- this.setNumPoints(0);
- }
- */
+ * public MeanShiftCanopy(String id) { this.setId(Integer.parseInt(id.substring(1))); this.setCenter(null);
+ * this.setPointTotal(null); this.setNumPoints(0); }
+ */
/**
* Create a new Canopy containing the given point
- *
- * @param point a Vector
+ *
+ * @param point
+ * a Vector
*/
/*
- public MeanShiftCanopy(Vector point) {
- this.setCenter(point);
- this.setPointTotal(point.clone());
- this.setNumPoints(1);
- this.boundPoints.add(point);
- }
- */
+ * public MeanShiftCanopy(Vector point) { this.setCenter(point); this.setPointTotal(point.clone());
+ * this.setNumPoints(1); this.boundPoints.add(point); }
+ */
/**
* Create a new Canopy containing the given point
- *
- * @param point a Vector
+ *
+ * @param point
+ * a Vector
*/
public MeanShiftCanopy(Vector point, int id) {
this.setId(id);
@@ -90,14 +85,17 @@
/**
* Create a new Canopy containing the given point, id and bound points
- *
- * @param point a Vector
- * @param id an int identifying the canopy local to this process only
- * @param boundPoints a List<Vector> containing points bound to the canopy
- * @param converged true if the canopy has converged
+ *
+ * @param point
+ * a Vector
+ * @param id
+ * an int identifying the canopy local to this process only
+ * @param boundPoints
+ * a List<Vector> containing points bound to the canopy
+ * @param converged
+ * true if the canopy has converged
*/
- MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints,
- boolean converged) {
+ MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints, boolean converged) {
this.setId(id);
this.setCenter(point);
this.setPointTotal(point.clone());
@@ -105,36 +103,39 @@
this.boundPoints = boundPoints;
this.converged = converged;
}
-
+
/**
* Add a point to the canopy some number of times
- *
- * @param point a Vector to add
- * @param nPoints the number of times to add the point
- * @throws CardinalityException if the cardinalities disagree
+ *
+ * @param point
+ * a Vector to add
+ * @param nPoints
+ * the number of times to add the point
+ * @throws CardinalityException
+ * if the cardinalities disagree
*/
void addPoints(Vector point, int nPoints) {
setNumPoints(getNumPoints() + nPoints);
- Vector subTotal = (nPoints == 1) ? point.clone() : point.times(nPoints);
- setPointTotal((getPointTotal() == null) ? subTotal : getPointTotal().plus(subTotal));
+ Vector subTotal = nPoints == 1 ? point.clone() : point.times(nPoints);
+ setPointTotal(getPointTotal() == null ? subTotal : getPointTotal().plus(subTotal));
}
-
+
/**
* Compute the bound centroid by averaging the bound points
- *
+ *
* @return a Vector which is the new bound centroid
*/
public Vector computeBoundCentroid() {
Vector result = new DenseVector(getCenter().size());
for (Vector v : boundPoints) {
- result.assign(v, plus);
+ result.assign(v, Functions.plus);
}
return result.divide(boundPoints.size());
}
-
+
/**
* Compute the centroid by normalizing the pointTotal
- *
+ *
* @return a Vector which is the new centroid
*/
@Override
@@ -145,55 +146,57 @@
return getPointTotal().divide(getNumPoints());
}
}
-
+
public List<Vector> getBoundPoints() {
return boundPoints;
}
-
+
public int getCanopyId() {
return getId();
}
-
+
@Override
public String getIdentifier() {
return (converged ? "V" : "C") + getId();
}
-
+
void init(MeanShiftCanopy canopy) {
setId(canopy.getId());
setCenter(canopy.getCenter());
addPoints(getCenter(), 1);
boundPoints.addAll(canopy.getBoundPoints());
}
-
+
public boolean isConverged() {
return converged;
}
-
+
/**
* The receiver overlaps the given canopy. Touch it and add my bound points to it.
- *
- * @param canopy an existing MeanShiftCanopy
+ *
+ * @param canopy
+ * an existing MeanShiftCanopy
*/
void merge(MeanShiftCanopy canopy) {
boundPoints.addAll(canopy.boundPoints);
}
-
+
@Override
public String toString() {
- return formatCanopy(this);
+ return MeanShiftCanopy.formatCanopy(this);
}
-
+
/**
* The receiver touches the given canopy. Add respective centers.
- *
- * @param canopy an existing MeanShiftCanopy
+ *
+ * @param canopy
+ * an existing MeanShiftCanopy
*/
void touch(MeanShiftCanopy canopy) {
canopy.addPoints(getCenter(), boundPoints.size());
addPoints(canopy.getCenter(), canopy.boundPoints.size());
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
super.readFields(in);
@@ -207,7 +210,7 @@
this.boundPoints.add(temp.get());
}
}
-
+
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
@@ -217,7 +220,7 @@
VectorWritable.writeVector(out, v);
}
}
-
+
public MeanShiftCanopy shallowCopy() {
MeanShiftCanopy result = new MeanShiftCanopy();
result.setId(this.getId());
@@ -227,43 +230,42 @@
result.boundPoints = this.boundPoints;
return result;
}
-
+
@Override
public String asFormatString() {
- return formatCanopy(this);
+ return MeanShiftCanopy.formatCanopy(this);
}
public void setBoundPoints(List<Vector> boundPoints) {
this.boundPoints = boundPoints;
}
-
+
public void setConverged(boolean converged) {
this.converged = converged;
}
-
+
/** Format the canopy for output */
public static String formatCanopy(MeanShiftCanopy canopy) {
- Type vectorType = new TypeToken<Vector>() {
- }.getType();
+ Type vectorType = new TypeToken<Vector>() { }.getType();
GsonBuilder gBuilder = new GsonBuilder();
gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
Gson gson = gBuilder.create();
return gson.toJson(canopy, MeanShiftCanopy.class);
}
-
+
/**
* Decodes and returns a Canopy from the formattedString
- *
- * @param formattedString a String produced by formatCanopy
+ *
+ * @param formattedString
+ * a String produced by formatCanopy
* @return a new Canopy
*/
public static MeanShiftCanopy decodeCanopy(String formattedString) {
- Type vectorType = new TypeToken<Vector>() {
- }.getType();
+ Type vectorType = new TypeToken<Vector>() { }.getType();
GsonBuilder gBuilder = new GsonBuilder();
gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
Gson gson = gBuilder.create();
return gson.fromJson(formattedString, MeanShiftCanopy.class);
}
-
+
}