You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2012/02/14 20:47:43 UTC

svn commit: r1244191 - in /mahout/trunk/core/src: main/java/org/apache/mahout/clustering/classify/ main/java/org/apache/mahout/clustering/topdown/postprocessor/ main/java/org/apache/mahout/common/commandline/ main/java/org/apache/mahout/common/iterator...

Author: jeastman
Date: Tue Feb 14 19:47:43 2012
New Revision: 1244191

URL: http://svn.apache.org/viewvc?rev=1244191&view=rev
Log:
MAHOUT-929: Committing patch Mahout-929. All tests run

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java   (with props)
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java   (with props)
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java   (with props)
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java   (with props)
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java
    mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
    mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java?rev=1244191&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java Tue Feb 14 19:47:43 2012
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+/**
+ * Constants used in Cluster Classification. 
+ */
+public class ClusterClassificationConfigKeys {
+
+  public static final String CLUSTERS_IN = "clusters_in";
+  
+  public static final String OUTLIER_REMOVAL_THRESHOLD = "pdf_threshold";
+  
+}
+
+

Propchange: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1244191&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java Tue Feb 14 19:47:43 2012
@@ -0,0 +1,238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.WeightedVectorWritable;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Classifies the vectors into different clusters found by the clustering algorithm.
+ */
+public class ClusterClassificationDriver extends AbstractJob {
+	  
+    /**
+	   * CLI to run Cluster Classification Driver.
+	   */
+	  @Override
+	  public int run(String[] args) throws Exception {
+	    
+	    addInputOption();
+	    addOutputOption();
+	    addOption(DefaultOptionCreator.methodOption().create());
+	    addOption(DefaultOptionCreator.clustersInOption()
+	            .withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.")
+	            .create());
+
+	    if (parseArguments(args) == null) {
+	      return -1;
+	    }
+	    
+	    Path input = getInputPath();
+	    Path output = getOutputPath();
+
+	    if (getConf() == null) {
+	      setConf(new Configuration());
+	    }
+	    Path clustersIn = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+	    boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+	      DefaultOptionCreator.SEQUENTIAL_METHOD);
+	    
+	    double clusterClassificationThreshold = 0.0;
+	    if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+	      clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+	    }
+	    
+      run(input, clustersIn, output, clusterClassificationThreshold , runSequential);
+      
+	    return 0;
+	  }
+	  
+	  /**
+	   * Constructor to be used by the ToolRunner.
+	   */
+	  private ClusterClassificationDriver() {}
+	  
+	  public static void main(String[] args) throws Exception {
+	    ToolRunner.run(new Configuration(), new ClusterClassificationDriver(), args);
+	  }
+	  
+	  /**
+	   * Uses {@link ClusterClassifier} to classify input vectors into their respective clusters.
+	   * 
+	   * @param input 
+	   *         the input vectors
+	   * @param clusteringOutputPath
+	   *         the output path of clustering ( it reads clusters-*-final file from here )
+	   * @param output
+	   *         the location to store the classified vectors
+	   * @param clusterClassificationThreshold
+	   *         the threshold value of probability distribution function from 0.0 to 1.0. 
+	   *         Any vector with pdf less that this threshold will not be classified for the cluster.
+	   * @param runSequential
+	   *         Run the process sequentially or in a mapreduce way.
+	   * @throws IOException
+	   * @throws InterruptedException
+	   * @throws ClassNotFoundException
+	   */
+	  public static void run(Path input, Path clusteringOutputPath, Path output, Double clusterClassificationThreshold, boolean runSequential) throws IOException,
+	                                                                        InterruptedException,
+	                                                                        ClassNotFoundException {
+	    if (runSequential) {
+	      classifyClusterSeq(input, clusteringOutputPath, output, clusterClassificationThreshold);
+	    } else {
+	      Configuration conf = new Configuration();
+	      classifyClusterMR(conf, input, clusteringOutputPath, output, clusterClassificationThreshold);
+	    }
+	    
+	  }
+	  
+	  private static void classifyClusterSeq(Path input, Path clusters, Path output, Double clusterClassificationThreshold) throws IOException {
+	    List<Cluster> clusterModels = populateClusterModels(clusters);
+	    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels);
+      selectCluster(input, clusterModels, clusterClassifier, output, clusterClassificationThreshold);
+      
+	  }
+
+	  /**
+	   * Populates a list with clusters present in clusters-*-final directory.
+	   * 
+	   * @param clusterOutputPath
+	   *             The output path of the clustering.
+	   * @return
+	   *             The list of clusters found by the clustering.
+	   * @throws IOException
+	   */
+    private static List<Cluster> populateClusterModels(Path clusterOutputPath) throws IOException {
+      List<Cluster> clusterModels = new ArrayList<Cluster>();
+      Configuration conf = new Configuration();
+      Cluster cluster = null;
+      FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+      FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+      Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
+                                                                  PathType.LIST,
+                                                                  PathFilters.partFilter(),
+                                                                  null,
+                                                                  false,
+                                                                  conf);
+      while (it.hasNext()) {
+        cluster = (Cluster) it.next();
+        clusterModels.add(cluster);
+      }
+      return clusterModels;
+    }
+	  
+    /**
+     * Classifies the vector into its respective cluster.
+     * 
+     * @param input 
+     *            the path containing the input vector.
+     * @param clusterModels
+     *            the clusters
+     * @param clusterClassifier
+     *            used to classify the vectors into different clusters
+     * @param output
+     *            the path to store classified data
+     * @param clusterClassificationThreshold
+     * @throws IOException
+     */
+	  private static void selectCluster(Path input, List<Cluster> clusterModels, ClusterClassifier clusterClassifier, Path output, Double clusterClassificationThreshold) throws IOException {
+	    Configuration conf = new Configuration();
+	    SequenceFile.Writer writer = new SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(
+          output, "part-m-" + 0), IntWritable.class,
+          VectorWritable.class);
+	    for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+	        input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
+        Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+        if(shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
+          int maxValueIndex = pdfPerCluster.maxValueIndex();
+          Cluster cluster = clusterModels.get(maxValueIndex);
+          writer.append(new IntWritable(cluster.getId()), vw);
+        }
+	    }
+	    writer.close();
+    }
+
+	  /**
+	   * Decides whether the vector should be classified or not based on the max pdf value of the clusters and threshold value.
+	   * 
+	   * @param pdfPerCluster
+	   *         pdf of vector belonging to different clusters.
+	   * @param clusterClassificationThreshold
+	   *         threshold below which the vectors won't be classified.
+	   * @return whether the vector should be classified or not.
+	   */
+    private static boolean shouldClassify(Vector pdfPerCluster, Double clusterClassificationThreshold) {
+      return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+    }
+
+	  private static void classifyClusterMR(Configuration conf, Path input, Path clustersIn, Path output, Double clusterClassificationThreshold) throws IOException,
+	                                                                                InterruptedException,
+	                                                                                ClassNotFoundException {
+	    Job job = new Job(conf, "Cluster Classification Driver running over input: " + input);
+	    job.setJarByClass(ClusterClassificationDriver.class);
+	    
+	    conf.setFloat(OUTLIER_REMOVAL_THRESHOLD, clusterClassificationThreshold.floatValue());
+	    
+	    conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
+	    
+	    job.setInputFormatClass(SequenceFileInputFormat.class);
+	    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+	    
+	    job.setMapperClass(ClusterClassificationMapper.class);
+	    job.setNumReduceTasks(0);
+	    
+	    job.setOutputKeyClass(IntWritable.class);
+	    job.setOutputValueClass(WeightedVectorWritable.class);
+	    
+	    FileInputFormat.addInputPath(job, input);
+	    FileOutputFormat.setOutputPath(job, output);
+	    if (!job.waitForCompletion(true)) {
+	      throw new InterruptedException("Cluster Classification Driver Job failed processing " + input);
+	    }
+	  }
+	  
+	}

