You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pr...@apache.org on 2012/02/28 05:00:10 UTC

svn commit: r1294454 - in /mahout/trunk/core/src: main/java/org/apache/mahout/clustering/classify/ test/java/org/apache/mahout/clustering/classify/

Author: pranjan
Date: Tue Feb 28 04:00:09 2012
New Revision: 1294454

URL: http://svn.apache.org/viewvc?rev=1294454&view=rev
Log:
MAHOUT-929, MAHOUT-931. Implemented mapreduce version of ClusterClassificationDriver with outlier removal capability.
Changed output of sequential to WeightedVectorWritable. Fixed and added test cases.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java Tue Feb 28 04:00:09 2012
@@ -18,14 +18,14 @@
 package org.apache.mahout.clustering.classify;
 
 /**
- * Constants used in Cluster Classification. 
+ * Constants used in Cluster Classification.
  */
 public class ClusterClassificationConfigKeys {
-
+  
   public static final String CLUSTERS_IN = "clusters_in";
   
   public static final String OUTLIER_REMOVAL_THRESHOLD = "pdf_threshold";
   
+  public static final String EMIT_MOST_LIKELY = "emit_most_likely";
+  
 }
-
-

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Tue Feb 28 04:00:09 2012
@@ -17,6 +17,8 @@
 
 package org.apache.mahout.clustering.classify;
 
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.CLUSTERS_IN;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.EMIT_MOST_LIKELY;
 import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
 
 import java.io.IOException;
@@ -213,7 +215,7 @@ public class ClusterClassificationDriver
     Configuration conf = new Configuration();
     SequenceFile.Writer writer = new SequenceFile.Writer(
         input.getFileSystem(conf), conf, new Path(output, "part-m-" + 0),
-        IntWritable.class, VectorWritable.class);
+        IntWritable.class, WeightedVectorWritable.class);
     for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
         input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
       Vector pdfPerCluster = clusterClassifier.classify(vw.get());
@@ -231,7 +233,9 @@ public class ClusterClassificationDriver
       throws IOException {
     if (emitMostLikely) {
       int maxValueIndex = pdfPerCluster.maxValueIndex();
-      write(clusterModels, writer, vw, maxValueIndex);
+      WeightedVectorWritable wvw = new WeightedVectorWritable(
+          pdfPerCluster.maxValue(), vw.get());
+      write(clusterModels, writer, wvw, maxValueIndex);
     } else {
       writeAllAboveThreshold(clusterModels, clusterClassificationThreshold,
           writer, vw, pdfPerCluster);
@@ -245,17 +249,19 @@ public class ClusterClassificationDriver
     while (iterateNonZero.hasNext()) {
       Element pdf = iterateNonZero.next();
       if (pdf.get() >= clusterClassificationThreshold) {
+        WeightedVectorWritable wvw = new WeightedVectorWritable(pdf.get(),
+            vw.get());
         int clusterIndex = pdf.index();
-        write(clusterModels, writer, vw, clusterIndex);
+        write(clusterModels, writer, wvw, clusterIndex);
       }
     }
   }
   
   private static void write(List<Cluster> clusterModels,
-      SequenceFile.Writer writer, VectorWritable vw, int maxValueIndex)
+      SequenceFile.Writer writer, WeightedVectorWritable wvw, int maxValueIndex)
       throws IOException {
     Cluster cluster = clusterModels.get(maxValueIndex);
-    writer.append(new IntWritable(cluster.getId()), vw);
+    writer.append(new IntWritable(cluster.getId()), wvw);
   }
   
   /**
@@ -273,15 +279,16 @@ public class ClusterClassificationDriver
   private static void classifyClusterMR(Configuration conf, Path input,
       Path clustersIn, Path output, Double clusterClassificationThreshold,
       boolean emitMostLikely) throws IOException, InterruptedException,
-      ClassNotFoundException {
-    Job job = new Job(conf,
-        "Cluster Classification Driver running over input: " + input);
-    job.setJarByClass(ClusterClassificationDriver.class);
+      ClassNotFoundException {    
     
     conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
         clusterClassificationThreshold.floatValue());
+    conf.setBoolean(EMIT_MOST_LIKELY, emitMostLikely);
+    conf.set(CLUSTERS_IN, clustersIn.toUri().toString());
     
-    conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
+    Job job = new Job(conf,
+        "Cluster Classification Driver running over input: " + input);
+    job.setJarByClass(ClusterClassificationDriver.class);
     
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setOutputFormatClass(SequenceFileOutputFormat.class);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Tue Feb 28 04:00:09 2012
@@ -17,6 +17,8 @@
 
 package org.apache.mahout.clustering.classify;
 
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.CLUSTERS_IN;
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.EMIT_MOST_LIKELY;
 import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
 
 import java.io.IOException;
@@ -29,6 +31,7 @@ import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.clustering.Cluster;
@@ -37,70 +40,114 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.common.iterator.sequencefile.PathType;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
 import org.apache.mahout.math.VectorWritable;
 
 /**
  * Mapper for classifying vectors into clusters.
  */
