You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/19 09:26:28 UTC

svn commit: r987049 - in /mahout/trunk/core/src: main/java/org/apache/mahout/ep/EvolutionaryProcess.java test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java

Author: tdunning
Date: Thu Aug 19 07:26:28 2010
New Revision: 987049

URL: http://svn.apache.org/viewvc?rev=987049&view=rev
Log:
Exposed evolutionary algorithm suitable for training
classifiers

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java?rev=987049&r1=987048&r2=987049&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java Thu Aug 19 07:26:28 2010
@@ -1,8 +1,74 @@
 package org.apache.mahout.ep;
 
-/**
- * Created by IntelliJ IDEA. User: tdunning Date: Aug 17, 2010 Time: 12:04:41 PM To change this
- * template use File | Settings | File Templates.
- */
-public class EvolutionaryProcess {
+import com.google.common.collect.Lists;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+public class EvolutionaryProcess<T extends Copyable<T>> {
+  private ExecutorService pool;
+  private List<State<T>> population;
+  private int populationSize;
+
+  public EvolutionaryProcess(int threadCount, int populationSize, State<T> seed) {
+    this.populationSize = populationSize;
+    pool = Executors.newFixedThreadPool(threadCount);
+    population = Lists.newArrayList();
+    for (int i = 0; i < populationSize; i++) {
+      population.add(seed.mutate());
+    }
+  }
+
+  public void mutatePopulation(int survivors) {
+    Collections.sort(population);
+    List<State<T>> parents = Lists.newArrayList(population.subList(0, survivors));
+    population.subList(survivors, population.size()).clear();
+
+    int i = 0;
+    while (population.size() < populationSize) {
+      population.add(parents.get(i % survivors).mutate());
+      i++;
+    }
+  }
+
+  public State<T> parallelDo(final Function<T> fn) throws InterruptedException, ExecutionException {
+    Collection<Callable<State<T>>> tasks = Lists.newArrayList();
+    for (final State<T> state : population) {
+      tasks.add(new Callable<State<T>>() {
+        @Override
+        public State<T> call() throws Exception {
+          double v = fn.apply(state.getPayload(), state.getMappedParams());
+          state.setValue(v);
+          return state;
+        }
+      });
+    }
+    List<Future<State<T>>> r = pool.invokeAll(tasks);
+
+    double max = Double.NEGATIVE_INFINITY;
+    State<T> best = null;
+    for (Future<State<T>> future : r) {
+      State<T> s = future.get();
+      double value = s.getValue();
+      if (!Double.isNaN(value) && value >= max) {
+        max = value;
+        best = s;
+      }
+    }
+    if (best == null) {
+      best = r.get(0).get();
+    }
+
+    return best;
+  }
+
+  public abstract static class Function<U> {
+    abstract double apply(U payload, double[] params);
+  }
 }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java?rev=987049&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java Thu Aug 19 07:26:28 2010
@@ -0,0 +1,46 @@
+package org.apache.mahout.ep;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+import java.util.concurrent.ExecutionException;
+
+public class EvolutionaryProcessTest {
+  @Test
+  public void converges() throws ExecutionException, InterruptedException {
+    State<Foo> s0 = new State<Foo>(new double[5], 1);
+    s0.setPayload(new Foo());
+    s0.setRand(new Random(1));
+    EvolutionaryProcess<Foo> ep = new EvolutionaryProcess<Foo>(10, 100, s0);
+
+    State<Foo> best = null;
+    for (int i = 0; i < 10; i++) {
+      best = ep.parallelDo(new EvolutionaryProcess.Function<Foo>() {
+        @Override
+        double apply(Foo payload, double[] params) {
+          int i = 1;
+          double sum = 0;
+          for (double x : params) {
+            sum += i * (x - i) * (x - i);
+          }
+          return -sum;
+        }
+      });
+
+      ep.mutatePopulation(3);
+
+      System.out.printf("%.3f\n", best.getValue());
+    }
+
+    Assert.assertNotNull(best);
+    Assert.assertEquals(0, best.getValue(), 0.02);
+  }
+
+  private static class Foo implements Copyable<Foo> {
+    @Override
+    public Foo copy() {
+      return this;
+    }
+  }
+}