Propchange: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1244191&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java Tue Feb 14 19:47:43 2012
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import static org.apache.mahout.clustering.classify.ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.WeightedVectorWritable;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Mapper for classifying vectors into clusters.
+ */
+public class ClusterClassificationMapper extends
+    Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
+  
+  private static double threshold;
+  private List<Cluster> clusterModels;
+  private ClusterClassifier clusterClassifier;
+  private IntWritable clusterId;
+  private WeightedVectorWritable weightedVW;
+
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+      super.setup(context);
+
+      Configuration conf = context.getConfiguration();
+      String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+      
+      clusterModels = new ArrayList<Cluster>();
+      
+      if (clustersIn != null && !clustersIn.isEmpty()) {
+        Path clustersInPath = new Path(clustersIn, "*");
+        populateClusterModels(clustersInPath);
+        clusterClassifier = new ClusterClassifier(clusterModels);
+      }
+      threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+      clusterId = new IntWritable();
+      weightedVW = new WeightedVectorWritable(1, null);
+    }
+  
+  @Override
+  protected void map(IntWritable key, VectorWritable vw, Context context) throws IOException,
+                                                                                     InterruptedException {
+    if(!clusterModels.isEmpty()) {
+      Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+      if(shouldClassify(pdfPerCluster)) {
+        int maxValueIndex = pdfPerCluster.maxValueIndex();
+        Cluster cluster = clusterModels.get(maxValueIndex);
+        clusterId.set(cluster.getId());
+        weightedVW.setVector(vw.get());
+        context.write(clusterId, weightedVW);
+      }
+    }
+  }
+  
+  public static List<Cluster> populateClusterModels(Path clusterOutputPath) throws IOException {
+    List<Cluster> clusterModels = new ArrayList<Cluster>();
+    Configuration conf = new Configuration();
+    Cluster cluster = null;
+    FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+    Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
+                                                                PathType.LIST,
+                                                                PathFilters.partFilter(),
+                                                                null,
+                                                                false,
+                                                                conf);
+    while (it.hasNext()) {
+      cluster = (Cluster) it.next();
+      clusterModels.add(cluster);
+    }
+    return clusterModels;
+  }
+  
+  private static boolean shouldClassify(Vector pdfPerCluster) {
+    return pdfPerCluster.maxValue() >= threshold;
+  }
+  
+}