-public class ClusterClassificationMapper extends Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
+public class ClusterClassificationMapper extends
+    Mapper<LongWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
   
   private static double threshold;
   private List<Cluster> clusterModels;
   private ClusterClassifier clusterClassifier;
   private IntWritable clusterId;
   private WeightedVectorWritable weightedVW;
+  private boolean emitMostLikely;
   
   @Override
-  protected void setup(Context context) throws IOException, InterruptedException {
+  protected void setup(Context context) throws IOException,
+      InterruptedException {
     super.setup(context);
     
     Configuration conf = context.getConfiguration();
-    String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+    String clustersIn = conf.get(CLUSTERS_IN);
+    threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+    emitMostLikely = conf.getBoolean(EMIT_MOST_LIKELY, false);
     
     clusterModels = new ArrayList<Cluster>();
     
     if (clustersIn != null && !clustersIn.isEmpty()) {
-      Path clustersInPath = new Path(clustersIn, "*");
-      populateClusterModels(clustersInPath);
-      ClusteringPolicy policy = ClusterClassifier.readPolicy(clustersInPath);
+      Path clustersInPath = new Path(clustersIn);
+      clusterModels = populateClusterModels(clustersInPath);
+      ClusteringPolicy policy = ClusterClassifier
+          .readPolicy(finalClustersPath(clustersInPath));
       clusterClassifier = new ClusterClassifier(clusterModels, policy);
     }
-    threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
     clusterId = new IntWritable();
     weightedVW = new WeightedVectorWritable(1, null);
   }
   
+  /**
+   * Mapper which classifies the vectors to respective clusters.
+   */
   @Override
-  protected void map(IntWritable key, VectorWritable vw, Context context) throws IOException, InterruptedException {
+  protected void map(LongWritable key, VectorWritable vw, Context context)
+      throws IOException, InterruptedException {
     if (!clusterModels.isEmpty()) {
       Vector pdfPerCluster = clusterClassifier.classify(vw.get());
       if (shouldClassify(pdfPerCluster)) {
-        int maxValueIndex = pdfPerCluster.maxValueIndex();
-        Cluster cluster = clusterModels.get(maxValueIndex);
-        clusterId.set(cluster.getId());
-        weightedVW.setVector(vw.get());
-        context.write(clusterId, weightedVW);
+        if (emitMostLikely) {
+          int maxValueIndex = pdfPerCluster.maxValueIndex();
+          write(vw, context, maxValueIndex);
+        } else {
+          writeAllAboveThreshold(vw, context, pdfPerCluster);
+        }
       }
     }
   }
   
-  public static List<Cluster> populateClusterModels(Path clusterOutputPath) throws IOException {
-    List<Cluster> clusterModels = new ArrayList<Cluster>();
+  private void writeAllAboveThreshold(VectorWritable vw, Context context,
+      Vector pdfPerCluster) throws IOException, InterruptedException {
+    Iterator<Element> iterateNonZero = pdfPerCluster.iterateNonZero();
+    while (iterateNonZero.hasNext()) {
+      Element pdf = iterateNonZero.next();
+      if (pdf.get() >= threshold) {
+        int clusterIndex = pdf.index();
+        write(vw, context, clusterIndex);
+      }
+    }
+  }
+  
+  private void write(VectorWritable vw, Context context, int clusterIndex)
+      throws IOException, InterruptedException {
+    Cluster cluster = clusterModels.get(clusterIndex);
+    clusterId.set(cluster.getId());
+    weightedVW.setVector(vw.get());
+    context.write(clusterId, weightedVW);
+  }
+  
+  public static List<Cluster> populateClusterModels(Path clusterOutputPath)
+      throws IOException {
+    List<Cluster> clusters = new ArrayList<Cluster>();
     Configuration conf = new Configuration();
     Cluster cluster = null;
     FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
-    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
-    Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(), PathType.LIST,
-        PathFilters.partFilter(), null, false, conf);
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+        PathFilters.finalPartFilter());
+    Iterator<?> it = new SequenceFileDirValueIterator<Writable>(
+        clusterFiles[0].getPath(), PathType.LIST, PathFilters.partFilter(),
+        null, false, conf);
     while (it.hasNext()) {
       cluster = (Cluster) it.next();
-      clusterModels.add(cluster);
+      clusters.add(cluster);
     }
-    return clusterModels;
+    return clusters;
   }
   
   private static boolean shouldClassify(Vector pdfPerCluster) {
-    return pdfPerCluster.maxValue() >= threshold;
+    boolean isMaxPDFGreatherThanThreshold = pdfPerCluster.maxValue() >= threshold;
+    return isMaxPDFGreatherThanThreshold;
   }
   
+  private static Path finalClustersPath(Path clusterOutputPath)
+      throws IOException {
+    FileSystem fileSystem = clusterOutputPath
+        .getFileSystem(new Configuration());
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+        PathFilters.finalPartFilter());
+    Path finalClustersPath = clusterFiles[0].getPath();
+    return finalClustersPath;
+  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1294454&r1=1294453&r2=1294454&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java Tue Feb 28 04:00:09 2012
@@ -35,8 +35,10 @@ import org.apache.hadoop.io.Writable;
 import org.apache.mahout.clustering.ClusteringTestUtils;
 import org.apache.mahout.clustering.canopy.CanopyDriver;
 import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
+import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -89,6 +91,25 @@ public class ClusterClassificationDriver
   }
   
   @Test
