You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/10/15 05:45:38 UTC
svn commit: r1022814 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/sgd/
examples/src/main/java/org/apache/mahout/classifier/sgd/
Author: tdunning
Date: Fri Oct 15 03:45:37 2010
New Revision: 1022814
URL: http://svn.apache.org/viewvc?rev=1022814&view=rev
Log:
Make model dissector return multiple coefficients
Make model serializer more polymorphic
Checkpoint model during TrainNewsGroups
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Fri Oct 15 03:45:37 2010
@@ -25,6 +25,7 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.BinaryFunction;
import org.apache.mahout.math.function.Functions;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -59,27 +60,51 @@ public class ModelDissector {
weightMap = Maps.newHashMap();
}
+ /**
+ * Probes a model to determine the effect of a particular variable. This is done
+ * with the ade of a trace dictionary which has recorded the locations in the feature
+ * vector that are modified by various variable values. We can set these locations to
+ * 1 and then look at the resulting score. This tells us the weight the model places
+ * on that variable.
+ * @param features A feature vector to use (destructively)
+ * @param traceDictionary A trace dictionary containing variables and what locations in the feature vector are affected by them
+ * @param learner The model that we are probing to find weights on features
+ */
+
public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
+ // zero out feature vector
features.assign(0);
for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+ // get a feature and locations where it is stored in the feature vector
String key = entry.getKey();
Set<Integer> value = entry.getValue();
+
+ // if we haven't looked at this feature yet
if (!weightMap.containsKey(key)) {
+ // put probe values in the feature vector
for (Integer where : value) {
features.set(where, 1);
}
+ // see what the model says
Vector v = learner.classifyNoLink(features);
weightMap.put(key, v);
+ // and zero out those locations again
for (Integer where : value) {
features.set(where, 0);
}
}
}
-
}
+ /**
+ * Returns the n most important features with their
+ * weights, most important category and the top few
+ * categories that they affect.
+ * @param n How many results to return.
+ * @return A list of the top variables.
+ */
public List<Weight> summary(int n) {
Queue<Weight> pq = new PriorityQueue<Weight>();
for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
@@ -93,28 +118,50 @@ public class ModelDissector {
return r;
}
+ private static class Category implements Comparable<Category> {
+ int index;
+ double weight;
+
+ public Category(int index, double weight) {
+ this.index = index;
+ this.weight = weight;
+ }
+
+ @Override
+ public int compareTo(Category o) {
+ int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
+ if (r != 0) {
+ return r;
+ } else {
+ return index - o.index;
+ }
+ }
+ }
+
public static class Weight implements Comparable<Weight> {
private final String feature;
private final double value;
private final int maxIndex;
+ private List<Category> categories;
public Weight(String feature, Vector weights) {
+ this(feature, weights, 3);
+ }
+
+ public Weight(String feature, Vector weights, int n) {
this.feature = feature;
// pick out the weight with the largest abs value, but don't forget the sign
- value = weights.aggregate(new BinaryFunction() {
- @Override
- public double apply(double arg1, double arg2) {
- int r = Double.compare(Math.abs(arg1), Math.abs(arg2));
- if (r < 0) {
- return arg2;
- } else if (r > 0) {
- return arg1;
- } else {
- return Math.max(arg1, arg2);
- }
+ PriorityQueue<Category> biggest = new PriorityQueue<Category>(n + 1, Ordering.natural().reverse());
+ for (Vector.Element element : weights) {
+ biggest.add(new Category(element.index(), element.get()));
+ while (biggest.size() > n) {
+ biggest.poll();
}
- }, Functions.IDENTITY);
- maxIndex = weights.maxValueIndex();
+ }
+ categories = Lists.newArrayList(biggest);
+ Collections.sort(categories, Ordering.natural().reverse());
+ value = categories.get(0).weight;
+ maxIndex = categories.get(0).index;
}
@Override
@@ -135,6 +182,14 @@ public class ModelDissector {
return value;
}
+ public double getWeight(int n) {
+ return categories.get(n).weight;
+ }
+
+ public double getCategory(int n) {
+ return categories.get(n).index;
+ }
+
public int getMaxImpact() {
return maxIndex;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java Fri Oct 15 03:45:37 2010
@@ -29,6 +29,8 @@ import com.google.gson.JsonPrimitive;
import com.google.gson.JsonSerializationContext;
import com.google.gson.JsonSerializer;
import com.google.gson.reflect.TypeToken;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
import org.apache.mahout.ep.EvolutionaryProcess;
import org.apache.mahout.ep.Mapping;
import org.apache.mahout.ep.State;
@@ -80,7 +82,7 @@ public final class ModelSerializer {
return GSON.get();
}
- public static void writeJson(String path, AdaptiveLogisticRegression model) throws IOException {
+ public static void writeJson(String path, OnlineLearner model) throws IOException {
OutputStreamWriter out = new FileWriter(path);
try {
out.write(gson().toJson(model));
@@ -96,7 +98,7 @@ public final class ModelSerializer {
* @param clazz The class of the object we expect to read.
* @return The LogisticModelParameters object that we read.
*/
- public static AdaptiveLogisticRegression loadJsonFrom(Reader in, Class<AdaptiveLogisticRegression> clazz) {
+ public static AbstractVectorClassifier loadJsonFrom(Reader in, Class<? extends AbstractVectorClassifier> clazz) {
return gson().fromJson(in, clazz);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java Fri Oct 15 03:45:37 2010
@@ -17,11 +17,13 @@
package org.apache.mahout.classifier.sgd;
+import com.google.common.collect.Collections2;
import com.google.common.collect.Lists;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
+import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
@@ -66,10 +68,13 @@ public class RankingGradient implements
}
public void addToHistory(int actual, Vector instance) {
+ while (history.size() <= actual) {
+ history.add(new ArrayDeque<Vector>(window));
+ }
// save this instance
Deque<Vector> ourSide = history.get(actual);
ourSide.add(instance);
- if (ourSide.size() >= window) {
+ while (ourSide.size() >= window) {
ourSide.pollFirst();
}
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1022814&r1=1022813&r2=1022814&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java Fri Oct 15 03:45:37 2010
@@ -175,7 +175,7 @@ public final class TrainNewsGroups {
double lambda = 0;
double mu = 0;
-
+
if (best != null) {
CrossFoldLearner state = best.getPayload().getLearner();
averageCorrect = state.percentCorrect();
@@ -210,8 +210,12 @@ public final class TrainNewsGroups {
norm = 0;
}
if (k % (bump * scale) == 0) {
+ if (learningAlgorithm.getBest() != null) {
+ ModelSerializer.writeJson("/tmp/news-group-" + k + ".model", learningAlgorithm.getBest().getPayload().getLearner());
+ }
+
step += 0.25;
- System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8f\t%.8f\t", maxBeta, nonZeros, positive, norm, lambda, mu);
+ System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu);
System.out.printf("%d\t%.3f\t%.2f\t%s\n",
k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
}
@@ -220,6 +224,8 @@ public final class TrainNewsGroups {
dissect(leakType, newsGroups, learningAlgorithm, files);
System.out.println("exiting main");
+ ModelSerializer.writeJson("/tmp/news-group.model", learningAlgorithm);
+
List<Integer> counts = Lists.newArrayList();
System.out.printf("Word counts\n");
for (String count : overallCounts.elementSet()) {
@@ -230,6 +236,9 @@ public final class TrainNewsGroups {
for (Integer count : counts) {
System.out.printf("%d\t%d\n", k, count);
k++;
+ if (k > 1000) {
+ break;
+ }
}
}
@@ -258,7 +267,8 @@ public final class TrainNewsGroups {
List<String> ngNames = Lists.newArrayList(newsGroups.values());
List<ModelDissector.Weight> weights = md.summary(100);
for (ModelDissector.Weight w : weights) {
- System.out.printf("%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1));
+ System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n", w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1),
+ w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2));
}
}