You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/28 02:11:51 UTC
svn commit: r1001974 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/classifier/sgd/
test/java/org/apache/mahout/classifier/sgd/
Author: tdunning
Date: Tue Sep 28 00:11:51 2010
New Revision: 1001974
URL: http://svn.apache.org/viewvc?rev=1001974&view=rev
Log:
Added variable evolutionary step size to improve annealing evaluation in each step.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java Tue Sep 28 00:11:51 2010
@@ -63,7 +63,11 @@ public class AdaptiveLogisticRegression
private static final int SURVIVORS = 2;
private int record;
- private int evaluationInterval = 1000;
+ private int cutoff = 1000;
+ private int minInterval = 1000;
+ private int maxInterval = 1000;
+ private int currentStep = 1000;
+ private int bufferSize = 1000;
// transient here is a signal to GSON not to serialize pending records
private transient List<TrainingExample> buffer = Lists.newArrayList();
@@ -105,7 +109,7 @@ public class AdaptiveLogisticRegression
record++;
buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
- if (buffer.size() > evaluationInterval) {
+ if (buffer.size() > bufferSize) {
trainWithBufferedExamples();
}
}
@@ -134,19 +138,52 @@ public class AdaptiveLogisticRegression
} catch (ExecutionException e) {
throw new IllegalStateException(e);
}
+ buffer.clear();
+
+ if (record > cutoff) {
+ cutoff = nextStep(record);
- // evolve based on new fitness
- ep.mutatePopulation(SURVIVORS);
+ // evolve based on new fitness
+ ep.mutatePopulation(SURVIVORS);
- if (freezeSurvivors) {
- // now grossly hack the top survivors so they stick around. Set their
- // mutation rates small and also hack their learning rate to be small
- // as well.
- for (State<Wrapper> state : ep.getPopulation().subList(0, SURVIVORS)) {
- state.getPayload().freeze(state);
+ if (freezeSurvivors) {
+ // now grossly hack the top survivors so they stick around. Set their
+ // mutation rates small and also hack their learning rate to be small
+ // as well.
+ for (State<Wrapper> state : ep.getPopulation().subList(0, SURVIVORS)) {
+ state.getPayload().freeze(state);
+ }
}
}
- buffer.clear();
+
+ }
+
+ public int nextStep(int recordNumber) {
+ int stepSize = stepSize(recordNumber, 2.6);
+ if (stepSize < minInterval) {
+ stepSize = minInterval;
+ }
+
+ if (stepSize > maxInterval) {
+ stepSize = maxInterval;
+ }
+
+ int newCutoff = stepSize * (recordNumber / stepSize + 1);
+ if (newCutoff < cutoff + currentStep) {
+ newCutoff = cutoff + currentStep;
+ } else {
+ this.currentStep = stepSize;
+ }
+ return newCutoff;
+ }
+
+ public static int stepSize(int recordNumber, double multiplier) {
+ final int[] bumps = new int[]{1, 2, 5};
+ double log = Math.floor(multiplier * Math.log10(recordNumber));
+ int bump = bumps[(int) log % bumps.length];
+ int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+ return bump * scale;
}
@Override
@@ -173,7 +210,23 @@ public class AdaptiveLogisticRegression
* @param interval Number of training examples to use in each epoch of optimization.
*/
public void setInterval(int interval) {
- this.evaluationInterval = interval;
+ this.minInterval = interval;
+ this.maxInterval = interval;
+ this.cutoff = interval * (record / interval + 1);
+ }
+
+ /**
+ * Starts optimization using the shorter interval and progresses to the longer using the specified
+ * number of steps per decade. Note that values < 200 are not accepted. Values even that small
+ * are unlikely to be useful.
+ *
+ * @param minInterval The minimum epoch length for the evolutionary optimization
+ * @param maxInterval The maximum epoch length
+ */
+ public void setInterval(int minInterval, int maxInterval) {
+ this.minInterval = Math.max(200, minInterval);
+ this.maxInterval = Math.max(200, maxInterval);
+ this.cutoff = minInterval * (record / minInterval + 1);
}
public void setPoolSize(int poolSize) {
@@ -234,8 +287,12 @@ public class AdaptiveLogisticRegression
this.record = record;
}
- public int getEvaluationInterval() {
- return evaluationInterval;
+ public int getMinInterval() {
+ return minInterval;
+ }
+
+ public int getMaxInterval() {
+ return maxInterval;
}
public int getNumCategories() {
@@ -246,10 +303,6 @@ public class AdaptiveLogisticRegression
return seed.getPayload().getLearner().getPrior();
}
- public void setEvaluationInterval(int evaluationInterval) {
- this.evaluationInterval = evaluationInterval;
- }
-
public void setBuffer(List<TrainingExample> buffer) {
this.buffer = buffer;
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java Tue Sep 28 00:11:51 2010
@@ -340,7 +340,11 @@ public final class ModelSerializer {
x.get("numFeatures").getAsInt(),
jdc.<PriorFunction>deserialize(x.get("prior"), PriorFunction.class));
Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
- r.setEvaluationInterval(x.get("evaluationInterval").getAsInt());
+ if (x.get("evaluationInterval")!=null) {
+ r.setInterval(x.get("evaluationInterval").getAsInt());
+ } else {
+ r.setInterval(x.get("minInterval").getAsInt(), x.get("minInterval").getAsInt());
+ }
r.setRecord(x.get("record").getAsInt());
Type epType = new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
@@ -360,7 +364,8 @@ public final class ModelSerializer {
new TypeToken<EvolutionaryProcess<AdaptiveLogisticRegression.Wrapper>>() {}.getType()));
r.add("buffer", jsc.serialize(x.getBuffer(),
new TypeToken<List<AdaptiveLogisticRegression.TrainingExample>>() {}.getType()));
- r.add("evaluationInterval", jsc.serialize(x.getEvaluationInterval()));
+ r.add("minInterval", jsc.serialize(x.getMinInterval()));
+ r.add("maxInterval", jsc.serialize(x.getMaxInterval()));
Type stateType = new TypeToken<State<AdaptiveLogisticRegression.Wrapper>>() {}.getType();
r.add("best", jsc.serialize(x.getBest(), stateType));
r.add("numFeatures", jsc.serialize(x.getNumFeatures()));
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=1001974&r1=1001973&r2=1001974&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java Tue Sep 28 00:11:51 2010
@@ -138,4 +138,44 @@ public final class AdaptiveLogisticRegre
// make sure that the copy didn't lose anything
assertEquals(auc1, w.getLearner().auc(), 0);
}
+
+ @Test
+ public void stepSize() {
+ assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2));
+ assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6));
+ assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6));
+ assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3));
+ }
+
+ @Test
+ public void constantStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(5000);
+ assertEquals(20000, lr.nextStep(15000));
+ assertEquals(20000, lr.nextStep(15001));
+ assertEquals(20000, lr.nextStep(16500));
+ assertEquals(20000, lr.nextStep(19999));
+ }
+
+
+ @Test
+ public void growingStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(2000, 10000);
+
+ // start with minimum step size
+ for (int i = 2000; i < 20000;i+=2000) {
+ assertEquals(i + 2000, lr.nextStep(i));
+ }
+
+ // then level up a bit
+ for (int i = 20000; i < 50000; i += 5000) {
+ assertEquals(i + 5000, lr.nextStep(i));
+ }
+
+ // and more, but we top out with this step size
+ for (int i = 50000; i < 500000; i += 10000) {
+ assertEquals(i + 10000, lr.nextStep(i));
+ }
+ }
}