You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/18 13:52:14 UTC
[ignite] branch master updated: IGNITE-8532: [ML] GA Grid:
Implement Roulette Wheel Selection
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 4cadb0d IGNITE-8532: [ML] GA Grid: Implement Roulette Wheel Selection
4cadb0d is described below
commit 4cadb0dd5920b8cf7b55c5e3f83c398008a8eddc
Author: Turik Campbell <ad...@techbysample.com>
AuthorDate: Fri Jan 18 16:51:24 2019 +0300
IGNITE-8532: [ML] GA Grid: Implement Roulette Wheel Selection
This closes #5842
---
.../ml/genetic/helloworld/HelloWorldGAExample.java | 23 +++-
.../java/org/apache/ignite/ml/genetic/GAGrid.java | 111 ++++++++-------
.../ml/genetic/RouletteWheelSelectionJob.java | 111 +++++++++++++++
.../ml/genetic/RouletteWheelSelectionTask.java | 153 +++++++++++++++++++++
.../ml/genetic/parameter/GAGridConstants.java | 7 +-
5 files changed, 355 insertions(+), 50 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/genetic/helloworld/HelloWorldGAExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/genetic/helloworld/HelloWorldGAExample.java
index 585cbb5..7e8bb8a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/genetic/helloworld/HelloWorldGAExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/genetic/helloworld/HelloWorldGAExample.java
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.genetic.Chromosome;
import org.apache.ignite.ml.genetic.GAGrid;
import org.apache.ignite.ml.genetic.Gene;
import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
/**
* This example demonstrates how to use the {@link GAGrid} framework. In this example, we want to evolve a string
@@ -37,6 +38,14 @@ import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
* <p>
* You can change the test data and parameters of GA grid used in this example and re-run it to explore
* this functionality further.</p>
+ *
+ * For example, you may change the some basic genetic parameters on the GAConfiguration object:
+ *
+ * Mutation Rate
+ * Crossover Rate
+ * Population Size
+ * Selection Method
+ *
* <p>
* How to run from command line:</p>
* <p>
@@ -72,7 +81,19 @@ public class HelloWorldGAExample {
// Initialize gene pool.
gaCfg.setGenePool(genes);
-
+
+ // Set CrossOver Rate.
+ gaCfg.setCrossOverRate(.05);
+
+ // Set Mutation Rate.
+ gaCfg.setMutationRate(.05);
+
+ // Set Selection Method.
+ gaCfg.setSelectionMtd(GAGridConstants.SELECTION_METHOD.SELECTION_METHOD_ROULETTE_WHEEL);
+
+ // Set Population Size.
+ gaCfg.setPopulationSize(2000);
+
// Create and set Fitness function.
HelloWorldFitnessFunction function = new HelloWorldFitnessFunction();
gaCfg.setFitnessFunction(function);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/GAGrid.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/GAGrid.java
index 92eab5e..5531241 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/GAGrid.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/GAGrid.java
@@ -18,8 +18,10 @@
package org.apache.ignite.ml.genetic;
import java.util.ArrayList;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
+import java.util.stream.Collectors;
import javax.cache.Cache.Entry;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
@@ -48,8 +50,7 @@ public class GAGrid {
private IgniteCache<Long, Chromosome> populationCache;
/** Gene cache */
private IgniteCache<Long, Gene> geneCache;
- /** population keys */
- private List<Long> populationKeys = new ArrayList<Long>();
+
/**
* @param cfg GAConfiguration
@@ -76,12 +77,11 @@ public class GAGrid {
* @return Average fitness score
*/
private Double calculateAverageFitness() {
-
double avgFitnessScore = 0;
IgniteCache<Long, Gene> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
- // Execute query to get names of all employees.
+ // Execute query calculate average fitness
SqlFieldsQuery sql = new SqlFieldsQuery("select AVG(FITNESSSCORE) from Chromosome");
// Iterate over the result set.
@@ -110,7 +110,7 @@ public class GAGrid {
private Boolean copyFitterChromosomesToPopulation(List<Long> fittestKeys, List<Long> selectedKeys) {
double truncatePercentage = this.cfg.getTruncateRate();
- int totalSize = this.populationKeys.size();
+ int totalSize = this.cfg.getPopulationSize();
int truncateCnt = (int)(truncatePercentage * totalSize);
@@ -118,7 +118,6 @@ public class GAGrid {
return this.ignite.compute()
.execute(new TruncateSelectionTask(fittestKeys, numOfCopies), selectedKeys);
-
}
/**
@@ -137,7 +136,7 @@ public class GAGrid {
if (!(keys.contains(key))) {
genes[k] = key;
keys.add(key);
- k = k + 1;
+ k += 1;
}
}
return new Chromosome(genes);
@@ -165,27 +164,28 @@ public class GAGrid {
initializeGenePopulation();
- intializePopulation();
+ initializePopulation();
// Calculate Fitness
- calculateFitness(this.populationKeys);
+ calculateFitness(getPopulationKeys());
// Retrieve chromosomes in order by fitness value
- List<Long> keys = getChromosomesByFittest();
+ LinkedHashMap<Long, Double> map = getChromosomesByFittest();
// Calculate average fitness value of population
double averageFitnessScore = calculateAverageFitness();
- fittestChomosome = populationCache.get(keys.get(0));
+ Long key = map.keySet().iterator().next();
+
+ fittestChomosome = populationCache.get(key);
// while NOT terminateCondition met
while (!(cfg.getTerminateCriteria().isTerminationConditionMet(fittestChomosome, averageFitnessScore,
generationCnt))) {
- generationCnt = generationCnt + 1;
+ generationCnt += 1;
// We will crossover/mutate over chromosomes based on selection method
-
- List<Long> selectedKeysforCrossMutaton = selection(keys);
+ List<Long> selectedKeysforCrossMutaton = selection(map);
// Cross Over
crossover(selectedKeysforCrossMutaton);
@@ -197,10 +197,12 @@ public class GAGrid {
calculateFitness(selectedKeysforCrossMutaton);
// Retrieve chromosomes in order by fitness value
- keys = getChromosomesByFittest();
+ map = getChromosomesByFittest();
+ key = map.keySet().iterator().next();
+
// Retreive the first chromosome from the list
- fittestChomosome = populationCache.get(keys.get(0));
+ fittestChomosome = populationCache.get(key);
// Calculate average fitness value of population
averageFitnessScore = calculateAverageFitness();
@@ -214,27 +216,29 @@ public class GAGrid {
/**
* helper routine to retrieve Chromosome keys in order of fittest
*
- * @return List of primary keys for chromosomes.
+ * @return Map of primary key/fitness score pairs for chromosomes.
*/
- private List<Long> getChromosomesByFittest() {
- List<Long> orderChromKeysByFittest = new ArrayList<Long>();
+ private LinkedHashMap<Long,Double> getChromosomesByFittest() {
+ LinkedHashMap<Long, Double> orderChromKeysByFittest = new LinkedHashMap<>();
+
String orderDirection = "desc";
if (!cfg.isHigherFitnessValFitter())
orderDirection = "asc";
- String fittestSQL = "select _key from Chromosome order by fitnessScore " + orderDirection;
+ String fittestSQL = "select _key, fitnessScore from Chromosome order by fitnessScore " + orderDirection;
// Execute query to retrieve keys for ALL Chromosomes by fittnessScore
QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));
-
+
List<List<?>> res = cursor.getAll();
-
+
for (List row : res) {
- Long key = (Long)row.get(0);
- orderChromKeysByFittest.add(key);
+ Long key = (Long)row.get(0);
+ Double fitnessScore= (Double)row.get(1);
+ orderChromKeysByFittest.put(key, fitnessScore);
}
-
+
return orderChromKeysByFittest;
}
@@ -272,25 +276,11 @@ public class GAGrid {
for (int j = 0; j < populationSize; j++) {
Chromosome chromosome = createChromosome(cfg.getChromosomeLen());
populationCache.put(chromosome.id(), chromosome);
- populationKeys.add(chromosome.id());
}
}
- /**
- * initialize the population of Chromosomes based on GAConfiguration
- */
- void intializePopulation() {
- int populationSize = cfg.getPopulationSize();
- populationCache.clear();
-
- for (int j = 0; j < populationSize; j++) {
- Chromosome chromosome = createChromosome(cfg.getChromosomeLen());
- populationCache.put(chromosome.id(), chromosome);
- populationKeys.add(chromosome.id());
- }
-
- }
+
/**
* Perform mutation
@@ -330,7 +320,7 @@ public class GAGrid {
* Truncation selection simply retains the fittest x% of the population. These fittest individuals are duplicated so
* that the population size is maintained.
*
- * @param keys
+ * @param keys Keys.
* @return List of keys
*/
private List<Long> selectByTruncation(List<Long> keys) {
@@ -340,6 +330,18 @@ public class GAGrid {
return keys.subList(truncateCnt, keys.size());
}
+
+ /**
+ * Roulette Wheel selection
+ *
+ * @param map Map of keys/fitness scores
+ * @return List of primary Keys for respective chromosomes that will breed
+ */
+ private List<Long> selectByRouletteWheel(LinkedHashMap map) {
+ List<Long> populationKeys = this.ignite.compute().execute(new RouletteWheelSelectionTask(this.cfg), map);
+
+ return populationKeys;
+ }
/**
* @param k Gene index in Chromosome.
@@ -359,7 +361,7 @@ public class GAGrid {
* @return Primary key of respective Gene
*/
private long selectGeneByChromsomeCriteria(int k) {
- List<Gene> genes = new ArrayList();
+ List<Gene> genes = new ArrayList<>();
StringBuffer sbSqlClause = new StringBuffer("_val like '");
sbSqlClause.append("%");
@@ -393,11 +395,14 @@ public class GAGrid {
/**
* Select chromosomes
*
- * @param chromosomeKeys List of population primary keys for respective Chromsomes
+ * @param map Map of keys/fitness scores for respective Chromsomes
* @return List of primary keys for respective Chromsomes
*/
- private List<Long> selection(List<Long> chromosomeKeys) {
- List<Long> selectedKeys = new ArrayList();
+ private List<Long> selection(LinkedHashMap map) {
+ List<Long> selectedKeys = new ArrayList<>();
+
+ // We will crossover/mutate over chromosomes based on selection method
+ List<Long> chromosomeKeys = new ArrayList<>(map.keySet());
GAGridConstants.SELECTION_METHOD selectionMtd = cfg.getSelectionMtd();
@@ -413,8 +418,10 @@ public class GAGrid {
copyFitterChromosomesToPopulation(fittestKeys, selectedKeys);
// copy more fit keys to rest of population
- break;
-
+ break;
+ case SELECTION_METHOD_ROULETTE_WHEEL:
+ selectedKeys = this.selectByRouletteWheel(map);
+
default:
break;
}
@@ -428,6 +435,14 @@ public class GAGrid {
* @return List of Chromosome primary keys
*/
List<Long> getPopulationKeys() {
- return populationKeys;
+ String fittestSQL = "select _key from Chromosome";
+
+ // Execute query to retrieve keys for ALL Chromosomes
+ QueryCursor<List<?>> cursor = populationCache.query(new SqlFieldsQuery(fittestSQL));
+
+ List<List<?>> res = cursor.getAll();
+
+ return (List<Long>) res.stream().map(x -> x.get(0)).collect(Collectors.toList());
}
+
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java
new file mode 100644
index 0000000..5b288af
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionJob.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.genetic;
+
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.stream.Collectors;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.IgniteLogger;
+import org.apache.ignite.compute.ComputeJobAdapter;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
+import org.apache.ignite.resources.IgniteInstanceResource;
+import org.apache.ignite.resources.LoggerResource;
+
+/**
+ * Responsible for performing Roulette Wheel selection
+ */
+public class RouletteWheelSelectionJob extends ComputeJobAdapter {
+ /** Ignite instance */
+ @IgniteInstanceResource
+ private Ignite ignite = null;
+
+ /** Ignite logger */
+ @LoggerResource
+ private IgniteLogger log = null;
+
+ /** Total Fitness score */
+ Double totalFitnessScore = null;
+
+ /** Chromosome key/fitness score pair */
+ LinkedHashMap<Long, Double> map = null;
+
+ /**
+ * @param totalFitnessScore Total fitness score
+ * @param map Chromosome key / fitness score map
+ */
+ public RouletteWheelSelectionJob(Double totalFitnessScore, LinkedHashMap<Long, Double> map) {
+ this.totalFitnessScore = totalFitnessScore;
+ this.map = map;
+ }
+
+ /**
+ * Perform Roulette Wheel selection
+ *
+ * @return Chromosome parent chosen after 'spinning' the wheel.
+ */
+ @Override public Chromosome execute() throws IgniteException {
+
+ IgniteCache<Long, Chromosome> populationCache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+
+ int value = spintheWheel(this.totalFitnessScore);
+
+ double partialSum = 0;
+ boolean notFound = true;
+
+ //sort map in ascending order by fitness score
+ Map<Long, Double> sortedAscendingMap = map.entrySet().stream()
+ .sorted(Map.Entry.comparingByValue())
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
+
+ Iterator<Entry<Long, Double>> entries = sortedAscendingMap.entrySet().iterator();
+
+ Long chromosomeKey = (long)-1;
+
+ while (entries.hasNext() && notFound) {
+ Entry<Long, Double> entry = entries.next();
+ Long key = entry.getKey();
+ Double fitnessScore = entry.getValue();
+ partialSum = partialSum + fitnessScore;
+
+ if (partialSum >= value) {
+ notFound = false;
+ chromosomeKey = key;
+ }
+ }
+
+ return populationCache.get(chromosomeKey);
+ }
+
+ /**
+ * Spin the wheel.
+ *
+ * @param fitnessScore Size of Gene pool
+ * @return value
+ */
+ private int spintheWheel(Double fitnessScore) {
+ Random randomGenerator = new Random();
+ return randomGenerator.nextInt(fitnessScore.intValue());
+ }
+
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java
new file mode 100644
index 0000000..9d81471
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/RouletteWheelSelectionTask.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.genetic;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteException;
+import org.apache.ignite.cache.affinity.Affinity;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.SqlFieldsQuery;
+import org.apache.ignite.cluster.ClusterNode;
+import org.apache.ignite.compute.ComputeJob;
+import org.apache.ignite.compute.ComputeJobResult;
+import org.apache.ignite.compute.ComputeJobResultPolicy;
+import org.apache.ignite.compute.ComputeLoadBalancer;
+import org.apache.ignite.compute.ComputeTaskAdapter;
+import org.apache.ignite.ml.genetic.parameter.GAConfiguration;
+import org.apache.ignite.ml.genetic.parameter.GAGridConstants;
+import org.apache.ignite.resources.IgniteInstanceResource;
+import org.apache.ignite.resources.LoadBalancerResource;
+
+/**
+ * Responsible for performing Roulette Wheel selection.
+ */
+public class RouletteWheelSelectionTask extends ComputeTaskAdapter<LinkedHashMap<Long, Double>, List<Long>> {
+ /** Ignite resource. */
+ @IgniteInstanceResource
+ private Ignite ignite = null;
+
+ // Inject load balancer.
+ @LoadBalancerResource
+ ComputeLoadBalancer balancer;
+
+ /** GAConfiguration */
+ private GAConfiguration cfg = null;
+
+ /**
+ * @param cfg GAConfiguration
+ */
+ public RouletteWheelSelectionTask(GAConfiguration cfg) {
+ this.cfg = cfg;
+ }
+
+ /**
+ * Calculate total fitness of population
+ *
+ * @return Double value representing total fitness score of population
+ */
+ private Double calculateTotalFitness() {
+ double totalFitnessScore = 0;
+
+ IgniteCache<Long, Chromosome> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+
+ SqlFieldsQuery sql = new SqlFieldsQuery("select SUM(FITNESSSCORE) from Chromosome");
+
+ // Iterate over the result set.
+ try (QueryCursor<List<?>> cursor = cache.query(sql)) {
+ for (List<?> row : cursor)
+ totalFitnessScore = (Double)row.get(0);
+ }
+
+ return totalFitnessScore;
+ }
+
+ /**
+ * @param nodes List of ClusterNode.
+ * @param chromosomeKeyFitness Map of key/fitness score pairs.
+ * @return Map of nodes to jobs.
+ */
+ @Override public Map<ComputeJob, ClusterNode> map(List<ClusterNode> nodes,
+ LinkedHashMap<Long, Double> chromosomeKeyFitness) throws IgniteException {
+ Map<ComputeJob, ClusterNode> map = new HashMap<>();
+
+ Affinity affinity = ignite.affinity(GAGridConstants.POPULATION_CACHE);
+ Double totalFitness = this.calculateTotalFitness();
+
+ int populationSize = this.cfg.getPopulationSize();
+
+ for (int i = 0; i < populationSize; i++) {
+ // Pick the next best balanced node for the job.
+ RouletteWheelSelectionJob job = new RouletteWheelSelectionJob(totalFitness, chromosomeKeyFitness);
+ map.put(job, balancer.getBalancedNode(job, null));
+ }
+
+ return map;
+ }
+
+ /**
+ * Return list of parent Chromosomes.
+ *
+ * @param list List of ComputeJobResult.
+ * @return List of Chromosome keys.
+ */
+ @Override public List<Long> reduce(List<ComputeJobResult> list) throws IgniteException {
+ List<Chromosome> parents = list.stream().map((x) -> (Chromosome)x.getData()).collect(Collectors.toList());
+
+ return createParents(parents);
+ }
+
+ /**
+ * Create new parents and add to populationCache
+ *
+ * @param parents Chromosomes chosen to breed
+ * @return List of Chromosome keys.
+ */
+ private List<Long> createParents(List<Chromosome> parents) {
+ IgniteCache<Long, Chromosome> cache = ignite.cache(GAGridConstants.POPULATION_CACHE);
+ cache.clear();
+
+ List<Long> keys = new ArrayList();
+
+ parents.stream().forEach((x) -> {
+ long[] genes = x.getGenes();
+ Chromosome newparent = new Chromosome(genes);
+ cache.put(newparent.id(), newparent);
+ keys.add(newparent.id());
+ });
+
+ return keys;
+ }
+
+ /** {@inheritDoc} */
+ @Override public ComputeJobResultPolicy result(ComputeJobResult res, List<ComputeJobResult> rcvd) {
+ IgniteException err = res.getException();
+
+ if (err != null)
+ return ComputeJobResultPolicy.FAILOVER;
+
+ // If there is no exception, wait for all job results.
+ return ComputeJobResultPolicy.WAIT;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/GAGridConstants.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/GAGridConstants.java
index 6d1645f..a44a802 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/GAGridConstants.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/parameter/GAGridConstants.java
@@ -29,6 +29,11 @@ public interface GAGridConstants {
/** Selection Method type **/
public enum SELECTION_METHOD {
- SELECTON_METHOD_ELETISM, SELECTION_METHOD_TRUNCATION
+ /** Selecton method eletism. */
+ SELECTON_METHOD_ELETISM,
+ /** Selection method truncation. */
+ SELECTION_METHOD_TRUNCATION,
+ /** Selection method roulette wheel. */
+ SELECTION_METHOD_ROULETTE_WHEEL
}
}