You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/14 21:23:43 UTC

svn commit: r910067 - /lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java

Author: robinanil
Date: Sun Feb 14 20:23:42 2010
New Revision: 910067

URL: http://svn.apache.org/viewvc?rev=910067&view=rev
Log:
Adding configuration parameter in cluster dumper to set the number of top words returned and provide total score of the term

Modified:
    lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java

Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java?rev=910067&r1=910066&r2=910067&view=diff
==============================================================================
--- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java (original)
+++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java Sun Feb 14 20:23:42 2010
@@ -41,6 +41,7 @@
 import org.apache.commons.cli2.builder.DefaultOptionBuilder;
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.lang.StringUtils;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.SequenceFile;
@@ -51,30 +52,32 @@
 import org.apache.hadoop.mapred.jobcontrol.Job;
 import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.Pair;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.utils.vectors.VectorHelper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 public final class ClusterDumper {
-  
+
   private static final Logger log = LoggerFactory.getLogger(ClusterDumper.class);
-  
+
   private final String seqFileDir;
   private final String pointsDir;
   private String termDictionary;
   private String dictionaryFormat;
   private String outputFile;
   private int subString = Integer.MAX_VALUE;
+  private int numTopFeatures = 10;
   private Map<String,List<String>> clusterIdToPoints = null;
   private boolean useJSON = false;
-  
+
   public ClusterDumper(String seqFileDir, String pointsDir) throws IOException {
     this.seqFileDir = seqFileDir;
     this.pointsDir = pointsDir;
     init();
   }
-  
+
   private void init() throws IOException {
     if (this.pointsDir != null) {
       JobConf conf = new JobConf(Job.class);
@@ -84,12 +87,12 @@
       clusterIdToPoints = Collections.emptyMap();
     }
   }
-  
+
   public void printClusters() throws IOException, InstantiationException, IllegalAccessException {
     JobClient client = new JobClient();
     JobConf conf = new JobConf(Job.class);
     client.setConf(conf);
-    
+
     String[] dictionary = null;
     if (this.termDictionary != null) {
       if (dictionaryFormat.equals("text")) {
@@ -101,14 +104,14 @@
         throw new IllegalArgumentException("Invalid dictionary format");
       }
     }
-    
+
     Writer writer = null;
     if (this.outputFile != null) {
       writer = new FileWriter(this.outputFile);
     } else {
       writer = new OutputStreamWriter(System.out);
     }
-    
+
     File[] seqFileList = new File(this.seqFileDir).listFiles(new FilenameFilter() {
       @Override
       public boolean accept(File file, String name) {
@@ -134,14 +137,14 @@
           writer.append(":").append(fmtStr.substring(0, Math.min(subString, fmtStr.length())));
         }
         writer.append('\n');
-        
+
         if (dictionary != null) {
-          String topTerms = ClusterDumper.getTopFeatures(center, dictionary, 10);
+          String topTerms = ClusterDumper.getTopFeatures(center, dictionary, numTopFeatures);
           writer.write("\tTop Terms: ");
           writer.write(topTerms);
           writer.write('\n');
         }
-        
+
         List<String> points = clusterIdToPoints.get(String.valueOf(value.getId()));
         if (points != null) {
           writer.write("\tPoints: ");
@@ -163,41 +166,49 @@
       writer.close();
     }
   }
-  
+
   public String getOutputFile() {
     return outputFile;
   }
-  
+
   public void setOutputFile(String outputFile) {
     this.outputFile = outputFile;
   }
-  
+
   public int getSubString() {
     return subString;
   }
-  
+
   public void setSubString(int subString) {
     this.subString = subString;
   }
-  
+
   public Map<String,List<String>> getClusterIdToPoints() {
     return clusterIdToPoints;
   }
-  
+
   public String getTermDictionary() {
     return termDictionary;
   }
-  
+
   public void setTermDictionary(String termDictionary, String dictionaryType) {
     this.termDictionary = termDictionary;
     this.dictionaryFormat = dictionaryType;
   }
-  
+
+  public void setNumTopFeatures(int num) {
+    this.numTopFeatures = num;
+  }
+
+  public int getNumTopFeatures() {
+    return this.numTopFeatures;
+  }
+
   public static void main(String[] args) throws IOException, IllegalAccessException, InstantiationException {
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
-    
+
     Option seqOpt = obuilder.withLongName("seqFileDir").withRequired(false).withArgument(
       abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription(
       "The directory containing Sequence Files for the Clusters").withShortName("s").create();
@@ -207,6 +218,9 @@
     Option substringOpt = obuilder.withLongName("substring").withRequired(false).withArgument(
       abuilder.withName("substring").withMinimum(1).withMaximum(1).create()).withDescription(
       "The number of chars of the asFormatString() to print").withShortName("b").create();
+    Option numWordsOpt = obuilder.withLongName("numWords").withRequired(false).withArgument(
+      abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The number of top terms to print").withShortName("n").create();
     Option centroidJSonOpt = obuilder.withLongName("json").withRequired(false).withDescription(
       "Output the centroid as JSON.  Otherwise it substitues in the terms for vector cell entries")
         .withShortName("j").create();
@@ -223,11 +237,11 @@
       "The dictionary file type (text|sequencefile)").withShortName("dt").create();
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
-    
+
     Group group = gbuilder.withName("Options").withOption(helpOpt).withOption(seqOpt).withOption(outputOpt)
         .withOption(substringOpt).withOption(pointsOpt).withOption(centroidJSonOpt).withOption(dictOpt)
-        .withOption(dictTypeOpt).create();
-    
+        .withOption(dictTypeOpt).withOption(numWordsOpt).create();
+
     try {
       Parser parser = new Parser();
       parser.setGroup(group);
@@ -244,7 +258,7 @@
       if (cmdLine.hasOption(dictOpt)) {
         termDictionary = cmdLine.getValue(dictOpt).toString();
       }
-      
+
       String pointsDir = null;
       if (cmdLine.hasOption(pointsOpt)) {
         pointsDir = cmdLine.getValue(pointsOpt).toString();
@@ -253,28 +267,35 @@
       if (cmdLine.hasOption(outputOpt)) {
         outputFile = cmdLine.getValue(outputOpt).toString();
       }
-      
+
       int sub = -1;
       if (cmdLine.hasOption(substringOpt)) {
         sub = Integer.parseInt(cmdLine.getValue(substringOpt).toString());
       }
+
       ClusterDumper clusterDumper = new ClusterDumper(seqFileDir, pointsDir);
       if (cmdLine.hasOption(centroidJSonOpt)) {
         clusterDumper.setUseJSON(true);
       }
-      
+
       if (outputFile != null) {
         clusterDumper.setOutputFile(outputFile);
       }
-      
+
       String dictionaryType = "text";
       if (cmdLine.hasOption(dictTypeOpt)) {
         dictionaryType = cmdLine.getValue(dictTypeOpt).toString();
       }
-      
+
       if (termDictionary != null) {
         clusterDumper.setTermDictionary(termDictionary, dictionaryType);
       }
+
+      if (cmdLine.hasOption(numWordsOpt)) {
+        int numWords = Integer.parseInt(cmdLine.getValue(numWordsOpt).toString());
+        clusterDumper.setNumTopFeatures(numWords);
+      }
+
       if (sub >= 0) {
         clusterDumper.setSubString(sub);
       }
@@ -284,21 +305,21 @@
       CommandLineUtil.printHelp(group);
     }
   }
-  
+
   private void setUseJSON(boolean json) {
     this.useJSON = json;
   }
-  
+
   private static Map<String,List<String>> readPoints(String pointsPathDir, JobConf conf) throws IOException {
     SortedMap<String,List<String>> result = new TreeMap<String,List<String>>();
-    
+
     File[] children = new File(pointsPathDir).listFiles(new FilenameFilter() {
       @Override
       public boolean accept(File file, String name) {
         return name.endsWith(".crc") == false;
       }
     });
-    
+
     for (File file : children) {
       if (!file.isFile()) {
         continue;
@@ -328,30 +349,30 @@
         ClusterDumper.log.error("Exception", e);
       }
     }
-    
+
     return result;
   }
-  
+
   static class TermIndexWeight {
     public int index = -1;
     public double weight = 0;
-    
+
     TermIndexWeight(int index, double weight) {
       this.index = index;
       this.weight = weight;
     }
   }
-  
+
   private static String getTopFeatures(Vector vector, String[] dictionary, int numTerms) {
-    
+
     List<TermIndexWeight> vectorTerms = new ArrayList<TermIndexWeight>();
-    
+
     Iterator<Vector.Element> iter = vector.iterateNonZero();
     while (iter.hasNext()) {
       Vector.Element elt = iter.next();
       vectorTerms.add(new TermIndexWeight(elt.index(), elt.get()));
     }
-    
+
     // Sort results in reverse order (ie weight in descending order)
     Collections.sort(vectorTerms, new Comparator<TermIndexWeight>() {
       @Override
@@ -359,28 +380,29 @@
         return Double.compare(two.weight, one.weight);
       }
     });
-    
-    List<String> topTerms = new LinkedList<String>();
-    
-    for (int i = 0; i < vectorTerms.size() && i < numTerms; i++) {
+
+    List<Pair<String,Double>> topTerms = new LinkedList<Pair<String,Double>>();
+
+    for (int i = 0; (i < vectorTerms.size()) && (i < numTerms); i++) {
       int index = vectorTerms.get(i).index;
       String dictTerm = dictionary[index];
       if (dictTerm == null) {
         ClusterDumper.log.error("Dictionary entry missing for {}", index);
         continue;
       }
-      topTerms.add(dictTerm);
+      topTerms.add(new Pair<String,Double>(dictTerm, vectorTerms.get(i).weight));
     }
-    
+
     StringBuilder sb = new StringBuilder();
-    for (Iterator<String> iterator = topTerms.iterator(); iterator.hasNext();) {
-      String term = iterator.next();
-      sb.append(term);
-      if (iterator.hasNext()) {
-        sb.append(", ");
-      }
+
+    for (Pair<String,Double> item : topTerms) {
+      String term = item.getFirst();
+      sb.append("\n\t\t");
+      sb.append(StringUtils.rightPad(term, 40));
+      sb.append("=>");
+      sb.append(StringUtils.leftPad(item.getSecond().toString(), 20));
     }
     return sb.toString();
   }
-  
+
 }