You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/17 20:41:22 UTC
svn commit: r998242 -
/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Author: tdunning
Date: Fri Sep 17 18:41:21 2010
New Revision: 998242
URL: http://svn.apache.org/viewvc?rev=998242&view=rev
Log:
Separated classifier from link function.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=998242&r1=998241&r2=998242&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java Fri Sep 17 18:41:21 2010
@@ -74,15 +74,38 @@ public abstract class AbstractOnlineLogi
return this;
}
- private Vector logisticLink(Vector v) {
+ /**
+ * Computes the inverse link function, by default the logistic link function.
+ *
+ * @param v The output of the linear combination in a GLM. Note that the value
+ * of v is disturbed.
+ * @return A version of v with the link function applied.
+ */
+ public Vector link(Vector v) {
double max = v.maxValue();
- if (max < 40) {
+ if (max >= 40) {
+ // if max > 40, we subtract the large offset first
+ // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+ v.assign(Functions.minus(max)).assign(Functions.EXP);
+ return v.divide(v.norm(1));
+ } else {
v.assign(Functions.EXP);
- double sum = 1 + v.norm(1);
- return v.divide(sum);
+ return v.divide(1 + v.norm(1));
+ }
+ }
+
+ /**
+ * Computes the binomial logistic inverse link function.
+ * @param r The value to transform.
+ * @return The logit of r.
+ */
+ public double link(double r){
+ if (r < 0) {
+ double s = Math.exp(r);
+ return s / (1 + s);
} else {
- v.assign(Functions.minus(max)).assign(Functions.EXP);
- return v;
+ double s = Math.exp(-r);
+ return 1 / (1 + s);
}
}
@@ -92,6 +115,10 @@ public abstract class AbstractOnlineLogi
return beta.times(instance);
}
+ public double classifyScalarNoLink(Vector instance) {
+ return beta.getRow(0).dot(instance);
+ }
+
/**
* Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th
* category is 1 - sum(this result).
@@ -100,7 +127,7 @@ public abstract class AbstractOnlineLogi
* @return A vector of probabilities, one for each of the first n-1 categories.
*/
public Vector classify(Vector instance) {
- return logisticLink(classifyNoLink(instance));
+ return link(classifyNoLink(instance));
}
/**
@@ -121,8 +148,7 @@ public abstract class AbstractOnlineLogi
regularize(instance);
// result is a vector with one element so we can just use dot product
- double r = Math.exp(beta.getRow(0).dot(instance));
- return r / (1 + r);
+ return link(classifyScalarNoLink(instance));
}
@Override