+  public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    
+    pointsPath = getTestTempDirPath("points");
+    clusteringOutputPath = getTestTempDirPath("output");
+    classifiedOutputPath = getTestTempDirPath("classifiedClusters");
+    HadoopUtil.delete(conf, classifiedOutputPath);
+    
+    conf = new Configuration();
+    
+    ClusteringTestUtils.writePointsToFile(points,
+        new Path(pointsPath, "file1"), fs, conf);
+    runClustering(pointsPath, conf, false);
+    runClassificationWithOutlierRemoval(conf, false);
+    collectVectorsForAssertion();
+    assertVectorsWithOutlierRemoval();
+  }
+  
+  @Test
   public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
     List<VectorWritable> points = getPointsWritable(REFERENCE);
     
@@ -100,7 +121,7 @@ public class ClusterClassificationDriver
     
     ClusteringTestUtils.writePointsToFile(points,
         new Path(pointsPath, "file1"), fs, conf);
-    runClustering(pointsPath, conf);
+    runClustering(pointsPath, conf, true);
     runClassificationWithoutOutlierRemoval(conf);
     collectVectorsForAssertion();
     assertVectorsWithoutOutlierRemoval();
@@ -118,16 +139,17 @@ public class ClusterClassificationDriver
     
     ClusteringTestUtils.writePointsToFile(points,
         new Path(pointsPath, "file1"), fs, conf);
-    runClustering(pointsPath, conf);
-    runClassificationWithOutlierRemoval(conf);
+    runClustering(pointsPath, conf, true);
+    runClassificationWithOutlierRemoval(conf, true);
     collectVectorsForAssertion();
     assertVectorsWithOutlierRemoval();
   }
   
-  private void runClustering(Path pointsPath, Configuration conf)
-      throws IOException, InterruptedException, ClassNotFoundException {
+  private void runClustering(Path pointsPath, Configuration conf,
+      Boolean runSequential) throws IOException, InterruptedException,
+      ClassNotFoundException {
     CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
-        new ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+        new ManhattanDistanceMeasure(), 3.1, 2.1, false, runSequential);
     Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
     ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
         finalClustersPath);
@@ -139,23 +161,25 @@ public class ClusterClassificationDriver
         classifiedOutputPath, 0.0, true, true);
   }
   
-  private void runClassificationWithOutlierRemoval(Configuration conf2)
-      throws IOException, InterruptedException, ClassNotFoundException {
+  private void runClassificationWithOutlierRemoval(Configuration conf2,
+      boolean runSequential) throws IOException, InterruptedException,
+      ClassNotFoundException {
     ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
-        classifiedOutputPath, 0.73, true, true);
+        classifiedOutputPath, 0.73, true, runSequential);
   }
   
   private void collectVectorsForAssertion() throws IOException {
     Path[] partFilePaths = FileUtil.stat2Paths(fs
         .globStatus(classifiedOutputPath));
-    FileStatus[] listStatus = fs.listStatus(partFilePaths);
+    FileStatus[] listStatus = fs.listStatus(partFilePaths,
+        PathFilters.partFilter());
     for (FileStatus partFile : listStatus) {
       SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
           partFile.getPath(), conf);
       Writable clusterIdAsKey = new IntWritable();
-      VectorWritable point = new VectorWritable();
+      WeightedVectorWritable point = new WeightedVectorWritable();
       while (classifiedVectors.next(clusterIdAsKey, point)) {
-        collectVector(clusterIdAsKey.toString(), point.get());
+        collectVector(clusterIdAsKey.toString(), point.getVector());
       }
     }
   }