You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/10/02 10:18:36 UTC
svn commit: r1003753 - in
/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd:
AbstractOnlineLogisticRegression.java DefaultGradient.java Gradient.java
MixedGradient.java RankingGradient.java
Author: tdunning
Date: Sat Oct 2 08:18:35 2010
New Revision: 1003753
URL: http://svn.apache.org/viewvc?rev=1003753&view=rev
Log:
Added ranking gradient implement
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java?rev=1003753&r1=1003752&r2=1003753&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java Sat Oct 2 08:18:35 2010
@@ -163,11 +163,8 @@ public abstract class AbstractOnlineLogi
// push coefficients back to zero based on the prior
regularize(instance);
- // what does the current model say?
- Vector v = classify(instance);
-
// update each row of coefficients according to result
- Vector gradient = this.gradient.apply(groupKey, actual, v);
+ Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
for (int i = 0; i < numCategories - 1; i++) {
double gradientBase = gradient.get(i);
@@ -177,7 +174,7 @@ public abstract class AbstractOnlineLogi
Vector.Element updateLocation = nonZeros.next();
int j = updateLocation.index();
- double newValue = beta.getQuick(i, j) + learningRate * gradientBase * instance.get(j) * perTermLearningRate(j);
+ double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
beta.setQuick(i, j, newValue);
}
}
@@ -324,24 +321,4 @@ public abstract class AbstractOnlineLogi
return k < 1;
}
- public static class DefaultGradient implements Gradient {
- /**
- * Provides a default gradient computation useful for logistic regression. This
- * can be over-ridden to incorporate AUC driven learning.
- * <p>
- * See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
- * @param groupKey A grouping key to allow per-something AUC loss to be used for training.
- *@param actual The target variable value.
- * @param v The current score vector. @return
- */
- @Override
- public final Vector apply(String groupKey, int actual, Vector v) {
- Vector r = v.like();
- if (actual != 0) {
- r.setQuick(actual - 1, 1);
- }
- r.assign(v, Functions.MINUS);
- return r;
- }
- }
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java Sat Oct 2 08:18:35 2010
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implements the basic logistic training law.
+ */
+public class DefaultGradient implements Gradient {
+ /**
+ * Provides a default gradient computation useful for logistic regression.
+ *
+ * @param groupKey A grouping key to allow per-something AUC loss to be used for training.
+ * @param actual The target variable value.
+ * @param instance The current feature vector to use for gradient computation
+ * @param classifier The classifier that can compute scores
+ * @return The gradient to be applied to beta
+ */
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ // what does the current model say?
+ Vector v = classifier.classify(instance);
+
+ Vector r = v.like();
+ if (actual != 0) {
+ r.setQuick(actual - 1, 1);
+ }
+ r.assign(v, Functions.MINUS);
+ return r;
+ }
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java?rev=1003753&r1=1003752&r2=1003753&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java Sat Oct 2 08:18:35 2010
@@ -17,6 +17,7 @@
package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.math.Vector;
/**
@@ -25,5 +26,5 @@ import org.apache.mahout.math.Vector;
* normal loss function.
*/
public interface Gradient {
- Vector apply(String groupKey, int actual, Vector v);
+ Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier);
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java Sat Oct 2 08:18:35 2010
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * Provides a stochastic mixture of ranking updates and normal logistic updates. This uses a
+ * combination of AUC driven learning to improve ranking performance and traditional log-loss driven
+ * learning to improve log-likelihood.
+ * <p/>
+ * See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf
+ */
+public class MixedGradient implements Gradient {
+ private double alpha;
+
+ private RankingGradient rank;
+ private Gradient basic;
+
+ Random random = RandomUtils.getRandom();
+
+ public MixedGradient(double alpha, int window) {
+ this.alpha = alpha;
+ this.rank = new RankingGradient(window);
+ this.basic = this.rank.getBaseGradient();
+ }
+
+ @Override
+ public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ if (random.nextDouble() < alpha) {
+ // one option is to apply a ranking update relative to our recent history
+ return rank.apply(groupKey, actual, instance, classifier);
+ } else {
+ // the other option is a normal update, but we have to update our history on the way
+ rank.addToHistory(actual, instance);
+ return basic.apply(groupKey, actual, instance, classifier);
+ }
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java?rev=1003753&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java Sat Oct 2 08:18:35 2010
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * Uses the difference between this instance and recent history to get a
+ * gradient that optimizes ranking performance. Essentially this is the
+ * same as directly optimizing AUC. It isn't expected that this would
+ * be used alone, but rather that a MixedGradient would use it and a
+ * DefaultGradient together to combine both ranking and log-likelihood
+ * goals.
+ */
+public class RankingGradient implements Gradient {
+ private static final Gradient basic = new DefaultGradient();
+
+ private int window = 10;
+
+ private List<Deque<Vector>> history = Lists.newArrayList();
+
+ public RankingGradient(int window) {
+ this.window = window;
+ }
+
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ addToHistory(actual, instance);
+
+ // now compute average gradient versus saved vectors from the other side
+ Deque<Vector> otherSide = history.get(1 - actual);
+ int n = otherSide.size();
+
+ Vector r = null;
+ for (Vector other : otherSide) {
+ Vector g = basic.apply(groupKey, actual, instance.minus(other), classifier);
+
+ if (r == null) {
+ r = g;
+ } else {
+ r.assign(g, Functions.plusMult(1.0 / n));
+ }
+ }
+ return r;
+ }
+
+ public void addToHistory(int actual, Vector instance) {
+ // save this instance
+ Deque<Vector> ourSide = history.get(actual);
+ ourSide.add(instance);
+ if (ourSide.size() >= window) {
+ ourSide.pollFirst();
+ }
+ }
+
+ public Gradient getBaseGradient() {
+ return basic;
+ }
+}