You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/16 00:16:48 UTC
svn commit: r997525 -
/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Author: tdunning
Date: Wed Sep 15 22:16:47 2010
New Revision: 997525
URL: http://svn.apache.org/viewvc?rev=997525&view=rev
Log:
ModelDissector retains sign information for weights
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997525&r1=997524&r2=997525&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Wed Sep 15 22:16:47 2010
@@ -22,6 +22,8 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
+import org.apache.mahout.math.function.Functions;
import java.util.Collections;
import java.util.List;
@@ -78,18 +80,29 @@ public class ModelDissector {
private String feature;
private double value;
private int maxIndex;
- private Vector weights;
public Weight(String feature, Vector weights) {
- this.weights = weights;
this.feature = feature;
- value = weights.norm(1);
+ // pick out the weight with the largest abs value, but don't forget the sign
+ value = weights.aggregate(new BinaryFunction() {
+ @Override
+ public double apply(double arg1, double arg2) {
+ int r = Double.compare(Math.abs(arg1), Math.abs(arg2));
+ if (r < 0) {
+ return arg2;
+ } else if (r > 0) {
+ return arg1;
+ } else {
+ return Math.max(arg1, arg2);
+ }
+ }
+ }, Functions.IDENTITY);
maxIndex = weights.maxValueIndex();
}
@Override
public int compareTo(Weight other) {
- int r = Double.compare(this.value, other.value);
+ int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
if (r != 0) {
return r;
} else {