Propchange: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java?rev=1244191&r1=1244190&r2=1244191&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java Tue Feb 14 19:47:43 2012
@@ -24,7 +24,6 @@ import org.apache.hadoop.conf.Configurat
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.fs.PathFilter;
 import org.apache.hadoop.io.Writable;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -49,7 +48,7 @@ public final class ClusterCountReader {
    */
   public static int getNumberOfClusters(Path clusterOutputPath, Configuration conf) throws IOException {
     FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
-    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, CLUSTER_FINAL);
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
     int numberOfClusters = 0;
     Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
                                                                 PathType.LIST,
@@ -64,14 +63,4 @@ public final class ClusterCountReader {
     return numberOfClusters;
   }
 
-  /**
-   * Pathfilter to read the final clustering file.
-   */
-  private static final PathFilter CLUSTER_FINAL = new PathFilter() {
-    @Override
-    public boolean accept(Path path) {
-      String name = path.getName();
-      return name.startsWith("clusters-") && name.endsWith("-final");
-    }
-  };
 }
\ No newline at end of file

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java?rev=1244191&r1=1244190&r2=1244191&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java Tue Feb 14 19:47:43 2012
@@ -59,6 +59,8 @@ public final class DefaultOptionCreator 
   
   public static final String T4_OPTION = "t4";
   
+  public static final String OUTLIER_THRESHOLD = "outlierThreshold";
+  
   public static final String CLUSTER_FILTER_OPTION = "clusterFilter";
   
   public static final String THRESHOLD_OPTION = "threshold";
