You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/17 20:41:22 UTC

svn commit: r998242 - /mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java

Author: tdunning
Date: Fri Sep 17 18:41:21 2010
New Revision: 998242

URL: http://svn.apache.org/viewvc?rev=998242&view=rev
Log:
Separated classifier from link function.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=998242&r1=998241&r2=998242&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java Fri Sep 17 18:41:21 2010
@@ -74,15 +74,38 @@ public abstract class AbstractOnlineLogi
     return this;
   }
 
-  private Vector logisticLink(Vector v) {
+  /**
+   * Computes the inverse link function, by default the logistic link function.
+   *
+   * @param v  The output of the linear combination in a GLM.  Note that the value
+   * of v is disturbed.
+   * @return A version of v with the link function applied.
+   */
+  public Vector link(Vector v) {
     double max = v.maxValue();
-    if (max < 40) {
+    if (max >= 40) {
+      // if max > 40, we subtract the large offset first
+      // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+      v.assign(Functions.minus(max)).assign(Functions.EXP);
+      return v.divide(v.norm(1));
+    } else {
       v.assign(Functions.EXP);
-      double sum = 1 + v.norm(1);
-      return v.divide(sum);
+      return v.divide(1 + v.norm(1));
+    }
+  }
+
+  /**
+   * Computes the binomial logistic inverse link function.
+   * @param r  The value to transform.
+   * @return   The logit of r.
+   */
+  public double link(double r){
+    if (r < 0) {
+      double s = Math.exp(r);
+      return s / (1 + s);
     } else {
-      v.assign(Functions.minus(max)).assign(Functions.EXP);
-      return v;
+      double s = Math.exp(-r);
+      return 1 / (1 + s);
     }
   }
 
@@ -92,6 +115,10 @@ public abstract class AbstractOnlineLogi
     return beta.times(instance);
   }
 
+  public double classifyScalarNoLink(Vector instance) {
+    return beta.getRow(0).dot(instance);
+  }
+
   /**
    * Returns n-1 probabilities, one for each category but the 0-th.  The probability of the 0-th
    * category is 1 - sum(this result).
@@ -100,7 +127,7 @@ public abstract class AbstractOnlineLogi
    * @return A vector of probabilities, one for each of the first n-1 categories.
    */
   public Vector classify(Vector instance) {
-    return logisticLink(classifyNoLink(instance));
+    return link(classifyNoLink(instance));
   }
 
   /**
@@ -121,8 +148,7 @@ public abstract class AbstractOnlineLogi
     regularize(instance);
 
     // result is a vector with one element so we can just use dot product
-    double r = Math.exp(beta.getRow(0).dot(instance));
-    return r / (1 + r);
+    return link(classifyScalarNoLink(instance));
   }
 
   @Override