You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/03/18 12:07:16 UTC
svn commit: r755548 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/kmeans/
core/src/test/java/org/apache/mahout/clustering/kmeans/
examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/
Author: gsingers
Date: Wed Mar 18 11:07:16 2009
New Revision: 755548
URL: http://svn.apache.org/viewvc?rev=755548&view=rev
Log:
MAHOUT-99: Fix k-means speed issue
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -14,7 +14,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.mahout.clustering.kmeans;
import org.apache.hadoop.io.Text;
@@ -64,6 +63,7 @@
* Format the cluster for output
*
* @param cluster the Cluster
+ * @return
*/
public static String formatCluster(Cluster cluster) {
return cluster.getIdentifier() + ": "
@@ -73,8 +73,7 @@
/**
* Decodes and returns a Cluster from the formattedString
*
- * @param formattedString
- * a String produced by formatCluster
+ * @param formattedString a String produced by formatCluster
* @return a new Canopy
*/
public static Cluster decodeCluster(String formattedString) {
@@ -83,8 +82,8 @@
String center = formattedString.substring(beginIndex);
char firstChar = id.charAt(0);
boolean startsWithV = firstChar == 'V';
- if (firstChar == 'C' || startsWithV) {
- int clusterId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));
+ if (firstChar == 'C' || startsWithV) {
+ int clusterId = Integer.parseInt(formattedString.substring(1, beginIndex - 2));
Vector clusterCenter = AbstractVector.decodeVector(center);
Cluster cluster = new Cluster(clusterCenter, clusterId);
cluster.converged = startsWithV;
@@ -96,12 +95,11 @@
/**
* Configure the distance measure from the job
*
- * @param job
- * the JobConf for the job
+ * @param job the JobConf for the job
*/
public static void configure(JobConf job) {
try {
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ final ClassLoader ccl = Thread.currentThread().getContextClassLoader();
Class<?> cl = ccl.loadClass(job.get(DISTANCE_MEASURE_KEY));
measure = (DistanceMeasure) cl.newInstance();
measure.configure(job);
@@ -119,10 +117,8 @@
/**
* Configure the distance measure directly. Used by unit tests.
*
- * @param aMeasure
- * the DistanceMeasure
- * @param aConvergenceDelta
- * the delta value used to define convergence
+ * @param aMeasure the DistanceMeasure
+ * @param aConvergenceDelta the delta value used to define convergence
*/
public static void config(DistanceMeasure aMeasure, double aConvergenceDelta) {
measure = aMeasure;
@@ -133,15 +129,11 @@
/**
* Emit the point to the nearest cluster center
*
- * @param point
- * a point
- * @param clusters
- * a List<Cluster> to test
- * @param values
- * a Writable containing the input point and possible other values
- * of interest (payload)
- * @param output
- * the OutputCollector to emit into
+ * @param point a point
+ * @param clusters a List<Cluster> to test
+ * @param values a Writable containing the input point and possible other
+ * values of interest (payload)
+ * @param output the OutputCollector to emit into
* @throws IOException
*/
public static void emitPointToNearestCluster(Vector point,
@@ -156,7 +148,26 @@
nearestDistance = distance;
}
}
- output.collect(new Text(formatCluster(nearestCluster)), values);
+ // emit only clusterID
+ String outKey = nearestCluster.getIdentifier();
+ String value = "1\t" + values.toString();
+ output.collect(new Text(outKey), new Text(value));
+ }
+
+ public static void outputPointWithClusterInfo(String key, Vector point,
+ List<Cluster> clusters, Text values, OutputCollector<Text, Text> output)
+ throws IOException {
+ Cluster nearestCluster = null;
+ double nearestDistance = Double.MAX_VALUE;
+ for (Cluster cluster : clusters) {
+ double distance = measure.distance(point, cluster.getCenter());
+ if (nearestCluster == null || distance < nearestDistance) {
+ nearestCluster = cluster;
+ nearestDistance = distance;
+ }
+ }
+ output.collect(new Text(key), new Text(Integer
+ .toString(nearestCluster.clusterId)));
}
/**
@@ -177,10 +188,10 @@
/**
* Construct a new cluster with the given point as its center
*
- * @param center
- * the center point
+ * @param center the center point
*/
public Cluster(Vector center) {
+ super();
this.clusterId = nextClusterId++;
this.center = center;
this.numPoints = 0;
@@ -190,16 +201,28 @@
/**
* Construct a new cluster with the given point as its center
*
- * @param center
- * the center point
+ * @param center the center point
*/
public Cluster(Vector center, int clusterId) {
+ super();
this.clusterId = clusterId;
this.center = center;
this.numPoints = 0;
this.pointTotal = center.like();
}
+ /**
+ * Construct a new clsuter with the given id as identifier
+ *
+ * @param identifier
+ */
+ public Cluster(String clusterId) {
+
+ this.clusterId = Integer.parseInt((clusterId.substring(1)));
+ this.numPoints = 0;
+ this.converged = clusterId.startsWith("V");
+ }
+
@Override
public String toString() {
return getIdentifier() + " - " + center.asFormatString();
@@ -215,25 +238,17 @@
/**
* Add the point to the cluster
*
- * @param point
- * a point to add
+ * @param point a point to add
*/
public void addPoint(Vector point) {
- centroid = null;
- numPoints++;
- if (pointTotal == null)
- pointTotal = point.copy();
- else
- pointTotal = point.plus(pointTotal);
+ addPoints(1, point);
}
/**
* Add the point to the cluster
*
- * @param count
- * the number of points in the delta
- * @param delta
- * a point to add
+ * @param count the number of points in the delta
+ * @param delta a point to add
*/
public void addPoints(int count, Vector delta) {
centroid = null;
@@ -241,7 +256,7 @@
if (pointTotal == null)
pointTotal = delta.copy();
else
- pointTotal = delta.plus(pointTotal);
+ pointTotal = pointTotal.plus(delta);
}
public Vector getCenter() {
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java?rev=755548&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java Wed Mar 18 11:07:16 2009
@@ -0,0 +1,36 @@
+package org.apache.mahout.clustering.kmeans;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.matrix.AbstractVector;
+import org.apache.mahout.matrix.Vector;
+
+public class KMeansClusterMapper extends KMeansMapper {
+ public void map(WritableComparable<?> key, Text values,
+ OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+ Vector point = AbstractVector.decodeVector(values.toString());
+ Cluster.outputPointWithClusterInfo(values.toString(), point, clusters,
+ values, output);
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansCombiner.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -14,9 +14,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.Iterator;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
@@ -25,20 +27,19 @@
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.matrix.AbstractVector;
-import java.io.IOException;
-import java.util.Iterator;
-
public class KMeansCombiner extends MapReduceBase implements
Reducer<Text, Text, Text, Text> {
@Override
public void reduce(Text key, Iterator<Text> values,
OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
- Cluster cluster = Cluster.decodeCluster(key.toString());
+ Cluster cluster = new Cluster(key.toString());
while (values.hasNext()) {
- cluster.addPoint(AbstractVector.decodeVector(values.next().toString()));
+ String[] numPointnValue = values.next().toString().split("\t");
+ cluster.addPoints(Integer.parseInt(numPointnValue[0].trim()),
+ AbstractVector.decodeVector(numPointnValue[1].trim()));
}
- output.collect(key, new Text(cluster.getNumPoints() + ", "
+ output.collect(key, new Text(cluster.getNumPoints() + "\t"
+ cluster.getPointTotal().asFormatString()));
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -14,23 +14,28 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.FileSplit;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapred.KeyValueLineRecordReader;
+import org.apache.hadoop.mapred.TextInputFormat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-
public class KMeansDriver {
private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
@@ -45,21 +50,23 @@
String measureClass = args[3];
double convergenceDelta = Double.parseDouble(args[4]);
int maxIterations = Integer.parseInt(args[5]);
- runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations);
+ runJob(input, clusters, output, measureClass, convergenceDelta,
+ maxIterations, 2);
}
/**
* Run the job using supplied arguments
- *
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for initial & computed clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
+ *
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for initial & computed clusters
+ * @param output the directory pathname for output points
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
- * @param maxIterations the maximum number of iterations
+ * @param maxIterations the maximum number of iterations
*/
public static void runJob(String input, String clustersIn, String output,
- String measureClass, double convergenceDelta, int maxIterations) {
+ String measureClass, double convergenceDelta, int maxIterations,
+ int numCentroids) {
// iterate until the clusters converge
boolean converged = false;
int iteration = 0;
@@ -70,32 +77,32 @@
// point the output to a new directory per iteration
String clustersOut = output + "/clusters-" + iteration;
converged = runIteration(input, clustersIn, clustersOut, measureClass,
- delta);
+ delta, numCentroids);
// now point the input to the old output directory
clustersIn = output + "/clusters-" + iteration;
iteration++;
}
// now actually cluster the points
log.info("Clustering ");
- runClustering(input, clustersIn, output + "/points", measureClass,
- delta);
+ runClustering(input, clustersIn, output + "/points", measureClass, delta);
}
/**
* Run the job using supplied arguments
- *
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for iniput clusters
- * @param clustersOut the directory pathname for output clusters
- * @param measureClass the classname of the DistanceMeasure
+ *
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for iniput clusters
+ * @param clustersOut the directory pathname for output clusters
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
* @return true if the iteration successfully runs
*/
private static boolean runIteration(String input, String clustersIn,
- String clustersOut, String measureClass, String convergenceDelta) {
+ String clustersOut, String measureClass, String convergenceDelta,
+ int numReduceTasks) {
JobClient client = new JobClient();
JobConf conf = new JobConf(KMeansDriver.class);
-
+ conf.setInputFormat(TextInputFormat.class);
conf.setOutputKeyClass(Text.class);
conf.setOutputValueClass(Text.class);
@@ -106,12 +113,16 @@
conf.setMapperClass(KMeansMapper.class);
conf.setCombinerClass(KMeansCombiner.class);
conf.setReducerClass(KMeansReducer.class);
- conf.setNumReduceTasks(1);
- conf.setOutputFormat(SequenceFileOutputFormat.class);
+ // conf.setNumMapTasks(numMapTasks);
+ conf.setNumReduceTasks(numReduceTasks);
conf.set(Cluster.CLUSTER_PATH_KEY, clustersIn);
conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
+ conf.set("mapred.child.java.opts", "-Xmx1536m");
+ // uncomment it to run locally
+ conf.set("mapred.job.tracker", "local");
+
client.setConf(conf);
try {
JobClient.runJob(conf);
@@ -125,15 +136,15 @@
/**
* Run the job using supplied arguments
- *
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for input clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
+ *
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for input clusters
+ * @param output the directory pathname for output points
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
*/
- private static void runClustering(String input, String clustersIn, String output,
- String measureClass, String convergenceDelta) {
+ private static void runClustering(String input, String clustersIn,
+ String output, String measureClass, String convergenceDelta) {
JobClient client = new JobClient();
JobConf conf = new JobConf(KMeansDriver.class);
@@ -144,13 +155,16 @@
Path outPath = new Path(output);
FileOutputFormat.setOutputPath(conf, outPath);
- conf.setMapperClass(KMeansMapper.class);
+ conf.setMapperClass(KMeansClusterMapper.class);
conf.setNumReduceTasks(0);
conf.set(Cluster.CLUSTER_PATH_KEY, clustersIn);
conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
client.setConf(conf);
+ // uncomment it to run locally
+ // conf.set("mapred.job.tracker", "local");
+ conf.set("mapred.child.java.opts", "-Xmx1536m");
try {
JobClient.runJob(conf);
} catch (IOException e) {
@@ -160,23 +174,52 @@
/**
* Return if all of the Clusters in the filePath have converged or not
- *
+ *
* @param filePath the file path to the single file containing the clusters
- * @param conf the JobConf
- * @param fs the FileSystem
+ * @param conf the JobConf
+ * @param fs the FileSystem
* @return true if all Clusters are converged
* @throws IOException if there was an IO error
*/
- private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
- throws IOException {
- Path outPart = new Path(filePath);
- SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);
- Text key = new Text();
- Text value = new Text();
+ private static boolean isConverged(String filePath, JobConf conf,
+ FileSystem fs) throws IOException {
+ Path clusterPath = new Path(filePath);
+ List<Path> result = new ArrayList<Path>();
+
+ PathFilter clusterFileFilter = new PathFilter() {
+ public boolean accept(Path path) {
+ return path.getName().startsWith("part");
+ }
+ };
+
+ FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
+ clusterPath, clusterFileFilter)), clusterFileFilter);
+
+ for (FileStatus match : matches) {
+ result.add(fs.makeQualified(match.getPath()));
+ }
boolean converged = true;
- while (converged && reader.next(key, value)) {
- converged = value.toString().charAt(0) == 'V';
+
+ for (Path p : result) {
+ KeyValueLineRecordReader reader = null;
+
+ try {
+ reader = new KeyValueLineRecordReader(conf, new FileSplit(p, 0, fs
+ .getFileStatus(p).getLen(), (String[]) null));
+ Text key = new Text();
+ Text value = new Text();
+
+ while (converged && reader.next(key, value)) {
+ converged = value.toString().startsWith("V");
+ }
+ } finally {
+ if (reader != null) {
+ reader.close();
+ }
+ }
+
}
+
return converged;
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,11 @@
/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -17,26 +18,37 @@
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapred.JobConf;
-import java.io.IOException;
-
public class KMeansJob {
private KMeansJob() {
}
public static void main(String[] args) throws IOException {
- String input = args[0];
- String clusters = args[1];
- String output = args[2];
- String measureClass = args[3];
- double convergenceDelta = Double.parseDouble(args[4]);
- int maxIterations = Integer.parseInt(args[5]);
+ int index = 0;
+
+ if (args.length != 7) {
+ System.out.println("Expected number of arguments 10 and received:"
+ + args.length);
+ System.out
+ .println("Usage:input clustersIn output measureClass convergenceDelta maxIterations numCentroids");
+ System.exit(1);
+ }
+ String input = args[index++];
+ String clusters = args[index++];
+ String output = args[index++];
+ String measureClass = args[index++];
+ double convergenceDelta = Double.parseDouble(args[index++]);
+ int maxIterations = Integer.parseInt(args[index++]);
+ int numCentroids = Integer.parseInt(args[index++]);
+
runJob(input, clusters, output, measureClass, convergenceDelta,
- maxIterations);
+ maxIterations, numCentroids);
}
/**
@@ -51,7 +63,8 @@
* @param maxIterations the maximum number of iterations
*/
public static void runJob(String input, String clustersIn, String output,
- String measureClass, double convergenceDelta, int maxIterations) throws IOException {
+ String measureClass, double convergenceDelta, int maxIterations,
+ int numCentroids) throws IOException {
// delete the output directory
JobConf conf = new JobConf(KMeansJob.class);
Path outPath = new Path(output);
@@ -60,7 +73,8 @@
fs.delete(outPath, true);
}
fs.mkdirs(outPath);
+
KMeansDriver.runJob(input, clustersIn, output, measureClass,
- convergenceDelta, maxIterations);
+ convergenceDelta, maxIterations, numCentroids);
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Wed Mar 18 11:07:16 2009
@@ -14,12 +14,12 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.mahout.clustering.kmeans;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -30,57 +30,43 @@
import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.Vector;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
public class KMeansMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>, Text, Text, Text> {
+ Mapper<WritableComparable<?>, Text, Text, Text> {
- private List<Cluster> clusters;
+ protected List<Cluster> clusters;
@Override
public void map(WritableComparable<?> key, Text values,
- OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+ OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
Vector point = AbstractVector.decodeVector(values.toString());
Cluster.emitPointToNearestCluster(point, clusters, values, output);
}
/**
* Configure the mapper by providing its clusters. Used by unit tests.
- *
+ *
* @param clusters a List<Cluster>
*/
void config(List<Cluster> clusters) {
this.clusters = clusters;
}
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.mapred.MapReduceBase#configure(org.apache.hadoop.mapred.JobConf)
+ */
@Override
public void configure(JobConf job) {
super.configure(job);
Cluster.configure(job);
- String clusterPath = job.get(Cluster.CLUSTER_PATH_KEY);
clusters = new ArrayList<Cluster>();
- try {
- FileSystem fs = FileSystem.get(job);
- Path path = new Path(clusterPath + "/part-00000");
- SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
- try {
- Text key = new Text();
- Text value = new Text();
- while (reader.next(key, value)) {
- Cluster cluster = Cluster.decodeCluster(value.toString());
- // add the center so the centroid will be correct on output formatting
- cluster.addPoint(cluster.getCenter());
- clusters.add(cluster);
- }
- } finally {
- reader.close();
- }
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ KMeansUtil.configureWithClusterInfo(job.get(Cluster.CLUSTER_PATH_KEY),
+ clusters);
+
+ if (clusters.size() == 0)
+ throw new NullPointerException("Cluster is empty!!!");
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Wed Mar 18 11:07:16 2009
@@ -1,10 +1,10 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
@@ -14,9 +14,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.mahout.clustering.kmeans;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
@@ -24,37 +30,62 @@
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.matrix.AbstractVector;
-import org.apache.mahout.matrix.Vector;
-
-import java.io.IOException;
-import java.util.Iterator;
public class KMeansReducer extends MapReduceBase implements
- Reducer<Text, Text, Text, Text> {
+ Reducer<Text, Text, Text, Text> {
- //double delta = 0;
+ protected Map<String, Cluster> clusterMap;
@Override
public void reduce(Text key, Iterator<Text> values,
- OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
- Cluster cluster = Cluster.decodeCluster(key.toString());
+ OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+ Cluster cluster = clusterMap.get(key.toString());
+
while (values.hasNext()) {
String value = values.next().toString();
- int ix = value.indexOf(',');
- int count = Integer.parseInt(value.substring(0, ix));
- Vector total = AbstractVector.decodeVector(value.substring(ix + 2));
- cluster.addPoints(count, total);
+ String[] numNValue = value.split("\t");
+ cluster.addPoints(Integer.parseInt(numNValue[0].trim()), AbstractVector
+ .decodeVector(numNValue[1].trim()));
}
// force convergence calculation
cluster.computeConvergence();
output.collect(new Text(cluster.getIdentifier()), new Text(Cluster
- .formatCluster(cluster)));
+ .formatCluster(cluster)));
}
+ /*
+ * (non-Javadoc)
+ *
+ * @see org.apache.hadoop.mapred.MapReduceBase#configure(org.apache.hadoop.mapred.JobConf)
+ */
@Override
public void configure(JobConf job) {
+
super.configure(job);
Cluster.configure(job);
+ clusterMap = new HashMap<String, Cluster>();
+
+ List<Cluster> clusters = new ArrayList<Cluster>();
+ KMeansUtil.configureWithClusterInfo(job.get(Cluster.CLUSTER_PATH_KEY),
+ clusters);
+ setClusterMap(clusters);
+
+ if (clusterMap.size() == 0)
+ throw new NullPointerException("Cluster is empty!!!");
+ }
+
+ private void setClusterMap(List<Cluster> clusters) {
+ clusterMap = new HashMap<String, Cluster>();
+ for (Cluster cluster : clusters) {
+ clusterMap.put(cluster.getIdentifier(), cluster);
+ }
+ clusters.clear();
+ clusters = null;
+ }
+
+ public void config(List<Cluster> clusters) {
+ setClusterMap(clusters);
+
}
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java?rev=755548&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java Wed Mar 18 11:07:16 2009
@@ -0,0 +1,98 @@
+package org.apache.mahout.clustering.kmeans;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.FileSplit;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.KeyValueLineRecordReader;
+import org.apache.hadoop.mapred.RecordReader;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class KMeansUtil {
+ private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class);
+
+ /**
+ * Configure the mapper with the cluster info
+ *
+ * @param job
+ * @param clusters
+ */
+ public static void configureWithClusterInfo(String clusterPathStr,
+ List<Cluster> clusters) {
+ // Get the path location where the cluster Info is stored
+ JobConf job = new JobConf(KMeansUtil.class);
+ Path clusterPath = new Path(clusterPathStr);
+ List<Path> result = new ArrayList<Path>();
+
+ // filter out the files
+ PathFilter clusterFileFilter = new PathFilter() {
+ public boolean accept(Path path) {
+ return path.getName().startsWith("part");
+ }
+ };
+
+ try {
+ // get all filtered file names in result list
+ FileSystem fs = clusterPath.getFileSystem(job);
+ FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
+ clusterPath, clusterFileFilter)), clusterFileFilter);
+
+ for (FileStatus match : matches) {
+ result.add(fs.makeQualified(match.getPath()));
+ }
+
+ // iterate thru the result path list
+ for (Path path : result) {
+ RecordReader<Text, Text> recordReader = null;
+ try {
+ recordReader = new KeyValueLineRecordReader(job, new FileSplit(path,
+ 0, fs.getFileStatus(path).getLen(), (String[]) null));
+ Text key = new Text();
+ Text value = new Text();
+ int counter = 1;
+ while (recordReader.next(key, value)) {
+ // get the cluster info
+ Cluster cluster = Cluster.decodeCluster(value.toString());
+ clusters.add(cluster);
+ }
+ } finally {
+ if (recordReader != null) {
+ recordReader.close();
+ }
+
+ }
+ }
+
+ } catch (IOException e) {
+ log.info("Exception occurred in loading clusters:", e);
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Wed Mar 18 11:07:16 2009
@@ -17,10 +17,24 @@
package org.apache.mahout.clustering.kmeans;
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.nio.charset.Charset;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
import junit.framework.TestCase;
+
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
@@ -34,42 +48,28 @@
import org.apache.mahout.utils.EuclideanDistanceMeasure;
import org.apache.mahout.utils.ManhattanDistanceMeasure;
-import java.io.BufferedReader;
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.IOException;
-import java.io.FileOutputStream;
-import java.io.OutputStreamWriter;
-import java.io.InputStreamReader;
-import java.io.FileInputStream;
-import java.util.ArrayList;
-import java.util.List;
-import java.nio.charset.Charset;
-
public class TestKmeansClustering extends TestCase {
- public static final double[][] reference = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 },
- { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
+ public static final double[][] reference = { { 1, 1 }, { 2, 1 }, { 1, 2 },
+ { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
- public static final int[][] expectedNumPoints = { { 9 }, { 4, 5 }, { 4, 5, 0 },
- { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 }, { 1, 1, 1, 1, 1, 4 },
+ public static final int[][] expectedNumPoints = { { 9 }, { 4, 5 },
+ { 4, 5, 0 }, { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 }, { 1, 1, 1, 1, 1, 4 },
{ 1, 1, 1, 1, 1, 2, 2 }, { 1, 1, 1, 1, 1, 1, 2, 1 },
{ 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
- private static void rmr(String path) throws Exception {
+ private void rmr(String path) throws Exception {
File f = new File(path);
if (f.exists()) {
if (f.isDirectory()) {
String[] contents = f.list();
- for (String content : contents) {
- rmr(f.toString() + File.separator + content);
- }
+ for (int i = 0; i < contents.length; i++)
+ rmr(f.toString() + File.separator + contents[i]);
}
f.delete();
}
}
- @Override
protected void setUp() throws Exception {
super.setUp();
rmr("output");
@@ -81,16 +81,12 @@
* over the points and clusters until their centers converge or until the
* maximum number of iterations is exceeded.
*
- * @param points
- * the input List<Vector> of points
- * @param clusters
- * the initial List<Cluster> of clusters
- * @param measure
- * the DistanceMeasure to use
- * @param maxIter
- * the maximum number of iterations
+ * @param points the input List<Vector> of points
+ * @param clusters the initial List<Cluster> of clusters
+ * @param measure the DistanceMeasure to use
+ * @param maxIter the maximum number of iterations
*/
- private static void referenceKmeans(List<Vector> points, List<Cluster> clusters,
+ private void referenceKmeans(List<Vector> points, List<Cluster> clusters,
DistanceMeasure measure, int maxIter) {
boolean converged = false;
int iteration = 0;
@@ -103,16 +99,15 @@
* Perform a single iteration over the points and clusters, assigning points
* to clusters and returning if the iterations are completed.
*
- * @param points
- * the List<Vector> having the input points
- * @param clusters
- * the List<Cluster> clusters
- * @param measure
- * a DistanceMeasure to use
+ * @param points the List<Vector> having the input points
+ * @param clusters the List<Cluster> clusters
+ * @param measure a DistanceMeasure to use
* @return
*/
- private static boolean iterateReference(List<Vector> points, List<Cluster> clusters,
+ private boolean iterateReference(List<Vector> points, List<Cluster> clusters,
DistanceMeasure measure) {
+ boolean converged;
+ converged = true;
// iterate through all points, assigning each to the nearest cluster
for (Vector point : points) {
Cluster closestCluster = null;
@@ -127,7 +122,6 @@
closestCluster.addPoint(point);
}
// test for convergence
- boolean converged = true;
for (Cluster cluster : clusters) {
if (!cluster.computeConvergence())
converged = false;
@@ -141,7 +135,8 @@
public static List<Vector> getPoints(double[][] raw) {
List<Vector> points = new ArrayList<Vector>();
- for (double[] fr : raw) {
+ for (int i = 0; i < raw.length; i++) {
+ double[] fr = raw[i];
Vector vec = new SparseVector(fr.length);
vec.assign(fr);
points.add(vec);
@@ -160,7 +155,7 @@
Cluster.config(measure, 0.001);
// try all possible values of k
for (int k = 0; k < points.size(); k++) {
- System.out.println("Test k=" + (k + 1) + ':');
+ System.out.println("Test k=" + (k + 1) + ":");
// pick k initial cluster centers at random
List<Cluster> clusters = new ArrayList<Cluster>();
for (int i = 0; i < k + 1; i++) {
@@ -179,6 +174,15 @@
}
}
+ private Map<String, Cluster> loadClusterMap(List<Cluster> clusters) {
+ Map<String, Cluster> clusterMap = new HashMap<String, Cluster>();
+
+ for (Cluster cluster : clusters) {
+ clusterMap.put(cluster.getIdentifier(), cluster);
+ }
+ return clusterMap;
+ }
+
/**
* Story: test that the mapper will map input points to the nearest cluster
*
@@ -193,12 +197,15 @@
// pick k initial cluster centers at random
DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
List<Cluster> clusters = new ArrayList<Cluster>();
+
for (int i = 0; i < k + 1; i++) {
Cluster cluster = new Cluster(points.get(i));
// add the center so the centroid will be correct upon output
cluster.addPoint(cluster.getCenter());
clusters.add(cluster);
}
+
+ Map<String, Cluster> clusterMap = loadClusterMap(clusters);
mapper.config(clusters);
// map the data
for (Vector point : points) {
@@ -208,10 +215,12 @@
assertEquals("Number of map results", k + 1, collector.getData().size());
// now verify that all points are correctly allocated
for (String key : collector.getKeys()) {
- Cluster cluster = Cluster.decodeCluster(key);
+ Cluster cluster = clusterMap.get(key);
List<Text> values = collector.getValue(key);
for (Writable value : values) {
- Vector point = AbstractVector.decodeVector(value.toString());
+ String[] pointInfo = value.toString().split("\t");
+
+ Vector point = AbstractVector.decodeVector(pointInfo[1]);
double distance = euclideanDistanceMeasure.distance(cluster
.getCenter(), point);
for (Cluster c : clusters)
@@ -266,10 +275,10 @@
List<Text> values = collector2.getValue(key);
assertEquals("too many values", 1, values.size());
String value = values.get(0).toString();
- int ix = value.indexOf(',');
- count += Integer.parseInt(value.substring(0, ix));
- total = total
- .plus(AbstractVector.decodeVector(value.substring(ix + 2)));
+
+ String[] pointInfo = value.split("\t");
+ count += Integer.parseInt(pointInfo[0]);
+ total = total.plus(AbstractVector.decodeVector(pointInfo[1]));
}
assertEquals("total points", 9, count);
assertEquals("point total[0]", 27, (int) total.get(0));
@@ -297,7 +306,7 @@
Vector vec = points.get(i);
Cluster cluster = new Cluster(vec, i);
// add the center so the centroid will be correct upon output
- cluster.addPoint(cluster.getCenter());
+ // cluster.addPoint(cluster.getCenter());
clusters.add(cluster);
}
mapper.config(clusters);
@@ -315,6 +324,7 @@
// now reduce the data
KMeansReducer reducer = new KMeansReducer();
+ reducer.config(clusters);
DummyOutputCollector<Text, Text> collector3 = new DummyOutputCollector<Text, Text>();
for (String key : collector2.getKeys())
reducer.reduce(new Text(key), collector2.getValue(key).iterator(),
@@ -337,7 +347,8 @@
// now verify that all clusters have correct centers
converged = true;
- for (Cluster ref : reference) {
+ for (int i = 0; i < reference.size(); i++) {
+ Cluster ref = reference.get(i);
String key = ref.getIdentifier();
List<Text> values = collector3.getValue(key);
String value = values.get(0).toString();
@@ -373,52 +384,52 @@
writePointsToFile(points, "testdata/points/file1");
writePointsToFile(points, "testdata/points/file2");
- for (int k = 0; k < points.size(); k++) {
+ for (int k = 1; k < points.size(); k++) {
System.out.println("testKMeansMRJob k= " + k);
// pick k initial cluster centers at random
JobConf job = new JobConf(KMeansDriver.class);
FileSystem fs = FileSystem.get(job);
Path path = new Path("testdata/clusters/part-00000");
- SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
- Text.class, Text.class);
+ BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(fs
+ .create(path)));
+
for (int i = 0; i < k + 1; i++) {
Vector vec = points.get(i);
- Cluster cluster = new Cluster(vec);
+ Cluster cluster = new Cluster(vec, i);
// add the center so the centroid will be correct upon output
cluster.addPoint(cluster.getCenter());
- writer.append(new Text(cluster.getIdentifier()), new Text(Cluster
- .formatCluster(cluster)));
+ writer.write(cluster.getIdentifier() + "\t"
+ + Cluster.formatCluster(cluster) + "\n");
}
writer.close();
-
// now run the Job
KMeansJob.runJob("testdata/points", "testdata/clusters", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10);
-
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1);
// now compare the expected clusters with actual
File outDir = new File("output/points");
assertTrue("output dir exists?", outDir.exists());
String[] outFiles = outDir.list();
- assertEquals("output dir files?", 4, outFiles.length);
- BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(
- "output/points/part-00000"), Charset.forName("UTF-8")));
+ // assertEquals("output dir files?", 4, outFiles.length);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(
+ new FileInputStream("output/points/part-00000"), Charset
+ .forName("UTF-8")));
int[] expect = expectedNumPoints[k];
DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
while (reader.ready()) {
String line = reader.readLine();
String[] lineParts = line.split("\t");
assertEquals("line parts", 2, lineParts.length);
- String cl = line.substring(0, line.indexOf(':'));
- collector.collect(new Text(cl), new Text(lineParts[1]));
+ // String cl = line.substring(0, line.indexOf(':'));
+ collector.collect(new Text(lineParts[1]), new Text(lineParts[0]));
}
reader.close();
if (k == 2)
// cluster 3 is empty so won't appear in output
- assertEquals("clusters[" + k + ']', expect.length - 1, collector
+ assertEquals("clusters[" + k + "]", expect.length - 1, collector
.getKeys().size());
else
- assertEquals("clusters[" + k + ']', expect.length, collector.getKeys()
+ assertEquals("clusters[" + k + "]", expect.length, collector.getKeys()
.size());
}
}
@@ -429,7 +440,7 @@
*
* @throws Exception
*/
- public static void textKMeansWithCanopyClusterInput() throws Exception {
+ public void textKMeansWithCanopyClusterInput() throws Exception {
List<Vector> points = getPoints(reference);
File testData = new File("testdata");
if (!testData.exists())
@@ -446,15 +457,16 @@
// now run the KMeans job
KMeansJob.runJob("testdata/points", "testdata/canopies", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10);
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1);
// now compare the expected clusters with actual
File outDir = new File("output/points");
assertTrue("output dir exists?", outDir.exists());
String[] outFiles = outDir.list();
assertEquals("output dir files?", 4, outFiles.length);
- BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(
- "output/points/part-00000"), Charset.forName("UTF-8")));
+ BufferedReader reader = new BufferedReader(new InputStreamReader(
+ new FileInputStream("output/points/part-00000"), Charset
+ .forName("UTF-8")));
DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
while (reader.ready()) {
String line = reader.readLine();
@@ -470,7 +482,8 @@
public static void writePointsToFileWithPayload(List<Vector> points,
String fileName, String payload) throws IOException {
- BufferedWriter output = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileName), Charset.forName("UTF-8")));
+ BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
+ new FileOutputStream(fileName), Charset.forName("UTF-8")));
for (Vector point : points) {
output.write(point.asFormatString());
output.write(payload);
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=755548&r1=755547&r2=755548&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Wed Mar 18 11:07:16 2009
@@ -80,7 +80,7 @@
CanopyClusteringJob
.runJob(output + "/data", output, measureClass, t1, t2);
KMeansDriver.runJob(output + "/data", output + "/canopies", output,
- measureClass, convergenceDelta, maxIterations);
+ measureClass, convergenceDelta, maxIterations,1);
OutputDriver.runJob(output + "/points", output + "/clustered-points");
}
}