You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/10/15 05:45:38 UTC

svn commit: r1022814 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/sgd/ examples/src/main/java/org/apache/mahout/classifier/sgd/

Author: tdunning
Date: Fri Oct 15 03:45:37 2010
New Revision: 1022814

URL: http://svn.apache.org/viewvc?rev=1022814&view=rev
Log:
Make model dissector return multiple coefficients
Make model serializer more polymorphic
Checkpoint model during TrainNewsGroups

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Fri Oct 15 03:45:37 2010
@@ -25,6 +25,7 @@ import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.BinaryFunction;
 import org.apache.mahout.math.function.Functions;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -59,27 +60,51 @@ public class ModelDissector {
     weightMap = Maps.newHashMap();
   }
 
+  /**
+   * Probes a model to determine the effect of a particular variable.  This is done
+   * with the ade of a trace dictionary which has recorded the locations in the feature
+   * vector that are modified by various variable values.  We can set these locations to
+   * 1 and then look at the resulting score.  This tells us the weight the model places
+   * on that variable.
+   * @param features               A feature vector to use (destructively)
+   * @param traceDictionary        A trace dictionary containing variables and what locations in the feature vector are affected by them
+   * @param learner                The model that we are probing to find weights on features
+   */
+
   public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
+    // zero out feature vector
     features.assign(0);
     for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+      // get a feature and locations where it is stored in the feature vector
       String key = entry.getKey();
       Set<Integer> value = entry.getValue();
+
+      // if we haven't looked at this feature yet
       if (!weightMap.containsKey(key)) {
+        // put probe values in the feature vector
         for (Integer where : value) {
           features.set(where, 1);
         }
 
+        // see what the model says
         Vector v = learner.classifyNoLink(features);
         weightMap.put(key, v);
 
+        // and zero out those locations again
         for (Integer where : value) {
           features.set(where, 0);
         }
       }
     }
-
   }
 
