You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/04/21 22:35:23 UTC
svn commit: r936489 [1/2] - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/
core/src/main/java/org/apache/mahout/clustering/canopy/
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahou...
Author: jeastman
Date: Wed Apr 21 20:35:22 2010
New Revision: 936489
URL: http://svn.apache.org/viewvc?rev=936489&view=rev
Log:
MAHOUT-236: fixed a couple of compile errors caused by DummyCollector changes. Otherwise same as jira patch. All tests run. Committing b4 Sean makes another patch :)
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDistantPointWritable.java
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java
lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/
lucene/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/DummyOutputCollector.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/ga/watchmaker/EvalMapperTest.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
lucene/mahout/trunk/examples/src/test/java/org/apache/mahout/ga/watchmaker/cd/hadoop/CDMapperTest.java
lucene/mahout/trunk/examples/src/test/java/org/apache/mahout/ga/watchmaker/cd/hadoop/CDReducerTest.java
lucene/mahout/trunk/examples/src/test/java/org/apache/mahout/ga/watchmaker/cd/tool/ToolMapperTest.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterBase.java Wed Apr 21 20:35:22 2010
@@ -43,78 +43,79 @@ public abstract class ClusterBase implem
// this cluster's clusterId
private int id;
-
+
// the current cluster center
private Vector center = new RandomAccessSparseVector(0);
-
+
// the number of points in the cluster
private int numPoints;
-
+
// the Vector total of all points added to the cluster
private Vector pointTotal;
-
+
@Override
public int getId() {
return id;
}
-
+
public void setId(int id) {
this.id = id;
}
-
+
@Override
public Vector getCenter() {
return center;
}
-
+
public void setCenter(Vector center) {
this.center = center;
}
-
+
@Override
public int getNumPoints() {
return numPoints;
}
-
+
public void setNumPoints(int numPoints) {
this.numPoints = numPoints;
}
-
+
public Vector getPointTotal() {
return pointTotal;
}
-
+
public void setPointTotal(Vector pointTotal) {
this.pointTotal = pointTotal;
}
-
+
/**
* @deprecated
* @return
*/
@Deprecated
public abstract String asFormatString();
-
+
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder();
buf.append(getIdentifier()).append(": ").append(formatVector(computeCentroid(), bindings));
return buf.toString();
}
-
+
public abstract Vector computeCentroid();
-
+
public abstract Object getIdentifier();
-
+
@Override
public String asJsonString() {
- Type vectorType = new TypeToken<Vector>() { }.getType();
+ Type vectorType = new TypeToken<Vector>() {
+ }.getType();
GsonBuilder gBuilder = new GsonBuilder();
gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
Gson gson = gBuilder.create();
return gson.toJson(this, this.getClass());
}
-
+
/**
* Simply writes out the id, and that's it!
*
@@ -125,13 +126,13 @@ public abstract class ClusterBase implem
public void write(DataOutput out) throws IOException {
out.writeInt(id);
}
-
+
/** Reads in the id, nothing else */
@Override
public void readFields(DataInput in) throws IOException {
id = in.readInt();
}
-
+
/**
* Return a human-readable formatted string representation of the vector, not intended to be complete nor
* usable as an input/output representation such as Json
@@ -171,7 +172,9 @@ public abstract class ClusterBase implem
buf.append(String.format(Locale.ENGLISH, "%.3f", elem)).append(", ");
}
}
- buf.setLength(buf.length() - 2);
+ if (buf.length() > 1) {
+ buf.setLength(buf.length() - 2);
+ }
buf.append(']');
return buf.toString();
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Wed Apr 21 20:35:22 2010
@@ -22,6 +22,7 @@ import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
@@ -164,7 +165,7 @@ public class CanopyClusterer {
*/
public void emitPointToExistingCanopies(Vector point,
List<Canopy> canopies,
- OutputCollector<Text,VectorWritable> collector,
+ OutputCollector<IntWritable,VectorWritable> collector,
Reporter reporter) throws IOException {
double minDist = Double.MAX_VALUE;
Canopy closest = null;
@@ -174,7 +175,7 @@ public class CanopyClusterer {
if (dist < t1) {
isCovered = true;
VectorWritable vw = new VectorWritable(point);
- collector.collect(new Text(canopy.getIdentifier()), vw);
+ collector.collect(new IntWritable(canopy.getId()), vw);
reporter.setStatus("Emit Canopy ID:" + canopy.getIdentifier());
} else if (dist < minDist) {
minDist = dist;
@@ -184,8 +185,7 @@ public class CanopyClusterer {
// if the point is not contained in any canopies (due to canopy centroid
// clustering), emit the point to the closest covering canopy.
if (!isCovered) {
- VectorWritable vw = new VectorWritable(point);
- collector.collect(new Text(closest.getIdentifier()), vw);
+ collector.collect(new IntWritable(closest.getId()), new VectorWritable(point));
reporter.setStatus("Emit Closest Canopy ID:" + closest.getIdentifier());
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java Wed Apr 21 20:35:22 2010
@@ -30,7 +30,7 @@ import org.apache.commons.cli2.commandli
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
@@ -154,11 +154,7 @@ public final class ClusterDriver {
conf.set(CanopyConfigKeys.CANOPY_PATH_KEY, canopies);
conf.setInputFormat(SequenceFileInputFormat.class);
-
- /*
- * conf.setMapOutputKeyClass(Text.class); conf.setMapOutputValueClass(RandomAccessSparseVector.class);
- */
- conf.setOutputKeyClass(Text.class);
+ conf.setOutputKeyClass(IntWritable.class);
conf.setOutputValueClass(VectorWritable.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
@@ -168,6 +164,7 @@ public final class ClusterDriver {
conf.setMapperClass(ClusterMapper.class);
conf.setReducerClass(IdentityReducer.class);
+ conf.setNumReduceTasks(0);
client.setConf(conf);
FileSystem dfs = FileSystem.get(outPath.toUri(), conf);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java Wed Apr 21 20:35:22 2010
@@ -23,6 +23,7 @@ import java.util.List;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
@@ -35,7 +36,7 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
public class ClusterMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>,VectorWritable,Text,VectorWritable> {
+ Mapper<WritableComparable<?>,VectorWritable,IntWritable,VectorWritable> {
private CanopyClusterer canopyClusterer;
private final List<Canopy> canopies = new ArrayList<Canopy>();
@@ -43,7 +44,7 @@ public class ClusterMapper extends MapRe
@Override
public void map(WritableComparable<?> key,
VectorWritable point,
- OutputCollector<Text,VectorWritable> output,
+ OutputCollector<IntWritable,VectorWritable> output,
Reporter reporter) throws IOException {
canopyClusterer.emitPointToExistingCanopies(point.get(), canopies, output, reporter);
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java?rev=936489&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java Wed Apr 21 20:35:22 2010
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.dirichlet;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.OutputLogFilter;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.math.VectorWritable;
+
+public class DirichletClusterMapper extends MapReduceBase implements
+ Mapper<WritableComparable<?>, VectorWritable, IntWritable, VectorWritable> {
+
+ private OutputCollector<IntWritable, VectorWritable> output;
+
+ private List<DirichletCluster> clusters;
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void map(WritableComparable<?> key, VectorWritable vector, OutputCollector<IntWritable, VectorWritable> output,
+ Reporter reporter) throws IOException {
+ int clusterId = -1;
+ double clusterPdf = 0;
+ for (int i = 0; i < clusters.size(); i++) {
+ double pdf = clusters.get(i).getModel().pdf(vector);
+ if (pdf > clusterPdf) {
+ clusterId = i;
+ clusterPdf = pdf;
+ }
+ }
+ System.out.println(clusterId + ": " + ClusterBase.formatVector(vector.get(), null));
+ output.collect(new IntWritable(clusterId), vector);
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ super.configure(job);
+ try {
+ clusters = getClusters(job);
+ } catch (SecurityException e) {
+ throw new IllegalStateException(e);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalStateException(e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public static List<DirichletCluster> getClusters(JobConf job) throws SecurityException, IllegalArgumentException,
+ NoSuchMethodException, InvocationTargetException {
+ String statePath = job.get(DirichletDriver.STATE_IN_KEY);
+ List<DirichletCluster> clusters = new ArrayList<DirichletCluster>();
+ try {
+ Path path = new Path(statePath);
+ FileSystem fs = FileSystem.get(path.toUri(), job);
+ FileStatus[] status = fs.listStatus(path, new OutputLogFilter());
+ for (FileStatus s : status) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, s.getPath(), job);
+ try {
+ Text key = new Text();
+ DirichletCluster cluster = new DirichletCluster();
+ while (reader.next(key, cluster)) {
+ clusters.add(cluster);
+ cluster = new DirichletCluster();
+ }
+ } finally {
+ reader.close();
+ }
+ }
+ return clusters;
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Wed Apr 21 20:35:22 2010
@@ -32,6 +32,7 @@ import org.apache.commons.cli2.commandli
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
@@ -41,7 +42,9 @@ import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.mahout.clustering.dirichlet.models.VectorModelDistribution;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.meanshift.MeanShiftCanopyClusterMapper;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Vector;
@@ -183,6 +186,10 @@ public class DirichletDriver {
* the directory pathname for output points
* @param modelFactory
* the String ModelDistribution class name to use
+ * @param modelPrototype
+ * the String class name of the model prototype
+ * @param prototypeSize
+ * the int size of the prototype to use
* @param numClusters
* the number of models
* @param maxIterations
@@ -220,6 +227,8 @@ public class DirichletDriver {
// now point the input to the old output directory
stateIn = stateOut;
}
+ // now cluster the most likely points
+ runClustering(input, stateIn, output + "/clusters");
}
private static void writeInitialState(String output,
@@ -363,20 +372,25 @@ public class DirichletDriver {
* the directory pathname for output points
*/
public static void runClustering(String input, String stateIn, String output) {
- Configurable client = new JobClient();
JobConf conf = new JobConf(DirichletDriver.class);
+ conf.setJobName("Dirichlet Clustering");
- conf.setOutputKeyClass(Text.class);
- conf.setOutputValueClass(Text.class);
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
- conf.setMapperClass(DirichletMapper.class);
- conf.setNumReduceTasks(0);
+ conf.setMapperClass(DirichletClusterMapper.class);
- client.setConf(conf);
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+
+ // uncomment it to run locally
+ // conf.set("mapred.job.tracker", "local");
+ conf.setNumReduceTasks(0);
+ conf.set(STATE_IN_KEY, stateIn);
try {
JobClient.runJob(conf);
} catch (IOException e) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java Wed Apr 21 20:35:22 2010
@@ -31,21 +31,22 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.math.VectorWritable;
public class DirichletReducer extends MapReduceBase implements
- Reducer<Text,VectorWritable,Text,DirichletCluster<VectorWritable>> {
-
+ Reducer<Text, VectorWritable, Text, DirichletCluster<VectorWritable>> {
+
private DirichletState<VectorWritable> state;
-
+
private Model<VectorWritable>[] newModels;
-
+
+ private OutputCollector<Text, DirichletCluster<VectorWritable>> output;
+
public Model<VectorWritable>[] getNewModels() {
return newModels;
}
-
+
@Override
- public void reduce(Text key,
- Iterator<VectorWritable> values,
- OutputCollector<Text,DirichletCluster<VectorWritable>> output,
- Reporter reporter) throws IOException {
+ public void reduce(Text key, Iterator<VectorWritable> values, OutputCollector<Text, DirichletCluster<VectorWritable>> output,
+ Reporter reporter) throws IOException {
+ this.output = output;
int k = Integer.parseInt(key.toString());
Model<VectorWritable> model = newModels[k];
while (values.hasNext()) {
@@ -55,14 +56,25 @@ public class DirichletReducer extends Ma
model.computeParameters();
DirichletCluster<VectorWritable> cluster = state.getClusters().get(k);
cluster.setModel(model);
- output.collect(key, cluster);
}
-
+
+ /* (non-Javadoc)
+ * @see org.apache.hadoop.mapred.MapReduceBase#close()
+ */
+ @Override
+ public void close() throws IOException {
+ for (int i = 0; i < state.getNumClusters(); i++) {
+ DirichletCluster cluster = state.getClusters().get(i);
+ output.collect(new Text(String.valueOf(i)), cluster);
+ }
+ super.close();
+ }
+
public void configure(DirichletState<VectorWritable> state) {
this.state = state;
this.newModels = state.getModelFactory().sampleFromPosterior(state.getModels());
}
-
+
@Override
public void configure(JobConf job) {
super.configure(job);
@@ -81,5 +93,5 @@ public class DirichletReducer extends Ma
}
this.newModels = state.getModelFactory().sampleFromPosterior(state.getModels());
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java Wed Apr 21 20:35:22 2010
@@ -49,7 +49,7 @@ public class AsymmetricSampledNormalDist
for (int j = 0; j < prototype.size(); j++) {
sd.set(j, UncommonDistributions.rNorm(1, 1));
}
- result[i] = new AsymmetricSampledNormalModel(mean, sd);
+ result[i] = new AsymmetricSampledNormalModel(i, mean, sd);
}
return result;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Wed Apr 21 20:35:22 2010
@@ -56,8 +56,9 @@ public class AsymmetricSampledNormalMode
super();
}
- public AsymmetricSampledNormalModel(Vector mean, Vector stdDev) {
+ public AsymmetricSampledNormalModel(int id, Vector mean, Vector stdDev) {
super();
+ this.id = id;
this.mean = mean;
this.stdDev = stdDev;
this.s0 = 0;
@@ -79,7 +80,7 @@ public class AsymmetricSampledNormalMode
* @return an AsymmetricSampledNormalModel
*/
AsymmetricSampledNormalModel sample() {
- return new AsymmetricSampledNormalModel(mean, stdDev);
+ return new AsymmetricSampledNormalModel(id, mean, stdDev);
}
@Override
@@ -166,6 +167,7 @@ public class AsymmetricSampledNormalMode
@Override
public void readFields(DataInput in) throws IOException {
+ this.id = in.readInt();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
this.mean = temp.get();
@@ -180,6 +182,7 @@ public class AsymmetricSampledNormalMode
@Override
public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
VectorWritable.writeVector(out, mean);
VectorWritable.writeVector(out, stdDev);
out.writeInt(s0);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Wed Apr 21 20:35:22 2010
@@ -42,7 +42,8 @@ public class L1Model implements Model<Ve
super();
}
- public L1Model(Vector v) {
+ public L1Model(int id, Vector v) {
+ this.id = id;
observed = v.like();
coefficients = v;
}
@@ -79,20 +80,23 @@ public class L1Model implements Model<Ve
@Override
public void readFields(DataInput in) throws IOException {
- count = in.readInt();
+ this.id = in.readInt();
+ this.count = in.readInt();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
- coefficients = temp.get();
+ this.coefficients = temp.get();
+ this.observed = coefficients.like();
}
@Override
public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
out.writeInt(count);
VectorWritable.writeVector(out, coefficients);
}
public L1Model sample() {
- return new L1Model(coefficients.clone());
+ return new L1Model(id, coefficients.clone());
}
@Override
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java Wed Apr 21 20:35:22 2010
@@ -39,7 +39,7 @@ public class L1ModelDistribution extends
Model<VectorWritable>[] result = new L1Model[howMany];
for (int i = 0; i < howMany; i++) {
Vector prototype = getModelPrototype().get();
- result[i] = new L1Model(prototype.like());
+ result[i] = new L1Model(i, prototype.like());
}
return result;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Wed Apr 21 20:35:22 2010
@@ -55,7 +55,8 @@ public class NormalModel implements Mode
public NormalModel() { }
- public NormalModel(Vector mean, double stdDev) {
+ public NormalModel(int id, Vector mean, double stdDev) {
+ this.id = id;
this.mean = mean;
this.stdDev = stdDev;
this.s0 = 0;
@@ -81,7 +82,7 @@ public class NormalModel implements Mode
* @return an NormalModel
*/
public NormalModel sample() {
- return new NormalModel(mean, stdDev);
+ return new NormalModel(id, mean, stdDev);
}
@Override
@@ -147,6 +148,7 @@ public class NormalModel implements Mode
@Override
public void readFields(DataInput in) throws IOException {
+ this.id = in.readInt();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
this.mean = temp.get();
@@ -160,6 +162,7 @@ public class NormalModel implements Mode
@Override
public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
VectorWritable.writeVector(out, mean);
out.writeDouble(stdDev);
out.writeInt(s0);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java Wed Apr 21 20:35:22 2010
@@ -39,7 +39,7 @@ public class NormalModelDistribution ext
Model<VectorWritable>[] result = new NormalModel[howMany];
for (int i = 0; i < howMany; i++) {
Vector prototype = getModelPrototype().get();
- result[i] = new NormalModel(prototype.like(), 1);
+ result[i] = new NormalModel(i, prototype.like(), 1);
}
return result;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java Wed Apr 21 20:35:22 2010
@@ -47,7 +47,7 @@ public class SampledNormalDistribution e
}
Vector mean = prototype.like();
mean.assign(m);
- result[i] = new SampledNormalModel(mean, 1);
+ result[i] = new SampledNormalModel(i, mean, 1);
}
return result;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java Wed Apr 21 20:35:22 2010
@@ -28,8 +28,8 @@ public class SampledNormalModel extends
super();
}
- public SampledNormalModel(Vector mean, double sd) {
- super(mean, sd);
+ public SampledNormalModel(int id, Vector mean, double sd) {
+ super(id, mean, sd);
}
@Override
@@ -44,7 +44,7 @@ public class SampledNormalModel extends
*/
@Override
public NormalModel sample() {
- return new SampledNormalModel(getMean(), getStdDev());
+ return new SampledNormalModel(getId(), getMean(), getStdDev());
}
@Override
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java Wed Apr 21 20:35:22 2010
@@ -21,6 +21,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -32,19 +33,18 @@ import org.apache.mahout.math.NamedVecto
import org.apache.mahout.math.VectorWritable;
public class FuzzyKMeansClusterMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>,VectorWritable,Text,FuzzyKMeansOutput> {
-
+ Mapper<WritableComparable<?>, VectorWritable, IntWritable, VectorWritable> {
+
private final List<SoftCluster> clusters = new ArrayList<SoftCluster>();
+
private FuzzyKMeansClusterer clusterer;
-
+
@Override
- public void map(WritableComparable<?> key,
- VectorWritable point,
- OutputCollector<Text,FuzzyKMeansOutput> output,
- Reporter reporter) throws IOException {
+ public void map(WritableComparable<?> key, VectorWritable point, OutputCollector<IntWritable, VectorWritable> output,
+ Reporter reporter) throws IOException {
clusterer.outputPointWithClusterProbabilities(key.toString(), (NamedVector) point.get(), clusters, output);
}
-
+
/**
* Configure the mapper by providing its clusters. Used by unit tests.
*
@@ -55,21 +55,21 @@ public class FuzzyKMeansClusterMapper ex
this.clusters.clear();
this.clusters.addAll(clusters);
}
-
+
@Override
public void configure(JobConf job) {
-
+
super.configure(job);
clusterer = new FuzzyKMeansClusterer(job);
-
+
String clusterPath = job.get(FuzzyKMeansConfigKeys.CLUSTER_PATH_KEY);
if ((clusterPath != null) && (clusterPath.length() > 0)) {
FuzzyKMeansUtil.configureWithClusterInfo(clusterPath, clusters);
}
-
+
if (clusters.isEmpty()) {
throw new IllegalStateException("Cluster is empty!!!");
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Wed Apr 21 20:35:22 2010
@@ -21,23 +21,26 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
public class FuzzyKMeansClusterer {
-
+
private static final double MINIMAL_VALUE = 0.0000000001;
-
+
private DistanceMeasure measure;
-
+
private double convergenceDelta;
-
+
private double m = 2.0; // default value
-
+
/**
* Init the fuzzy k-means clusterer with the distance measure to use for comparison.
*
@@ -52,11 +55,11 @@ public class FuzzyKMeansClusterer {
this.convergenceDelta = convergenceDelta;
this.m = m;
}
-
+
public FuzzyKMeansClusterer(JobConf job) {
this.configure(job);
}
-
+
/**
* Configure the distance measure from the job
*
@@ -80,7 +83,7 @@ public class FuzzyKMeansClusterer {
throw new IllegalStateException(e);
}
}
-
+
/**
* Emit the point and its probability of belongingness to each cluster
*
@@ -91,15 +94,14 @@ public class FuzzyKMeansClusterer {
* @param output
* the OutputCollector to emit into
*/
- public void emitPointProbToCluster(Vector point,
- List<SoftCluster> clusters,
- OutputCollector<Text,FuzzyKMeansInfo> output) throws IOException {
-
+ public void emitPointProbToCluster(Vector point, List<SoftCluster> clusters, OutputCollector<Text, FuzzyKMeansInfo> output)
+ throws IOException {
+
List<Double> clusterDistanceList = new ArrayList<Double>();
for (SoftCluster cluster : clusters) {
clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
}
-
+
for (int i = 0; i < clusters.size(); i++) {
double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
Text key = new Text(clusters.get(i).getIdentifier());
@@ -112,7 +114,7 @@ public class FuzzyKMeansClusterer {
output.collect(key, value);
}
}
-
+
/**
* Output point with cluster info (Cluster and probability)
*
@@ -123,24 +125,38 @@ public class FuzzyKMeansClusterer {
* @param output
* the OutputCollector to emit into
*/
- public void outputPointWithClusterProbabilities(String key,
- NamedVector point,
- List<SoftCluster> clusters,
- OutputCollector<Text,FuzzyKMeansOutput> output) throws IOException {
-
- List<Double> clusterDistanceList = new ArrayList<Double>();
-
+ public void outputPointWithClusterProbabilities(String key, Vector point, List<SoftCluster> clusters,
+ OutputCollector<IntWritable, VectorWritable> output) throws IOException {
+
+ // TODO: remove this later
+ // List<Double> clusterDistanceList = new ArrayList<Double>();
+ //
+ // for (SoftCluster cluster : clusters) {
+ // clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
+ // }
+ // FuzzyKMeansOutput fOutput = new FuzzyKMeansOutput(clusters.size());
+ // for (int i = 0; i < clusters.size(); i++) {
+ // double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+ // fOutput.add(i, clusters.get(i), probWeight);
+ // }
+ // String name = point.getName();
+
+ // for now just emit the closest cluster
+ int clusterId = -1;
+ double distance = Double.MAX_VALUE;
for (SoftCluster cluster : clusters) {
- clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
- }
- FuzzyKMeansOutput fOutput = new FuzzyKMeansOutput(clusters.size());
- for (int i = 0; i < clusters.size(); i++) {
- double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- fOutput.add(i, clusters.get(i), probWeight);
+ Vector center = cluster.getCenter();
+ // System.out.println("cluster-" + cluster.getId() + "@ " + ClusterBase.formatVector(center, null));
+ double d = measure.distance(center, point);
+ if (d < distance) {
+ clusterId = cluster.getId();
+ distance = d;
+ }
}
- output.collect(new Text(point.getName()), fOutput);
+ // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null));
+ output.collect(new IntWritable(clusterId), new VectorWritable(point));
}
-
+
/** Computes the probability of a point belonging to a cluster */
public double computeProbWeight(double clusterDistance, List<Double> clusterDistanceList) {
if (clusterDistance == 0) {
@@ -155,7 +171,7 @@ public class FuzzyKMeansClusterer {
}
return 1.0 / denom;
}
-
+
/**
* Return if the cluster is converged by comparing its center and centroid.
*
@@ -166,15 +182,15 @@ public class FuzzyKMeansClusterer {
cluster.setConverged(measure.distance(cluster.getCenter(), centroid) <= convergenceDelta);
return cluster.isConverged();
}
-
+
public double getM() {
return m;
}
-
+
public DistanceMeasure getMeasure() {
return this.measure;
}
-
+
/**
* This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
* until their centers converge or until the maximum number of iterations is exceeded.
@@ -210,7 +226,7 @@ public class FuzzyKMeansClusterer {
}
return clustersList;
}
-
+
/**
* Perform a single iteration over the points and clusters, assigning points to clusters and returning if
* the iterations are completed.
@@ -224,13 +240,12 @@ public class FuzzyKMeansClusterer {
public static boolean runFuzzyKMeansIteration(List<NamedVector> points,
List<SoftCluster> clusterList,
FuzzyKMeansClusterer clusterer) {
- // for each
for (Vector point : points) {
List<Double> clusterDistanceList = new ArrayList<Double>();
for (SoftCluster cluster : clusterList) {
clusterDistanceList.add(clusterer.getMeasure().distance(point, cluster.getCenter()));
}
-
+
for (int i = 0; i < clusterList.size(); i++) {
double probWeight = clusterer.computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
clusterList.get(i).addPoint(point, Math.pow(probWeight, clusterer.getM()));
@@ -249,6 +264,6 @@ public class FuzzyKMeansClusterer {
}
}
return converged;
-
+
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java Wed Apr 21 20:35:22 2010
@@ -36,6 +36,7 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
@@ -345,10 +346,8 @@ public final class FuzzyKMeansDriver {
JobConf conf = new JobConf(FuzzyKMeansDriver.class);
conf.setJobName("Fuzzy K Means Clustering");
- conf.setMapOutputKeyClass(Text.class);
- conf.setMapOutputValueClass(VectorWritable.class);
- conf.setOutputKeyClass(Text.class);
- conf.setOutputValueClass(FuzzyKMeansOutput.class);
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java Wed Apr 21 20:35:22 2010
@@ -29,15 +29,17 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.kmeans.Cluster;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FuzzyKMeansUtil {
private static final Logger log = LoggerFactory.getLogger(FuzzyKMeansUtil.class);
-
- private FuzzyKMeansUtil() { }
-
+
+ private FuzzyKMeansUtil() {
+ }
+
/** Configure the mapper with the cluster info */
public static void configureWithClusterInfo(String clusterPathStr, List<SoftCluster> clusters) {
// Get the path location where the cluster Info is stored
@@ -52,17 +54,16 @@ class FuzzyKMeansUtil {
return path.getName().startsWith("part");
}
};
-
+
try {
// get all filtered file names in result list
FileSystem fs = clusterPath.getFileSystem(job);
- FileStatus[] matches = fs.listStatus(
- FileUtil.stat2Paths(fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
-
+ FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
+
for (FileStatus match : matches) {
result.add(fs.makeQualified(match.getPath()));
}
-
+
// iterate thru the result path list
for (Path path : result) {
// RecordReader<Text, Text> recordReader = null;
@@ -96,16 +97,24 @@ class FuzzyKMeansUtil {
clusters.add(value);
value = new SoftCluster();
}
+ } else if (valueClass.equals(Canopy.class)) {
+ Canopy value = new Canopy();
+ while (reader.next(key, value)) {
+ // get the cluster info
+ SoftCluster theCluster = new SoftCluster(value.getCenter(), value.getId());
+ clusters.add(theCluster);
+ value = new Canopy();
+ }
}
} finally {
reader.close();
}
}
-
+
} catch (IOException e) {
log.info("Exception occurred in loading clusters:", e);
throw new IllegalStateException(e);
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java Wed Apr 21 20:35:22 2010
@@ -21,7 +21,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
-import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
@@ -33,7 +33,7 @@ import org.apache.mahout.math.NamedVecto
import org.apache.mahout.math.VectorWritable;
public class KMeansClusterMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>,VectorWritable,Text,Text> {
+ Mapper<WritableComparable<?>,VectorWritable,IntWritable,VectorWritable> {
private final List<Cluster> clusters = new ArrayList<Cluster>();
private KMeansClusterer clusterer;
@@ -41,7 +41,7 @@ public class KMeansClusterMapper extends
@Override
public void map(WritableComparable<?> key,
VectorWritable point,
- OutputCollector<Text,Text> output,
+ OutputCollector<IntWritable,VectorWritable> output,
Reporter reporter) throws IOException {
clusterer.outputPointWithClusterInfo((NamedVector) point.get(), clusters, output);
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Wed Apr 21 20:35:22 2010
@@ -20,11 +20,13 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -81,7 +83,7 @@ public class KMeansClusterer {
public void outputPointWithClusterInfo(NamedVector point,
List<Cluster> clusters,
- OutputCollector<Text,Text> output) throws IOException {
+ OutputCollector<IntWritable,VectorWritable> output) throws IOException {
Cluster nearestCluster = null;
double nearestDistance = Double.MAX_VALUE;
for (Cluster cluster : clusters) {
@@ -93,7 +95,7 @@ public class KMeansClusterer {
}
}
- output.collect(new Text(point.getName()), new Text(String.valueOf(nearestCluster.getId())));
+ output.collect(new IntWritable(nearestCluster.getId()), new VectorWritable(point));
}
/**
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Wed Apr 21 20:35:22 2010
@@ -29,6 +29,7 @@ import org.apache.commons.cli2.commandli
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
@@ -305,11 +306,11 @@ public final class KMeansDriver {
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
- conf.setMapOutputKeyClass(Text.class);
- conf.setMapOutputValueClass(Text.class);
- conf.setOutputKeyClass(Text.class);
+ conf.setMapOutputKeyClass(IntWritable.class);
+ conf.setMapOutputValueClass(VectorWritable.class);
+ conf.setOutputKeyClass(IntWritable.class);
// the output is the cluster id
- conf.setOutputValueClass(Text.class);
+ conf.setOutputValueClass(VectorWritable.class);
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java?rev=936489&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java Wed Apr 21 20:35:22 2010
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.meanshift;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.OutputLogFilter;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.math.VectorWritable;
+
+public class MeanShiftCanopyClusterMapper extends MapReduceBase implements
+ Mapper<WritableComparable<?>, MeanShiftCanopy, IntWritable, VectorWritable> {
+
+ private MeanShiftCanopyClusterer clusterer;
+
+ private OutputCollector<IntWritable, VectorWritable> output;
+
+ private List<MeanShiftCanopy> canopies;
+
+ @Override
+ public void map(WritableComparable<?> key, MeanShiftCanopy vector, OutputCollector<IntWritable, VectorWritable> output,
+ Reporter reporter) throws IOException {
+ int vectorId = vector.getId();
+ for (MeanShiftCanopy msc : canopies) {
+ for (int containedId : msc.getBoundPoints().toList()) {
+ if (vectorId == containedId) {
+ // System.out.println(msc.getId() + ": v" + vectorId + "=" + ClusterBase.formatVector(vector.getCenter(), null));
+ output.collect(new IntWritable(msc.getId()), new VectorWritable(vector.getCenter()));
+ }
+ }
+ }
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ super.configure(job);
+ try {
+ canopies = getCanopies(job);
+ } catch (SecurityException e) {
+ throw new IllegalStateException(e);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalStateException(e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public static List<MeanShiftCanopy> getCanopies(JobConf job) throws SecurityException, IllegalArgumentException,
+ NoSuchMethodException, InvocationTargetException {
+ String statePath = job.get(MeanShiftCanopyDriver.STATE_IN_KEY);
+ List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+ try {
+ Path path = new Path(statePath);
+ FileSystem fs = FileSystem.get(path.toUri(), job);
+ FileStatus[] status = fs.listStatus(path, new OutputLogFilter());
+ for (FileStatus s : status) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, s.getPath(), job);
+ try {
+ Text key = new Text();
+ MeanShiftCanopy canopy = new MeanShiftCanopy();
+ while (reader.next(key, canopy)) {
+ canopies.add(canopy);
+ canopy = new MeanShiftCanopy();
+ }
+ } finally {
+ reader.close();
+ }
+ }
+ return canopies;
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java Wed Apr 21 20:35:22 2010
@@ -21,6 +21,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -28,35 +29,39 @@ import org.apache.hadoop.mapred.OutputCo
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
public class MeanShiftCanopyClusterer {
-
+
private double convergenceDelta = 0;
+
// the next canopyId to be allocated
// private int nextCanopyId = 0;
// the T1 distance threshold
private double t1;
+
// the T2 distance threshold
private double t2;
+
// the distance measure
private DistanceMeasure measure;
-
+
public MeanShiftCanopyClusterer(JobConf job) {
configure(job);
}
-
+
public MeanShiftCanopyClusterer(DistanceMeasure aMeasure, double aT1, double aT2, double aDelta) {
config(aMeasure, aT1, aT2, aDelta);
}
-
+
public double getT1() {
return t1;
}
-
+
public double getT2() {
return t2;
}
-
+
/**
* Configure the Canopy and its distance measure
*
@@ -65,8 +70,8 @@ public class MeanShiftCanopyClusterer {
*/
public void configure(JobConf job) {
try {
- measure = Class.forName(job.get(MeanShiftCanopyConfigKeys.DISTANCE_MEASURE_KEY)).asSubclass(
- DistanceMeasure.class).newInstance();
+ measure = Class.forName(job.get(MeanShiftCanopyConfigKeys.DISTANCE_MEASURE_KEY)).asSubclass(DistanceMeasure.class)
+ .newInstance();
measure.configure(job);
} catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
@@ -80,7 +85,7 @@ public class MeanShiftCanopyClusterer {
t2 = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.T2_KEY));
convergenceDelta = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.CLUSTER_CONVERGENCE_KEY));
}
-
+
/**
* Configure the Canopy for unit tests
*
@@ -94,7 +99,7 @@ public class MeanShiftCanopyClusterer {
t2 = aT2;
convergenceDelta = aDelta;
}
-
+
/**
* Merge the given canopy into the canopies list. If it touches any existing canopy (norm<T1) then add the
* center of each to the other. If it covers any other canopies (norm<T2), then merge the given canopy with
@@ -127,13 +132,13 @@ public class MeanShiftCanopyClusterer {
closestCoveringCanopy.merge(aCanopy);
}
}
-
+
/** Emit the new canopy to the collector, keyed by the canopy's Id */
- static void emitCanopy(MeanShiftCanopy canopy, OutputCollector<Text,WritableComparable<?>> collector) throws IOException {
+ static void emitCanopy(MeanShiftCanopy canopy, OutputCollector<Text, WritableComparable<?>> collector) throws IOException {
String identifier = canopy.getIdentifier();
collector.collect(new Text(identifier), new Text("new " + canopy.toString()));
}
-
+
/**
* Shift the center to the new centroid of the cluster
*
@@ -149,7 +154,7 @@ public class MeanShiftCanopyClusterer {
canopy.setPointTotal(centroid.clone());
return canopy.isConverged();
}
-
+
/**
* Return if the point is covered by this canopy
*
@@ -162,7 +167,7 @@ public class MeanShiftCanopyClusterer {
boolean covers(MeanShiftCanopy canopy, Vector point) {
return measure.distance(canopy.getCenter(), point) < t1;
}
-
+
/**
* Return if the point is closely covered by the canopy
*
@@ -175,17 +180,16 @@ public class MeanShiftCanopyClusterer {
public boolean closelyBound(MeanShiftCanopy canopy, Vector point) {
return measure.distance(canopy.getCenter(), point) < t2;
}
-
+
/**
* Story: User can exercise the reference implementation to verify that the test datapoints are clustered in
* a reasonable manner.
*/
public void testReferenceImplementation() {
- MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), 4.0,
- 1.0, 0.5);
+ MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), 4.0, 1.0, 0.5);
List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
// add all points to the canopies
-
+
boolean done = false;
int iter = 1;
while (!done) {// shift canopies to their centroids
@@ -199,7 +203,7 @@ public class MeanShiftCanopyClusterer {
System.out.println(iter++);
}
}
-
+
/**
* This is the reference mean-shift implementation. Given its inputs it iterates over the points and
* clusters until their centers converge or until the maximum number of iterations is exceeded.
@@ -211,27 +215,23 @@ public class MeanShiftCanopyClusterer {
* @param numIter
* the maximum number of iterations
*/
- public static List<MeanShiftCanopy> clusterPoints(List<Vector> points,
- DistanceMeasure measure,
- double convergenceThreshold,
- double t1,
- double t2,
- int numIter) {
+ public static List<MeanShiftCanopy> clusterPoints(List<Vector> points, DistanceMeasure measure, double convergenceThreshold,
+ double t1, double t2, int numIter) {
MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(measure, t1, t2, convergenceThreshold);
-
+
List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
int nextCanopyId = 0;
for (Vector point : points) {
clusterer.mergeCanopy(new MeanShiftCanopy(point, nextCanopyId++), canopies);
}
-
+
boolean converged = false;
for (int iter = 0; !converged && iter < numIter; iter++) {
converged = runMeanShiftCanopyIteration(canopies, clusterer);
}
return canopies;
}
-
+
/**
* Perform a single iteration over the points and clusters, assigning points to clusters and returning if
* the iterations are completed.
@@ -239,8 +239,7 @@ public class MeanShiftCanopyClusterer {
* @param canopies
* the List<MeanShiftCanopy> clusters
*/
- public static boolean runMeanShiftCanopyIteration(List<MeanShiftCanopy> canopies,
- MeanShiftCanopyClusterer clusterer) {
+ public static boolean runMeanShiftCanopyIteration(List<MeanShiftCanopy> canopies, MeanShiftCanopyClusterer clusterer) {
boolean converged = true;
List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
for (MeanShiftCanopy canopy : canopies) {
@@ -249,7 +248,7 @@ public class MeanShiftCanopyClusterer {
}
canopies = migratedCanopies;
return converged;
-
+
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java Wed Apr 21 20:35:22 2010
@@ -29,6 +29,7 @@ import org.apache.commons.cli2.builder.G
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
@@ -36,43 +37,49 @@ import org.apache.hadoop.mapred.JobClien
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterMapper;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansConfigKeys;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public final class MeanShiftCanopyDriver {
-
+
private static final Logger log = LoggerFactory.getLogger(MeanShiftCanopyDriver.class);
-
- private MeanShiftCanopyDriver() {}
-
+
+ public static final String STATE_IN_KEY = "org.apache.mahout.clustering.meanshift.stateInKey";
+
+ private MeanShiftCanopyDriver() {
+ }
+
public static void main(String[] args) {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
-
+
Option inputOpt = DefaultOptionCreator.inputOption().create();
Option outputOpt = DefaultOptionCreator.outputOption().create();
Option convergenceDeltaOpt = DefaultOptionCreator.convergenceOption().create();
Option helpOpt = DefaultOptionCreator.helpOption();
-
- Option modelOpt = obuilder.withLongName("distanceClass").withRequired(true).withShortName("d")
- .withArgument(abuilder.withName("distanceClass").withMinimum(1).withMaximum(1).create())
- .withDescription("The distance measure class name.").create();
-
- Option threshold1Opt = obuilder.withLongName("threshold_1").withRequired(true).withShortName("t1")
- .withArgument(abuilder.withName("threshold_1").withMinimum(1).withMaximum(1).create())
- .withDescription("The T1 distance threshold.").create();
-
- Option threshold2Opt = obuilder.withLongName("threshold_2").withRequired(true).withShortName("t2")
- .withArgument(abuilder.withName("threshold_2").withMinimum(1).withMaximum(1).create())
- .withDescription("The T1 distance threshold.").create();
-
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt)
- .withOption(modelOpt).withOption(helpOpt).withOption(convergenceDeltaOpt).withOption(threshold1Opt)
- .withOption(threshold2Opt).create();
-
+
+ Option modelOpt = obuilder.withLongName("distanceClass").withRequired(true).withShortName("d").withArgument(
+ abuilder.withName("distanceClass").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The distance measure class name.").create();
+
+ Option threshold1Opt = obuilder.withLongName("threshold_1").withRequired(true).withShortName("t1").withArgument(
+ abuilder.withName("threshold_1").withMinimum(1).withMaximum(1).create()).withDescription("The T1 distance threshold.")
+ .create();
+
+ Option threshold2Opt = obuilder.withLongName("threshold_2").withRequired(true).withShortName("t2").withArgument(
+ abuilder.withName("threshold_2").withMinimum(1).withMaximum(1).create()).withDescription("The T1 distance threshold.")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(modelOpt).withOption(helpOpt)
+ .withOption(convergenceDeltaOpt).withOption(threshold1Opt).withOption(threshold2Opt).create();
+
try {
Parser parser = new Parser();
parser.setGroup(group);
@@ -81,7 +88,7 @@ public final class MeanShiftCanopyDriver
CommandLineUtil.printHelp(group);
return;
}
-
+
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
String measureClassName = cmdLine.getValue(modelOpt).toString();
@@ -89,14 +96,14 @@ public final class MeanShiftCanopyDriver
double t2 = Double.parseDouble(cmdLine.getValue(threshold2Opt).toString());
double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
createCanopyFromVectors(input, output + "/intial-canopies");
- runJob(output + "/intial-canopies", output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY,
- measureClassName, t1, t2, convergenceDelta);
+ runJob(output + "/intial-canopies", output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY, measureClassName, t1, t2,
+ convergenceDelta);
} catch (OptionException e) {
log.error("Exception parsing command line: ", e);
CommandLineUtil.printHelp(group);
}
}
-
+
/**
* Run the job
*
@@ -115,24 +122,19 @@ public final class MeanShiftCanopyDriver
* @param convergenceDelta
* the double convergence criteria
*/
- public static void runJob(String input,
- String output,
- String control,
- String measureClassName,
- double t1,
- double t2,
- double convergenceDelta) {
-
+ public static void runJob(String input, String output, String control, String measureClassName, double t1, double t2,
+ double convergenceDelta) {
+
Configurable client = new JobClient();
JobConf conf = new JobConf(MeanShiftCanopyDriver.class);
-
+
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(MeanShiftCanopy.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
-
+
conf.setMapperClass(MeanShiftCanopyMapper.class);
conf.setReducerClass(MeanShiftCanopyReducer.class);
conf.setNumReduceTasks(1);
@@ -143,7 +145,7 @@ public final class MeanShiftCanopyDriver
conf.set(MeanShiftCanopyConfigKeys.T1_KEY, String.valueOf(t1));
conf.set(MeanShiftCanopyConfigKeys.T2_KEY, String.valueOf(t2));
conf.set(MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY, control);
-
+
client.setConf(conf);
try {
JobClient.runJob(conf);
@@ -151,7 +153,7 @@ public final class MeanShiftCanopyDriver
log.warn(e.toString(), e);
}
}
-
+
/**
* Run the job
*
@@ -161,22 +163,22 @@ public final class MeanShiftCanopyDriver
* the output pathname String
*/
public static void createCanopyFromVectors(String input, String output) {
-
+
Configurable client = new JobClient();
JobConf conf = new JobConf(MeanShiftCanopyDriver.class);
-
+
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(MeanShiftCanopy.class);
-
+
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
-
+
conf.setMapperClass(MeanShiftCanopyCreatorMapper.class);
conf.setNumReduceTasks(0);
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
-
+
client.setConf(conf);
try {
JobClient.runJob(conf);
@@ -184,4 +186,44 @@ public final class MeanShiftCanopyDriver
log.warn(e.toString(), e);
}
}
+
+ /**
+ * Run the job using supplied arguments
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for input clusters
+ * @param output
+ * the directory pathname for output clustered points
+ */
+ public static void runClustering(String input,
+ String clustersIn,
+ String output) {
+
+ JobConf conf = new JobConf(FuzzyKMeansDriver.class);
+ conf.setJobName("Mean Shift Clustering");
+
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
+
+ FileInputFormat.setInputPaths(conf, new Path(input));
+ Path outPath = new Path(output);
+ FileOutputFormat.setOutputPath(conf, outPath);
+
+ conf.setMapperClass(MeanShiftCanopyClusterMapper.class);
+
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+
+ // uncomment it to run locally
+ // conf.set("mapred.job.tracker", "local");
+ conf.setNumReduceTasks(0);
+ conf.set(STATE_IN_KEY, clustersIn);
+ try {
+ JobClient.runJob(conf);
+ } catch (IOException e) {
+ log.warn(e.toString(), e);
+ }
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java Wed Apr 21 20:35:22 2010
@@ -114,7 +114,7 @@ public class MeanShiftCanopyJob {
*/
public static void runJob(String input, String output, String measureClassName, double t1, double t2, double convergenceDelta,
int maxIterations) throws IOException {
- runJob(input, output, measureClassName, t1,t2,convergenceDelta, maxIterations, false);
+ runJob(input, output, measureClassName, t1, t2, convergenceDelta, maxIterations, false);
}
/**
@@ -147,14 +147,14 @@ public class MeanShiftCanopyJob {
fs.delete(outPath, true);
}
fs.mkdirs(outPath);
-
+
String clustersIn = output + "/initial-canopies";
if (inputIsCanopies) {
clustersIn = input;
} else {
MeanShiftCanopyDriver.createCanopyFromVectors(input, clustersIn);
}
-
+
// iterate until the clusters converge
boolean converged = false;
int iteration = 0;
@@ -169,6 +169,9 @@ public class MeanShiftCanopyJob {
clustersIn = output + "/canopies-" + iteration;
iteration++;
}
+
+ // now cluster the points
+ MeanShiftCanopyDriver.runClustering((inputIsCanopies ? input : output + "/initial-canopies"), clustersIn, output + "/clusters");
}
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java Wed Apr 21 20:35:22 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.classifier.bay
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Text;
@@ -44,11 +45,11 @@ public class BayesFeatureMapperTest exte
DummyOutputCollector<StringTuple,DoubleWritable> output = new DummyOutputCollector<StringTuple,DoubleWritable>();
mapper.map(new Text("foo"), new Text("big brown shoe"), output,
Reporter.NULL);
- Map<String,List<DoubleWritable>> outMap = output.getData();
+ Map<StringTuple, List<DoubleWritable>> outMap = output.getData();
System.out.println("Map: " + outMap);
assertNotNull("outMap is null and it shouldn't be", outMap);
// TODO: How about not such a lame test here?
- for (Map.Entry<String,List<DoubleWritable>> entry : outMap.entrySet()) {
+ for (Entry<StringTuple, List<DoubleWritable>> entry : outMap.entrySet()) {
assertTrue("entry.getKey() Size: " + entry.getKey().length()
+ " is not greater than: 0", entry.getKey().length() > 0);
assertEquals("entry.getValue() Size: " + entry.getValue().size()
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java?rev=936489&r1=936488&r2=936489&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java Wed Apr 21 20:35:22 2010
@@ -49,7 +49,7 @@ public class TestClusterInterface extend
public void testDirichletNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Cluster model = new NormalModel(m, 0.75);
+ Cluster model = new NormalModel(5, m, 0.75);
String format = model.asFormatString(null);
assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
String json = model.asJsonString();
@@ -63,7 +63,7 @@ public class TestClusterInterface extend
public void testDirichletSampledNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Cluster model = new SampledNormalModel(m, 0.75);
+ Cluster model = new SampledNormalModel(5, m, 0.75);
String format = model.asFormatString(null);
assertEquals("format", "snm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
String json = model.asJsonString();
@@ -77,7 +77,7 @@ public class TestClusterInterface extend
public void testDirichletASNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Cluster model = new AsymmetricSampledNormalModel(m, m);
+ Cluster model = new AsymmetricSampledNormalModel(5, m, m);
String format = model.asFormatString(null);
assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format);
String json = model.asJsonString();
@@ -91,7 +91,7 @@ public class TestClusterInterface extend
public void testDirichletL1Model() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- Cluster model = new L1Model(m);
+ Cluster model = new L1Model(5, m);
String format = model.asFormatString(null);
assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
String json = model.asJsonString();
@@ -105,7 +105,7 @@ public class TestClusterInterface extend
public void testDirichletNormalModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- NormalModel model = new NormalModel(m, 0.75);
+ NormalModel model = new NormalModel(5, m, 0.75);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
@@ -114,7 +114,7 @@ public class TestClusterInterface extend
public void testDirichletNormalModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- NormalModel model = new NormalModel(m, 0.75);
+ NormalModel model = new NormalModel(5, m, 0.75);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();
GsonBuilder builder = new GsonBuilder();
@@ -128,7 +128,7 @@ public class TestClusterInterface extend
public void testDirichletAsymmetricSampledNormalModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
+ AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(5, m, m);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format);
@@ -137,7 +137,7 @@ public class TestClusterInterface extend
public void testDirichletAsymmetricSampledNormalModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(m, m);
+ AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(5, m, m);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();
@@ -152,7 +152,7 @@ public class TestClusterInterface extend
public void testDirichletL1ModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- L1Model model = new L1Model(m);
+ L1Model model = new L1Model(5, m);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String format = cluster.asFormatString(null);
assertEquals("format", "l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
@@ -161,7 +161,7 @@ public class TestClusterInterface extend
public void testDirichletL1ModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
- L1Model model = new L1Model(m);
+ L1Model model = new L1Model(5, m);
Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0);
String json = cluster.asJsonString();