You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pr...@apache.org on 2012/02/28 05:00:10 UTC
svn commit: r1294454 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/classify/
test/java/org/apache/mahout/clustering/classify/
Author: pranjan
Date: Tue Feb 28 04:00:09 2012
New Revision: 1294454
URL: http://svn.apache.org/viewvc?rev=1294454&view=rev
Log:
MAHOUT-929, MAHOUT-931. Implemented mapreduce version of ClusterClassificationDriver with outlier removal capability.
Changed output of sequential to WeightedVectorWritable. Fixed and added test cases.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java Tue Feb 28 04:00:09 2012
@@ -18,14 +18,14 @@
package org.apache.mahout.clustering.classify;
/**
- * Constants used in Cluster Classification.
+ * Constants used in Cluster Classification.
*/
public class ClusterClassificationConfigKeys {
-
+
public static final String CLUSTERS_IN = "clusters_in";
public static final String OUTLIER_REMOVAL_THRESHOLD = "pdf_threshold";
+ public static final String EMIT_MOST_LIKELY = "emit_most_likely";
+
}
-
-
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Tue Feb 28 04:00:09 2012
@@ -17,6 +17,8 @@
package org.apache.mahout.clustering.classify;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.CLUSTERS_IN;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.EMIT_MOST_LIKELY;
import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
import java.io.IOException;
@@ -213,7 +215,7 @@ public class ClusterClassificationDriver
Configuration conf = new Configuration();
SequenceFile.Writer writer = new SequenceFile.Writer(
input.getFileSystem(conf), conf, new Path(output, "part-m-" + 0),
- IntWritable.class, VectorWritable.class);
+ IntWritable.class, WeightedVectorWritable.class);
for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
Vector pdfPerCluster = clusterClassifier.classify(vw.get());
@@ -231,7 +233,9 @@ public class ClusterClassificationDriver
throws IOException {
if (emitMostLikely) {
int maxValueIndex = pdfPerCluster.maxValueIndex();
- write(clusterModels, writer, vw, maxValueIndex);
+ WeightedVectorWritable wvw = new WeightedVectorWritable(
+ pdfPerCluster.maxValue(), vw.get());
+ write(clusterModels, writer, wvw, maxValueIndex);
} else {
writeAllAboveThreshold(clusterModels, clusterClassificationThreshold,
writer, vw, pdfPerCluster);
@@ -245,17 +249,19 @@ public class ClusterClassificationDriver
while (iterateNonZero.hasNext()) {
Element pdf = iterateNonZero.next();
if (pdf.get() >= clusterClassificationThreshold) {
+ WeightedVectorWritable wvw = new WeightedVectorWritable(pdf.get(),
+ vw.get());
int clusterIndex = pdf.index();
- write(clusterModels, writer, vw, clusterIndex);
+ write(clusterModels, writer, wvw, clusterIndex);
}
}
}
private static void write(List<Cluster> clusterModels,
- SequenceFile.Writer writer, VectorWritable vw, int maxValueIndex)
+ SequenceFile.Writer writer, WeightedVectorWritable wvw, int maxValueIndex)
throws IOException {
Cluster cluster = clusterModels.get(maxValueIndex);
- writer.append(new IntWritable(cluster.getId()), vw);
+ writer.append(new IntWritable(cluster.getId()), wvw);
}
/**
@@ -273,15 +279,16 @@ public class ClusterClassificationDriver
private static void classifyClusterMR(Configuration conf, Path input,
Path clustersIn, Path output, Double clusterClassificationThreshold,
boolean emitMostLikely) throws IOException, InterruptedException,
- ClassNotFoundException {
- Job job = new Job(conf,
- "Cluster Classification Driver running over input: " + input);
- job.setJarByClass(ClusterClassificationDriver.class);
+ ClassNotFoundException {
conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
clusterClassificationThreshold.floatValue());
+ conf.setBoolean(EMIT_MOST_LIKELY, emitMostLikely);
+ conf.set(CLUSTERS_IN, clustersIn.toUri().toString());
- conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
+ Job job = new Job(conf,
+ "Cluster Classification Driver running over input: " + input);
+ job.setJarByClass(ClusterClassificationDriver.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Tue Feb 28 04:00:09 2012
@@ -17,6 +17,8 @@
package org.apache.mahout.clustering.classify;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.CLUSTERS_IN;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.EMIT_MOST_LIKELY;
import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
import java.io.IOException;
@@ -29,6 +31,7 @@ import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
@@ -37,70 +40,114 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
/**
* Mapper for classifying vectors into clusters.
*/
-public class ClusterClassificationMapper extends Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
+public class ClusterClassificationMapper extends
+ Mapper<LongWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
private static double threshold;
private List<Cluster> clusterModels;
private ClusterClassifier clusterClassifier;
private IntWritable clusterId;
private WeightedVectorWritable weightedVW;
+ private boolean emitMostLikely;
@Override
- protected void setup(Context context) throws IOException, InterruptedException {
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
- String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+ String clustersIn = conf.get(CLUSTERS_IN);
+ threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+ emitMostLikely = conf.getBoolean(EMIT_MOST_LIKELY, false);
clusterModels = new ArrayList<Cluster>();
if (clustersIn != null && !clustersIn.isEmpty()) {
- Path clustersInPath = new Path(clustersIn, "*");
- populateClusterModels(clustersInPath);
- ClusteringPolicy policy = ClusterClassifier.readPolicy(clustersInPath);
+ Path clustersInPath = new Path(clustersIn);
+ clusterModels = populateClusterModels(clustersInPath);
+ ClusteringPolicy policy = ClusterClassifier
+ .readPolicy(finalClustersPath(clustersInPath));
clusterClassifier = new ClusterClassifier(clusterModels, policy);
}
- threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
clusterId = new IntWritable();
weightedVW = new WeightedVectorWritable(1, null);
}
+ /**
+ * Mapper which classifies the vectors to respective clusters.
+ */
@Override
- protected void map(IntWritable key, VectorWritable vw, Context context) throws IOException, InterruptedException {
+ protected void map(LongWritable key, VectorWritable vw, Context context)
+ throws IOException, InterruptedException {
if (!clusterModels.isEmpty()) {
Vector pdfPerCluster = clusterClassifier.classify(vw.get());
if (shouldClassify(pdfPerCluster)) {
- int maxValueIndex = pdfPerCluster.maxValueIndex();
- Cluster cluster = clusterModels.get(maxValueIndex);
- clusterId.set(cluster.getId());
- weightedVW.setVector(vw.get());
- context.write(clusterId, weightedVW);
+ if (emitMostLikely) {
+ int maxValueIndex = pdfPerCluster.maxValueIndex();
+ write(vw, context, maxValueIndex);
+ } else {
+ writeAllAboveThreshold(vw, context, pdfPerCluster);
+ }
}
}
}
- public static List<Cluster> populateClusterModels(Path clusterOutputPath) throws IOException {
- List<Cluster> clusterModels = new ArrayList<Cluster>();
+ private void writeAllAboveThreshold(VectorWritable vw, Context context,
+ Vector pdfPerCluster) throws IOException, InterruptedException {
+ Iterator<Element> iterateNonZero = pdfPerCluster.iterateNonZero();
+ while (iterateNonZero.hasNext()) {
+ Element pdf = iterateNonZero.next();
+ if (pdf.get() >= threshold) {
+ int clusterIndex = pdf.index();
+ write(vw, context, clusterIndex);
+ }
+ }
+ }
+
+ private void write(VectorWritable vw, Context context, int clusterIndex)
+ throws IOException, InterruptedException {
+ Cluster cluster = clusterModels.get(clusterIndex);
+ clusterId.set(cluster.getId());
+ weightedVW.setVector(vw.get());
+ context.write(clusterId, weightedVW);
+ }
+
+ public static List<Cluster> populateClusterModels(Path clusterOutputPath)
+ throws IOException {
+ List<Cluster> clusters = new ArrayList<Cluster>();
Configuration conf = new Configuration();
Cluster cluster = null;
FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
- FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
- Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(), PathType.LIST,
- PathFilters.partFilter(), null, false, conf);
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+ PathFilters.finalPartFilter());
+ Iterator<?> it = new SequenceFileDirValueIterator<Writable>(
+ clusterFiles[0].getPath(), PathType.LIST, PathFilters.partFilter(),
+ null, false, conf);
while (it.hasNext()) {
cluster = (Cluster) it.next();
- clusterModels.add(cluster);
+ clusters.add(cluster);
}
- return clusterModels;
+ return clusters;
}
private static boolean shouldClassify(Vector pdfPerCluster) {
- return pdfPerCluster.maxValue() >= threshold;
+ boolean isMaxPDFGreatherThanThreshold = pdfPerCluster.maxValue() >= threshold;
+ return isMaxPDFGreatherThanThreshold;
}
+ private static Path finalClustersPath(Path clusterOutputPath)
+ throws IOException {
+ FileSystem fileSystem = clusterOutputPath
+ .getFileSystem(new Configuration());
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+ PathFilters.finalPartFilter());
+ Path finalClustersPath = clusterFiles[0].getPath();
+ return finalClustersPath;
+ }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java Tue Feb 28 04:00:09 2012
@@ -35,8 +35,10 @@ import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
+import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -89,6 +91,25 @@ public class ClusterClassificationDriver
}
@Test
+ public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classifiedClusters");
+ HadoopUtil.delete(conf, classifiedOutputPath);
+
+ conf = new Configuration();
+
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, false);
+ runClassificationWithOutlierRemoval(conf, false);
+ collectVectorsForAssertion();
+ assertVectorsWithOutlierRemoval();
+ }
+
+ @Test
public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
List<VectorWritable> points = getPointsWritable(REFERENCE);
@@ -100,7 +121,7 @@ public class ClusterClassificationDriver
ClusteringTestUtils.writePointsToFile(points,
new Path(pointsPath, "file1"), fs, conf);
- runClustering(pointsPath, conf);
+ runClustering(pointsPath, conf, true);
runClassificationWithoutOutlierRemoval(conf);
collectVectorsForAssertion();
assertVectorsWithoutOutlierRemoval();
@@ -118,16 +139,17 @@ public class ClusterClassificationDriver
ClusteringTestUtils.writePointsToFile(points,
new Path(pointsPath, "file1"), fs, conf);
- runClustering(pointsPath, conf);
- runClassificationWithOutlierRemoval(conf);
+ runClustering(pointsPath, conf, true);
+ runClassificationWithOutlierRemoval(conf, true);
collectVectorsForAssertion();
assertVectorsWithOutlierRemoval();
}
- private void runClustering(Path pointsPath, Configuration conf)
- throws IOException, InterruptedException, ClassNotFoundException {
+ private void runClustering(Path pointsPath, Configuration conf,
+ Boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
- new ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+ new ManhattanDistanceMeasure(), 3.1, 2.1, false, runSequential);
Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
finalClustersPath);
@@ -139,23 +161,25 @@ public class ClusterClassificationDriver
classifiedOutputPath, 0.0, true, true);
}
- private void runClassificationWithOutlierRemoval(Configuration conf2)
- throws IOException, InterruptedException, ClassNotFoundException {
+ private void runClassificationWithOutlierRemoval(Configuration conf2,
+ boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
- classifiedOutputPath, 0.73, true, true);
+ classifiedOutputPath, 0.73, true, runSequential);
}
private void collectVectorsForAssertion() throws IOException {
Path[] partFilePaths = FileUtil.stat2Paths(fs
.globStatus(classifiedOutputPath));
- FileStatus[] listStatus = fs.listStatus(partFilePaths);
+ FileStatus[] listStatus = fs.listStatus(partFilePaths,
+ PathFilters.partFilter());
for (FileStatus partFile : listStatus) {
SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
partFile.getPath(), conf);
Writable clusterIdAsKey = new IntWritable();
- VectorWritable point = new VectorWritable();
+ WeightedVectorWritable point = new WeightedVectorWritable();
while (classifiedVectors.next(clusterIdAsKey, point)) {
- collectVector(clusterIdAsKey.toString(), point.get());
+ collectVector(clusterIdAsKey.toString(), point.getVector());
}
}
}