+  /**
+   * Returns the n most important features with their
+   * weights, most important category and the top few
+   * categories that they affect.
+   * @param n      How many results to return.
+   * @return       A list of the top variables.
+   */
   public List<Weight> summary(int n) {
     Queue<Weight> pq = new PriorityQueue<Weight>();
     for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
@@ -93,28 +118,50 @@ public class ModelDissector {
     return r;
   }
 
+  private static class Category implements Comparable<Category> {
+    int index;
+    double weight;
+
+    public Category(int index, double weight) {
+      this.index = index;
+      this.weight = weight;
+    }
+
+    @Override
+    public int compareTo(Category o) {
+      int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
+      if (r != 0) {
+        return r;
+      } else {
+        return index - o.index;
+      }
+    }
+  }
+
   public static class Weight implements Comparable<Weight> {
     private final String feature;
     private final double value;
     private final int maxIndex;
+    private List<Category> categories;
 
     public Weight(String feature, Vector weights) {
+      this(feature, weights, 3);
+    }
+
+    public Weight(String feature, Vector weights, int n) {
       this.feature = feature;
       // pick out the weight with the largest abs value, but don't forget the sign
-      value = weights.aggregate(new BinaryFunction() {
-        @Override
-        public double apply(double arg1, double arg2) {
-          int r = Double.compare(Math.abs(arg1), Math.abs(arg2));
-          if (r < 0) {
-            return arg2;
-          } else if (r > 0) {
-            return arg1;
-          } else {
-            return Math.max(arg1, arg2);
-          }
+      PriorityQueue<Category> biggest = new PriorityQueue<Category>(n + 1, Ordering.natural().reverse());
+      for (Vector.Element element : weights) {
+        biggest.add(new Category(element.index(), element.get()));
+        while (biggest.size() > n) {
+          biggest.poll();
         }
-      }, Functions.IDENTITY);
-      maxIndex = weights.maxValueIndex();
+      }
+      categories = Lists.newArrayList(biggest);
+      Collections.sort(categories, Ordering.natural().reverse());
+      value = categories.get(0).weight;
+      maxIndex = categories.get(0).index;
     }
 
     @Override
@@ -135,6 +182,14 @@ public class ModelDissector {
       return value;
     }
 
+    public double getWeight(int n) {
+      return categories.get(n).weight;
+    }
+
+    public double getCategory(int n) {
+      return categories.get(n).index;
+    }
+
     public int getMaxImpact() {
       return maxIndex;
     }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java Fri Oct 15 03:45:37 2010
@@ -29,6 +29,8 @@ import com.google.gson.JsonPrimitive;
 import com.google.gson.JsonSerializationContext;
 import com.google.gson.JsonSerializer;
 import com.google.gson.reflect.TypeToken;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
 import org.apache.mahout.ep.EvolutionaryProcess;
 import org.apache.mahout.ep.Mapping;
 import org.apache.mahout.ep.State;
@@ -80,7 +82,7 @@ public final class ModelSerializer {
     return GSON.get();
   }
 
-  public static void writeJson(String path, AdaptiveLogisticRegression model) throws IOException {
+  public static void writeJson(String path, OnlineLearner model) throws IOException {
     OutputStreamWriter out = new FileWriter(path);
     try {
       out.write(gson().toJson(model));
@@ -96,7 +98,7 @@ public final class ModelSerializer {
    * @param clazz The class of the object we expect to read.
    * @return The LogisticModelParameters object that we read.
    */
-  public static AdaptiveLogisticRegression loadJsonFrom(Reader in, Class<AdaptiveLogisticRegression> clazz) {
+  public static AbstractVectorClassifier loadJsonFrom(Reader in, Class<? extends AbstractVectorClassifier> clazz) {
     return gson().fromJson(in, clazz);
   }
 

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java Fri Oct 15 03:45:37 2010
@@ -17,11 +17,13 @@
 
 package org.apache.mahout.classifier.sgd;
 
+import com.google.common.collect.Collections2;
 import com.google.common.collect.Lists;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.Functions;
 
+import java.util.ArrayDeque;
 import java.util.Deque;
 import java.util.List;
 
@@ -66,10 +68,13 @@ public class RankingGradient implements 
   }
 
   public void addToHistory(int actual, Vector instance) {
+    while (history.size() <= actual) {
+      history.add(new ArrayDeque<Vector>(window));
+    }
     // save this instance
     Deque<Vector> ourSide = history.get(actual);
     ourSide.add(instance);
-    if (ourSide.size() >= window) {
+    while (ourSide.size() >= window) {
       ourSide.pollFirst();
     }
   }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java Fri Oct 15 03:45:37 2010
@@ -175,7 +175,7 @@ public final class TrainNewsGroups {
 
       double lambda = 0;
       double mu = 0;
-      
+
       if (best != null) {
         CrossFoldLearner state = best.getPayload().getLearner();
         averageCorrect = state.percentCorrect();
@@ -210,8 +210,12 @@ public final class TrainNewsGroups {
         norm = 0;
       }
       if (k % (bump * scale) == 0) {
+        if (learningAlgorithm.getBest() != null) {
+          ModelSerializer.writeJson("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner());
+        }
+
         step += 0.25;
-        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8f\t%.8f\t", maxBeta, nonZeros, positive, norm, lambda, mu);
+        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
         System.out.printf("%d\t%.3f\t%.2f\t%s\n",
           k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
       }
@@ -220,6 +224,8 @@ public final class TrainNewsGroups {
     dissect(leakType, newsGroups, learningAlgorithm, files);
     System.out.println("exiting main");
 
+    ModelSerializer.writeJson("/tmp/news-group.model", learningAlgorithm);
+
     List<Integer> counts = Lists.newArrayList();
     System.out.printf("Word counts\n");
     for (String count : overallCounts.elementSet()) {
@@ -230,6 +236,9 @@ public final class TrainNewsGroups {
     for (Integer count : counts) {
       System.out.printf("%d\t%d\n", k, count);
       k++;
+      if (k > 1000) {
+        break;
+      }
     }
   }
 
@@ -258,7 +267,8 @@ public final class TrainNewsGroups {
     List<String> ngNames = Lists.newArrayList(newsGroups.values());
     List<ModelDissector.Weight> weights = md.summary(100);
     for (ModelDissector.Weight w : weights) {
-      System.out.printf("%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1));
+      System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
+        w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
     }
   }