You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/16 00:16:48 UTC

svn commit: r997525 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java

Author: tdunning
Date: Wed Sep 15 22:16:47 2010
New Revision: 997525

URL: http://svn.apache.org/viewvc?rev=997525&view=rev
Log:
ModelDissector retains sign information for weights

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java?rev=997525&r1=997524&r2=997525&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java Wed Sep 15 22:16:47 2010
@@ -22,6 +22,8 @@ import com.google.common.collect.Maps;
 import com.google.common.collect.Ordering;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
+import org.apache.mahout.math.function.Functions;
 
 import java.util.Collections;
 import java.util.List;
@@ -78,18 +80,29 @@ public class ModelDissector {
     private String feature;
     private double value;
     private int maxIndex;
-    private Vector weights;
 
     public Weight(String feature, Vector weights) {
-      this.weights = weights;
       this.feature = feature;
-      value = weights.norm(1);
+      // pick out the weight with the largest abs value, but don't forget the sign
+      value = weights.aggregate(new BinaryFunction() {
+        @Override
+        public double apply(double arg1, double arg2) {
+          int r = Double.compare(Math.abs(arg1), Math.abs(arg2));
+          if (r < 0) {
+            return arg2;
+          } else if (r > 0) {
+            return arg1;
+          } else {
+            return Math.max(arg1, arg2);
+          }
+        }
+      }, Functions.IDENTITY);
       maxIndex = weights.maxValueIndex();
     }
 
     @Override
     public int compareTo(Weight other) {
-      int r = Double.compare(this.value, other.value);
+      int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
       if (r != 0) {
         return r;
       } else {