@@ -403,4 +405,18 @@ public static DefaultOptionBuilder clust
             "If present, the input directory already contains MeanShiftCanopies");
   }
   
+  /**
+   * Returns a default command line option for specification of OUTLIER THRESHOLD value. Used by
+   * Cluster Classification.
+   */
+  public static DefaultOptionBuilder classificationThresholdOption() {
+    return new DefaultOptionBuilder()
+        .withLongName(OUTLIER_THRESHOLD)
+        .withRequired(false)
+        .withArgument(
+            new ArgumentBuilder().withName(OUTLIER_THRESHOLD).withMinimum(1)
+                .withMaximum(1).create()).withDescription("Outlier threshold value")
+        .withShortName(OUTLIER_THRESHOLD);
+  }
+  
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java?rev=1244191&r1=1244190&r2=1244191&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java Tue Feb 14 19:47:43 2012
@@ -32,6 +32,17 @@ public final class PathFilters {
       return name.startsWith("part-") && !name.endsWith(".crc");
     }
   };
+  
+  /**
+   * Pathfilter to read the final clustering file.
+   */
+  private static final PathFilter CLUSTER_FINAL = new PathFilter() {
+    @Override
+    public boolean accept(Path path) {
+      String name = path.getName();
+      return name.startsWith("clusters-") && name.endsWith("-final");
+    }
+  };
 
   private static final PathFilter LOGS_CRC_INSTANCE = new PathFilter() {
     @Override
@@ -51,6 +62,13 @@ public final class PathFilters {
   public static PathFilter partFilter() {
     return PART_FILE_INSTANCE;
   }
+  
+  /**
+   * @return {@link PathFilter} that accepts paths whose file name starts with "part-" and ends with "-final".
+   */
+  public static PathFilter finalPartFilter() {
+    return CLUSTER_FINAL;
+  }
 
   /**
    * @return {@link PathFilter} that rejects paths whose file name starts with "_" (e.g. Cloudera

Added: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1244191&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java Tue Feb 14 19:47:43 2012
@@ -0,0 +1,216 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import junit.framework.Assert;
+
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class ClusterClassificationDriverTest extends MahoutTestCase{
+  
+  private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
+  
+  private FileSystem fs;
+  
+  private Path clusteringOutputPath;
+  
+  private Configuration conf;
+
+  private Path pointsPath;
+
+  private Path classifiedOutputPath;
+
+  private List<Vector> firstCluster;
+  
+  private List<Vector> secondCluster;
+  
+  private List<Vector> thirdCluster;
+  
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    Configuration conf = new Configuration();
+    fs = FileSystem.get(conf);
+    firstCluster = new ArrayList<Vector>();
+    secondCluster = new ArrayList<Vector>();
+    thirdCluster = new ArrayList<Vector>();
+    
+  }
+  
+  private static List<VectorWritable> getPointsWritable(double[][] raw) {
+    List<VectorWritable> points = Lists.newArrayList();
+    for (double[] fr : raw) {
+      Vector vec = new RandomAccessSparseVector(fr.length);
+      vec.assign(fr);
+      points.add(new VectorWritable(vec));
+    }
+    return points;
+  }
+  
+  @Test
+  public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    
+    pointsPath = getTestTempDirPath("points");
+    clusteringOutputPath = getTestTempDirPath("output");
+    classifiedOutputPath = getTestTempDirPath("classify");
+
+    conf = new Configuration();
+    
+    ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+    runClustering(pointsPath, conf);
+    runClassificationWithoutOutlierRemoval(conf);
+    collectVectorsForAssertion();
+    assertVectorsWithoutOutlierRemoval();
+  }
+  
+  @Test
+  public void testVectorClassificationWithOutlierRemoval() throws Exception {
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    
+    pointsPath = getTestTempDirPath("points");
+    clusteringOutputPath = getTestTempDirPath("output");
+    classifiedOutputPath = getTestTempDirPath("classify");
+
+    conf = new Configuration();
+    
+    ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+    runClustering(pointsPath, conf);
+    runClassificationWithOutlierRemoval(conf);
+    collectVectorsForAssertion();
+    assertVectorsWithOutlierRemoval();
+  }
+  
+  private void runClustering(Path pointsPath, Configuration conf) throws IOException,
+  InterruptedException,
+  ClassNotFoundException {
+    CanopyDriver.run(conf, pointsPath, clusteringOutputPath, new ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+  }
+  
+  private void runClassificationWithoutOutlierRemoval(Configuration conf) throws IOException, InterruptedException, ClassNotFoundException {
+    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, classifiedOutputPath, 0.0, true);
+  }
+  
+  private void runClassificationWithOutlierRemoval(Configuration conf2) throws IOException, InterruptedException, ClassNotFoundException {
+    ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, classifiedOutputPath, 0.73, true);
+  }
+
+  private void collectVectorsForAssertion() throws IOException {
+    Path[] partFilePaths = FileUtil.stat2Paths(fs.globStatus(classifiedOutputPath));
+    FileStatus[] listStatus = fs.listStatus(partFilePaths);
+    for (FileStatus partFile : listStatus) {
+      SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs, partFile.getPath(), conf);
+      Writable clusterIdAsKey = new IntWritable();
+      VectorWritable point = new VectorWritable();
+      while (classifiedVectors.next(clusterIdAsKey, point)) {
+        collectVector(clusterIdAsKey.toString(), point.get());
+      }
+    }
+  }
+  
+  private void collectVector(String clusterId, Vector vector) {
+    if(clusterId.equals("0")) {
+      firstCluster.add(vector);
+    }
+    else if(clusterId.equals("1")) {
+      secondCluster.add(vector);
+    }
+    else if(clusterId.equals("2")) {
+      thirdCluster.add(vector);
+    }
+  }
+  
+  private void assertVectorsWithOutlierRemoval() {
+    assertFirstClusterWithOutlierRemoval();
+    assertSecondClusterWithOutlierRemoval();
+    assertThirdClusterWithOutlierRemoval();
+  }
+
+  private void assertVectorsWithoutOutlierRemoval() {
+    assertFirstClusterWithoutOutlierRemoval();
+    assertSecondClusterWithoutOutlierRemoval();
+    assertThirdClusterWithoutOutlierRemoval();
+  }
+
+  private void assertThirdClusterWithoutOutlierRemoval() {
+    Assert.assertEquals(2, thirdCluster.size());
+    for (Vector vector : thirdCluster) {
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}", "{1:8.0,0:8.0}"}, vector.asFormatString()));
+    }
+  }
+
+  private void assertSecondClusterWithoutOutlierRemoval() {
+    Assert.assertEquals(4, secondCluster.size());
+    for (Vector vector : secondCluster) {
+    Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}", "{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
+    "{1:5.0,0:5.0}"}, vector.asFormatString()));
+    }
+  }
+
+  private void assertFirstClusterWithoutOutlierRemoval() {
+    Assert.assertEquals(3, firstCluster.size());
+    for (Vector vector : firstCluster) {
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}","{1:1.0,0:2.0}", "{1:2.0,0:1.0}"}, vector.asFormatString()));
+    }
+  }
+  
+
+  private void assertThirdClusterWithOutlierRemoval() {
+    Assert.assertEquals(1, thirdCluster.size());
+    for (Vector vector : thirdCluster) {
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"}, vector.asFormatString()));
+    }
+  }
+
+  private void assertSecondClusterWithOutlierRemoval() {
+    Assert.assertEquals(0, secondCluster.size());
+  }
+
+  private void assertFirstClusterWithOutlierRemoval() {
+    Assert.assertEquals(1, firstCluster.size());
+    for (Vector vector : firstCluster) {
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"}, vector.asFormatString()));
+    }
+  }
+
+  
+}

Propchange: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain