You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/19 09:26:28 UTC
svn commit: r987049 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/ep/EvolutionaryProcess.java
test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
Author: tdunning
Date: Thu Aug 19 07:26:28 2010
New Revision: 987049
URL: http://svn.apache.org/viewvc?rev=987049&view=rev
Log:
Exposed evolutionary algorithm suitable for training
classifiers
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java?rev=987049&r1=987048&r2=987049&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java Thu Aug 19 07:26:28 2010
@@ -1,8 +1,74 @@
package org.apache.mahout.ep;
-/**
- * Created by IntelliJ IDEA. User: tdunning Date: Aug 17, 2010 Time: 12:04:41 PM To change this
- * template use File | Settings | File Templates.
- */
-public class EvolutionaryProcess {
+import com.google.common.collect.Lists;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+public class EvolutionaryProcess<T extends Copyable<T>> {
+ private ExecutorService pool;
+ private List<State<T>> population;
+ private int populationSize;
+
+ public EvolutionaryProcess(int threadCount, int populationSize, State<T> seed) {
+ this.populationSize = populationSize;
+ pool = Executors.newFixedThreadPool(threadCount);
+ population = Lists.newArrayList();
+ for (int i = 0; i < populationSize; i++) {
+ population.add(seed.mutate());
+ }
+ }
+
+ public void mutatePopulation(int survivors) {
+ Collections.sort(population);
+ List<State<T>> parents = Lists.newArrayList(population.subList(0, survivors));
+ population.subList(survivors, population.size()).clear();
+
+ int i = 0;
+ while (population.size() < populationSize) {
+ population.add(parents.get(i % survivors).mutate());
+ i++;
+ }
+ }
+
+ public State<T> parallelDo(final Function<T> fn) throws InterruptedException, ExecutionException {
+ Collection<Callable<State<T>>> tasks = Lists.newArrayList();
+ for (final State<T> state : population) {
+ tasks.add(new Callable<State<T>>() {
+ @Override
+ public State<T> call() throws Exception {
+ double v = fn.apply(state.getPayload(), state.getMappedParams());
+ state.setValue(v);
+ return state;
+ }
+ });
+ }
+ List<Future<State<T>>> r = pool.invokeAll(tasks);
+
+ double max = Double.NEGATIVE_INFINITY;
+ State<T> best = null;
+ for (Future<State<T>> future : r) {
+ State<T> s = future.get();
+ double value = s.getValue();
+ if (!Double.isNaN(value) && value >= max) {
+ max = value;
+ best = s;
+ }
+ }
+ if (best == null) {
+ best = r.get(0).get();
+ }
+
+ return best;
+ }
+
+ public abstract static class Function<U> {
+ abstract double apply(U payload, double[] params);
+ }
}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java?rev=987049&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java Thu Aug 19 07:26:28 2010
@@ -0,0 +1,46 @@
+package org.apache.mahout.ep;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+import java.util.concurrent.ExecutionException;
+
+public class EvolutionaryProcessTest {
+ @Test
+ public void converges() throws ExecutionException, InterruptedException {
+ State<Foo> s0 = new State<Foo>(new double[5], 1);
+ s0.setPayload(new Foo());
+ s0.setRand(new Random(1));
+ EvolutionaryProcess<Foo> ep = new EvolutionaryProcess<Foo>(10, 100, s0);
+
+ State<Foo> best = null;
+ for (int i = 0; i < 10; i++) {
+ best = ep.parallelDo(new EvolutionaryProcess.Function<Foo>() {
+ @Override
+ double apply(Foo payload, double[] params) {
+ int i = 1;
+ double sum = 0;
+ for (double x : params) {
+ sum += i * (x - i) * (x - i);
+ }
+ return -sum;
+ }
+ });
+
+ ep.mutatePopulation(3);
+
+ System.out.printf("%.3f\n", best.getValue());
+ }
+
+ Assert.assertNotNull(best);
+ Assert.assertEquals(0, best.getValue(), 0.02);
+ }
+
+ private static class Foo implements Copyable<Foo> {
+ @Override
+ public Foo copy() {
+ return this;
+ }
+ }
+}