You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:25 UTC
[23/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
new file mode 100644
index 0000000..5c5b8a4
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
@@ -0,0 +1,244 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.driver;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.util.ProgramDriver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * General-purpose driver class for Mahout programs. Utilizes org.apache.hadoop.util.ProgramDriver to run
+ * main methods of other classes, but first loads up default properties from a properties file.
+ * <p/>
+ * To run locally:
+ *
+ * <pre>$MAHOUT_HOME/bin/mahout run shortJobName [over-ride ops]</pre>
+ * <p/>
+ * Works like this: by default, the file "driver.classes.props" is loaded from the classpath, which
+ * defines a mapping between short names like "vectordump" and fully qualified class names.
+ * The format of driver.classes.props is like so:
+ * <p/>
+ *
+ * <pre>fully.qualified.class.name = shortJobName : descriptive string</pre>
+ * <p/>
+ * The default properties to be applied to the program run is pulled out of, by default, "<shortJobName>.props"
+ * (also off of the classpath).
+ * <p/>
+ * The format of the default properties files is as follows:
+ * <pre>
+ i|input = /path/to/my/input
+ o|output = /path/to/my/output
+ m|jarFile = /path/to/jarFile
+ # etc - each line is shortArg|longArg = value
+ </pre>
+ *
+ * The next argument to the Driver is supposed to be the short name of the class to be run (as defined in the
+ * driver.classes.props file).
+ * <p/>
+ * Then the class which will be run will have it's main called with
+ *
+ * <pre>main(new String[] { "--input", "/path/to/my/input", "--output", "/path/to/my/output" });</pre>
+ *
+ * After all the "default" properties are loaded from the file, any further command-line arguments are taken in,
+ * and over-ride the defaults.
+ * <p/>
+ * So if your driver.classes.props looks like so:
+ *
+ * <pre>org.apache.mahout.utils.vectors.VectorDumper = vecDump : dump vectors from a sequence file</pre>
+ *
+ * and you have a file core/src/main/resources/vecDump.props which looks like
+ * <pre>
+ o|output = /tmp/vectorOut
+ s|seqFile = /my/vector/sequenceFile
+ </pre>
+ *
+ * And you execute the command-line:
+ *
+ * <pre>$MAHOUT_HOME/bin/mahout run vecDump -s /my/otherVector/sequenceFile</pre>
+ *
+ * Then org.apache.mahout.utils.vectors.VectorDumper.main() will be called with arguments:
+ * <pre>{"--output", "/tmp/vectorOut", "-s", "/my/otherVector/sequenceFile"}</pre>
+ */
+public final class MahoutDriver {
+
+ private static final Logger log = LoggerFactory.getLogger(MahoutDriver.class);
+
+ private MahoutDriver() {
+ }
+
+ public static void main(String[] args) throws Throwable {
+
+ Properties mainClasses = loadProperties("driver.classes.props");
+ if (mainClasses == null) {
+ mainClasses = loadProperties("driver.classes.default.props");
+ }
+ if (mainClasses == null) {
+ throw new IOException("Can't load any properties file?");
+ }
+
+ boolean foundShortName = false;
+ ProgramDriver programDriver = new ProgramDriver();
+ for (Object key : mainClasses.keySet()) {
+ String keyString = (String) key;
+ if (args.length > 0 && shortName(mainClasses.getProperty(keyString)).equals(args[0])) {
+ foundShortName = true;
+ }
+ if (args.length > 0 && keyString.equalsIgnoreCase(args[0]) && isDeprecated(mainClasses, keyString)) {
+ log.error(desc(mainClasses.getProperty(keyString)));
+ return;
+ }
+ if (isDeprecated(mainClasses, keyString)) {
+ continue;
+ }
+ addClass(programDriver, keyString, mainClasses.getProperty(keyString));
+ }
+
+ if (args.length < 1 || args[0] == null || "-h".equals(args[0]) || "--help".equals(args[0])) {
+ programDriver.driver(args);
+ return;
+ }
+
+ String progName = args[0];
+ if (!foundShortName) {
+ addClass(programDriver, progName, progName);
+ }
+ shift(args);
+
+ Properties mainProps = loadProperties(progName + ".props");
+ if (mainProps == null) {
+ log.warn("No {}.props found on classpath, will use command-line arguments only", progName);
+ mainProps = new Properties();
+ }
+
+ Map<String,String[]> argMap = new HashMap<>();
+ int i = 0;
+ while (i < args.length && args[i] != null) {
+ List<String> argValues = new ArrayList<>();
+ String arg = args[i];
+ i++;
+ if (arg.startsWith("-D")) { // '-Dkey=value' or '-Dkey=value1,value2,etc' case
+ String[] argSplit = arg.split("=");
+ arg = argSplit[0];
+ if (argSplit.length == 2) {
+ argValues.add(argSplit[1]);
+ }
+ } else { // '-key [values]' or '--key [values]' case.
+ while (i < args.length && args[i] != null) {
+ if (args[i].startsWith("-")) {
+ break;
+ }
+ argValues.add(args[i]);
+ i++;
+ }
+ }
+ argMap.put(arg, argValues.toArray(new String[argValues.size()]));
+ }
+
+ // Add properties from the .props file that are not overridden on the command line
+ for (String key : mainProps.stringPropertyNames()) {
+ String[] argNamePair = key.split("\\|");
+ String shortArg = '-' + argNamePair[0].trim();
+ String longArg = argNamePair.length < 2 ? null : "--" + argNamePair[1].trim();
+ if (!argMap.containsKey(shortArg) && (longArg == null || !argMap.containsKey(longArg))) {
+ argMap.put(longArg, new String[] {mainProps.getProperty(key)});
+ }
+ }
+
+ // Now add command-line args
+ List<String> argsList = new ArrayList<>();
+ argsList.add(progName);
+ for (Map.Entry<String,String[]> entry : argMap.entrySet()) {
+ String arg = entry.getKey();
+ if (arg.startsWith("-D")) { // arg is -Dkey - if value for this !isEmpty(), then arg -> -Dkey + "=" + value
+ String[] argValues = entry.getValue();
+ if (argValues.length > 0 && !argValues[0].trim().isEmpty()) {
+ arg += '=' + argValues[0].trim();
+ }
+ argsList.add(1, arg);
+ } else {
+ argsList.add(arg);
+ for (String argValue : Arrays.asList(argMap.get(arg))) {
+ if (!argValue.isEmpty()) {
+ argsList.add(argValue);
+ }
+ }
+ }
+ }
+
+ long start = System.currentTimeMillis();
+
+ programDriver.driver(argsList.toArray(new String[argsList.size()]));
+
+ if (log.isInfoEnabled()) {
+ log.info("Program took {} ms (Minutes: {})", System.currentTimeMillis() - start,
+ (System.currentTimeMillis() - start) / 60000.0);
+ }
+ }
+
+ private static boolean isDeprecated(Properties mainClasses, String keyString) {
+ return "deprecated".equalsIgnoreCase(shortName(mainClasses.getProperty(keyString)));
+ }
+
+ private static Properties loadProperties(String resource) throws IOException {
+ InputStream propsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(resource);
+ if (propsStream != null) {
+ try {
+ Properties properties = new Properties();
+ properties.load(propsStream);
+ return properties;
+ } finally {
+ Closeables.close(propsStream, true);
+ }
+ }
+ return null;
+ }
+
+ private static String[] shift(String[] args) {
+ System.arraycopy(args, 1, args, 0, args.length - 1);
+ args[args.length - 1] = null;
+ return args;
+ }
+
+ private static String shortName(String valueString) {
+ return valueString.contains(":") ? valueString.substring(0, valueString.indexOf(':')).trim() : valueString;
+ }
+
+ private static String desc(String valueString) {
+ return valueString.contains(":") ? valueString.substring(valueString.indexOf(':')).trim() : valueString;
+ }
+
+ private static void addClass(ProgramDriver driver, String classString, String descString) {
+ try {
+ Class<?> clazz = Class.forName(classString);
+ driver.addClass(shortName(descString), clazz, desc(descString));
+ } catch (Throwable t) {
+ log.warn("Unable to add class: {}", classString, t);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
new file mode 100644
index 0000000..4b2eea1
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
@@ -0,0 +1,229 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import java.io.Closeable;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+
+/**
+ * Allows evolutionary optimization where the state function can't be easily
+ * packaged for the optimizer to execute. A good example of this is with
+ * on-line learning where optimizing the learning parameters is desirable.
+ * We would like to pass training examples to the learning algorithms, but
+ * we definitely want to do the training in multiple threads and then after
+ * several training steps, we want to do a selection and mutation step.
+ *
+ * In such a case, it is highly desirable to leave most of the control flow
+ * in the hands of our caller. As such, this class provides three functions,
+ * <ul>
+ * <li> Storage of the evolutionary state. The state variables have payloads
+ * which can be anything that implements Payload.
+ * <li> Threaded execution of a single operation on each of the members of the
+ * population being evolved. In the on-line learning example, this is used for
+ * training all of the classifiers in the population.
+ * <li> Propagating mutations of the most successful members of the population.
+ * This propagation involves copying the state and the payload and then updating
+ * the payload after mutation of the evolutionary state.
+ * </ul>
+ *
+ * The State class that we use for storing the state of each member of the
+ * population also provides parameter mapping. Check out Mapping and State
+ * for more info.
+ *
+ * @see Mapping
+ * @see Payload
+ * @see State
+ *
+ * @param <T> The payload class.
+ */
+public class EvolutionaryProcess<T extends Payload<U>, U> implements Writable, Closeable {
+ // used to execute operations on the population in thread parallel.
+ private ExecutorService pool;
+
+ // threadCount is serialized so that we can reconstruct the thread pool
+ private int threadCount;
+
+ // list of members of the population
+ private List<State<T, U>> population;
+
+ // how big should the population be. If this is changed, it will take effect
+ // the next time the population is mutated.
+
+ private int populationSize;
+
+ public EvolutionaryProcess() {
+ population = new ArrayList<>();
+ }
+
+ /**
+ * Creates an evolutionary optimization framework with specified threadiness,
+ * population size and initial state.
+ * @param threadCount How many threads to use in parallelDo
+ * @param populationSize How large a population to use
+ * @param seed An initial population member
+ */
+ public EvolutionaryProcess(int threadCount, int populationSize, State<T, U> seed) {
+ this.populationSize = populationSize;
+ setThreadCount(threadCount);
+ initializePopulation(populationSize, seed);
+ }
+
+ private void initializePopulation(int populationSize, State<T, U> seed) {
+ population = Lists.newArrayList(seed);
+ for (int i = 0; i < populationSize; i++) {
+ population.add(seed.mutate());
+ }
+ }
+
+ public void add(State<T, U> value) {
+ population.add(value);
+ }
+
+ /**
+ * Nuke all but a few of the current population and then repopulate with
+ * variants of the survivors.
+ * @param survivors How many survivors we want to keep.
+ */
+ public void mutatePopulation(int survivors) {
+ // largest value first, oldest first in case of ties
+ Collections.sort(population);
+
+ // we copy here to avoid concurrent modification
+ List<State<T, U>> parents = new ArrayList<>(population.subList(0, survivors));
+ population.subList(survivors, population.size()).clear();
+
+ // fill out the population with offspring from the survivors
+ int i = 0;
+ while (population.size() < populationSize) {
+ population.add(parents.get(i % survivors).mutate());
+ i++;
+ }
+ }
+
+ /**
+ * Execute an operation on all of the members of the population with many threads. The
+ * return value is taken as the current fitness of the corresponding member.
+ * @param fn What to do on each member. Gets payload and the mapped parameters as args.
+ * @return The member of the population with the best fitness.
+ * @throws InterruptedException Shouldn't happen.
+ * @throws ExecutionException If fn throws an exception, that exception will be collected
+ * and rethrown nested in an ExecutionException.
+ */
+ public State<T, U> parallelDo(final Function<Payload<U>> fn) throws InterruptedException, ExecutionException {
+ Collection<Callable<State<T, U>>> tasks = new ArrayList<>();
+ for (final State<T, U> state : population) {
+ tasks.add(new Callable<State<T, U>>() {
+ @Override
+ public State<T, U> call() {
+ double v = fn.apply(state.getPayload(), state.getMappedParams());
+ state.setValue(v);
+ return state;
+ }
+ });
+ }
+
+ List<Future<State<T, U>>> r = pool.invokeAll(tasks);
+
+ // zip through the results and find the best one
+ double max = Double.NEGATIVE_INFINITY;
+ State<T, U> best = null;
+ for (Future<State<T, U>> future : r) {
+ State<T, U> s = future.get();
+ double value = s.getValue();
+ if (!Double.isNaN(value) && value >= max) {
+ max = value;
+ best = s;
+ }
+ }
+ if (best == null) {
+ best = r.get(0).get();
+ }
+
+ return best;
+ }
+
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ pool = Executors.newFixedThreadPool(threadCount);
+ }
+
+ public int getThreadCount() {
+ return threadCount;
+ }
+
+ public int getPopulationSize() {
+ return populationSize;
+ }
+
+ public List<State<T, U>> getPopulation() {
+ return population;
+ }
+
+ @Override
+ public void close() {
+ List<Runnable> remainingTasks = pool.shutdownNow();
+ try {
+ pool.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks");
+ }
+ if (!remainingTasks.isEmpty()) {
+ throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks");
+ }
+ }
+
+ public interface Function<T> {
+ double apply(T payload, double[] params);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(threadCount);
+ out.writeInt(population.size());
+ for (State<T, U> state : population) {
+ PolymorphicWritable.write(out, state);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ setThreadCount(input.readInt());
+ int n = input.readInt();
+ population = new ArrayList<>();
+ for (int i = 0; i < n; i++) {
+ State<T, U> state = (State<T, U>) PolymorphicWritable.read(input, State.class);
+ population.add(state);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java
new file mode 100644
index 0000000..41a8942
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Provides coordinate tranformations so that evolution can proceed on the entire space of
+ * reals but have the output limited and squished in convenient (and safe) ways.
+ */
+public abstract class Mapping extends DoubleFunction implements Writable {
+
+ private Mapping() {
+ }
+
+ public static final class SoftLimit extends Mapping {
+ private double min;
+ private double max;
+ private double scale;
+
+ public SoftLimit() {
+ }
+
+ private SoftLimit(double min, double max, double scale) {
+ this.min = min;
+ this.max = max;
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double v) {
+ return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(min);
+ out.writeDouble(max);
+ out.writeDouble(scale);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ min = in.readDouble();
+ max = in.readDouble();
+ scale = in.readDouble();
+ }
+ }
+
+ public static final class LogLimit extends Mapping {
+ private Mapping wrapped;
+
+ public LogLimit() {
+ }
+
+ private LogLimit(double low, double high) {
+ wrapped = softLimit(Math.log(low), Math.log(high));
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(wrapped.apply(v));
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ PolymorphicWritable.write(dataOutput, wrapped);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ wrapped = PolymorphicWritable.read(in, Mapping.class);
+ }
+ }
+
+ public static final class Exponential extends Mapping {
+ private double scale;
+
+ public Exponential() {
+ }
+
+ private Exponential(double scale) {
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(v * scale);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(scale);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ scale = in.readDouble();
+ }
+ }
+
+ public static final class Identity extends Mapping {
+ @Override
+ public double apply(double v) {
+ return v;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) {
+ // stateless
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) {
+ // stateless
+ }
+ }
+
+ /**
+ * Maps input to the open interval (min, max) with 0 going to the mean of min and
+ * max. When scale is large, a larger proportion of values are mapped to points
+ * near the boundaries. When scale is small, a larger proportion of values are mapped to
+ * points well within the boundaries.
+ * @param min The largest lower bound on values to be returned.
+ * @param max The least upper bound on values to be returned.
+ * @param scale Defines how sharp the boundaries are.
+ * @return A mapping that satisfies the desired constraint.
+ */
+ public static Mapping softLimit(double min, double max, double scale) {
+ return new SoftLimit(min, max, scale);
+ }
+
+ /**
+ * Maps input to the open interval (min, max) with 0 going to the mean of min and
+ * max. When scale is large, a larger proportion of values are mapped to points
+ * near the boundaries.
+ * @see #softLimit(double, double, double)
+ * @param min The largest lower bound on values to be returned.
+ * @param max The least upper bound on values to be returned.
+ * @return A mapping that satisfies the desired constraint.
+ */
+ public static Mapping softLimit(double min, double max) {
+ return softLimit(min, max, 1);
+ }
+
+ /**
+ * Maps input to positive values in the open interval (min, max) with
+ * 0 going to the geometric mean. Near the geometric mean, values are
+ * distributed roughly geometrically.
+ * @param low The largest lower bound for output results. Must be >0.
+ * @param high The least upper bound for output results. Must be >0.
+ * @return A mapped value.
+ */
+ public static Mapping logLimit(double low, double high) {
+ Preconditions.checkArgument(low > 0, "Lower bound for log limit must be > 0 but was %f", low);
+ Preconditions.checkArgument(high > 0, "Upper bound for log limit must be > 0 but was %f", high);
+ return new LogLimit(low, high);
+ }
+
+ /**
+ * Maps results to positive values.
+ * @return A positive value.
+ */
+ public static Mapping exponential() {
+ return exponential(1);
+ }
+
+ /**
+ * Maps results to positive values.
+ * @param scale If large, then large values are more likely.
+ * @return A positive value.
+ */
+ public static Mapping exponential(double scale) {
+ return new Exponential(scale);
+ }
+
+ /**
+ * Maps results to themselves.
+ * @return The original value.
+ */
+ public static Mapping identity() {
+ return new Identity();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java
new file mode 100644
index 0000000..920237d
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Payloads for evolutionary state must be copyable and updatable. The copy should be a deep copy
+ * unless some aspect of the state is sharable or immutable.
+ * <p/>
+ * During mutation, a copy is first made and then after the parameters in the State structure are
+ * suitably modified, update is called with the scaled versions of the parameters.
+ *
+ * @param <T>
+ * @see State
+ */
+public interface Payload<T> extends Writable {
+ Payload<T> copy();
+
+ void update(double[] params);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java
new file mode 100644
index 0000000..7a0fb5e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java
@@ -0,0 +1,302 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Locale;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Records evolutionary state and provides a mutation operation for recorded-step meta-mutation.
+ *
+ * You provide the payload, this class provides the mutation operations. During mutation,
+ * the payload is copied and after the state variables are changed, they are passed to the
+ * payload.
+ *
+ * Parameters are internally mutated in a state space that spans all of R^n, but parameters
+ * passed to the payload are transformed as specified by a call to setMap(). The default
+ * mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are
+ * also supported.
+ *
+ * More information on the underlying algorithm can be found in the following paper
+ *
+ * http://arxiv.org/abs/0803.3838
+ *
+ * @see Mapping
+ */
+public class State<T extends Payload<U>, U> implements Comparable<State<T, U>>, Writable {
+
+ // object count is kept to break ties in comparison.
+ private static final AtomicInteger OBJECT_COUNT = new AtomicInteger();
+
+ private int id = OBJECT_COUNT.getAndIncrement();
+ private Random gen = RandomUtils.getRandom();
+ // current state
+ private double[] params;
+ // mappers to transform state
+ private Mapping[] maps;
+ // omni-directional mutation
+ private double omni;
+ // directional mutation
+ private double[] step;
+ // current fitness value
+ private double value;
+ private T payload;
+
+ public State() {
+ }
+
+ /**
+ * Invent a new state with no momentum (yet).
+ */
+ public State(double[] x0, double omni) {
+ params = Arrays.copyOf(x0, x0.length);
+ this.omni = omni;
+ step = new double[params.length];
+ maps = new Mapping[params.length];
+ }
+
+ /**
+ * Deep copies a state, useful in mutation.
+ */
+ public State<T, U> copy() {
+ State<T, U> r = new State<>();
+ r.params = Arrays.copyOf(this.params, this.params.length);
+ r.omni = this.omni;
+ r.step = Arrays.copyOf(this.step, this.step.length);
+ r.maps = Arrays.copyOf(this.maps, this.maps.length);
+ if (this.payload != null) {
+ r.payload = (T) this.payload.copy();
+ }
+ r.gen = this.gen;
+ return r;
+ }
+
+ /**
+ * Clones this state with a random change in position. Copies the payload and
+ * lets it know about the change.
+ *
+ * @return A new state.
+ */
+ public State<T, U> mutate() {
+ double sum = 0;
+ for (double v : step) {
+ sum += v * v;
+ }
+ sum = Math.sqrt(sum);
+ double lambda = 1 + gen.nextGaussian();
+
+ State<T, U> r = this.copy();
+ double magnitude = 0.9 * omni + sum / 10;
+ r.omni = magnitude * -Math.log1p(-gen.nextDouble());
+ for (int i = 0; i < step.length; i++) {
+ r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
+ r.params[i] += r.step[i];
+ }
+ if (this.payload != null) {
+ r.payload.update(r.getMappedParams());
+ }
+ return r;
+ }
+
+ /**
+ * Defines the transformation for a parameter.
+ * @param i Which parameter's mapping to define.
+ * @param m The mapping to use.
+ * @see org.apache.mahout.ep.Mapping
+ */
+ public void setMap(int i, Mapping m) {
+ maps[i] = m;
+ }
+
+ /**
+ * Returns a transformed parameter.
+ * @param i The parameter to return.
+ * @return The value of the parameter.
+ */
+ public double get(int i) {
+ Mapping m = maps[i];
+ return m == null ? params[i] : m.apply(params[i]);
+ }
+
+ public int getId() {
+ return id;
+ }
+
+ public double[] getParams() {
+ return params;
+ }
+
+ public Mapping[] getMaps() {
+ return maps;
+ }
+
+ /**
+ * Returns all the parameters in mapped form.
+ * @return An array of parameters.
+ */
+ public double[] getMappedParams() {
+ double[] r = Arrays.copyOf(params, params.length);
+ for (int i = 0; i < params.length; i++) {
+ r[i] = get(i);
+ }
+ return r;
+ }
+
+ public double getOmni() {
+ return omni;
+ }
+
+ public double[] getStep() {
+ return step;
+ }
+
+ public T getPayload() {
+ return payload;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ public void setOmni(double omni) {
+ this.omni = omni;
+ }
+
+ public void setId(int id) {
+ this.id = id;
+ }
+
+ public void setStep(double[] step) {
+ this.step = step;
+ }
+
+ public void setMaps(Mapping[] maps) {
+ this.maps = maps;
+ }
+
+ public void setMaps(Iterable<Mapping> maps) {
+ Collection<Mapping> list = Lists.newArrayList(maps);
+ this.maps = list.toArray(new Mapping[list.size()]);
+ }
+
+ public void setValue(double v) {
+ value = v;
+ }
+
+ public void setPayload(T payload) {
+ this.payload = payload;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof State)) {
+ return false;
+ }
+ State<?,?> other = (State<?,?>) o;
+ return id == other.id && value == other.value;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashDouble(value) ^ id;
+ }
+
+ /**
+ * Natural order is to sort in descending order of score. Creation order is used as a
+ * tie-breaker.
+ *
+ * @param other The state to compare with.
+ * @return -1, 0, 1 if the other state is better, identical or worse than this one.
+ */
+ @Override
+ public int compareTo(State<T, U> other) {
+ int r = Double.compare(other.value, this.value);
+ if (r != 0) {
+ return r;
+ }
+ if (this.id < other.id) {
+ return -1;
+ }
+ if (this.id > other.id) {
+ return 1;
+ }
+ return 0;
+ }
+
+ @Override
+ public String toString() {
+ double sum = 0;
+ for (double v : step) {
+ sum += v * v;
+ }
+ return String.format(Locale.ENGLISH, "<S/%s %.3f %.3f>", payload, omni + Math.sqrt(sum), value);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
+ out.writeInt(params.length);
+ for (double v : params) {
+ out.writeDouble(v);
+ }
+ for (Mapping map : maps) {
+ PolymorphicWritable.write(out, map);
+ }
+
+ out.writeDouble(omni);
+ for (double v : step) {
+ out.writeDouble(v);
+ }
+
+ out.writeDouble(value);
+ PolymorphicWritable.write(out, payload);
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ id = input.readInt();
+ int n = input.readInt();
+ params = new double[n];
+ for (int i = 0; i < n; i++) {
+ params[i] = input.readDouble();
+ }
+
+ maps = new Mapping[n];
+ for (int i = 0; i < n; i++) {
+ maps[i] = PolymorphicWritable.read(input, Mapping.class);
+ }
+ omni = input.readDouble();
+ step = new double[n];
+ for (int i = 0; i < n; i++) {
+ step[i] = input.readDouble();
+ }
+ value = input.readDouble();
+ payload = (T) PolymorphicWritable.read(input, Payload.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java
new file mode 100644
index 0000000..4afe677
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java
@@ -0,0 +1,26 @@
+/**
+ * <p>Provides basic evolutionary optimization using <a href="http://arxiv.org/abs/0803.3838">recorded-step</a>
+ * mutation.</p>
+ *
+ * <p>With this style of optimization, we can optimize a function {@code f: R^n -> R} by stochastic
+ * hill-climbing with some of the benefits of conjugate gradient style history encoded in the mutation function.
+ * This mutation function will adapt to allow weakly directed search rather than using the somewhat more
+ * conventional symmetric Gaussian.</p>
+ *
+ * <p>With recorded-step mutation, the meta-mutation parameters are all auto-encoded in the current state of each point.
+ * This avoids the classic problem of having more mutation rate parameters than are in the original state and then
+ * requiring even more parameters to describe the meta-mutation rate. Instead, we store the previous point and one
+ * omni-directional mutation component. Mutation is performed by first mutating along the line formed by the previous
+ * and current points and then adding a scaled symmetric Gaussian. The magnitude of the omni-directional mutation is
+ * then mutated using itself as a scale.</p>
+ *
+ * <p>Because it is convenient to not restrict the parameter space, this package also provides convenient parameter
+ * mapping methods. These mapping methods map the set of reals to a finite open interval (a,b) in such a way that
+ * {@code lim_{x->-\inf} f(x) = a} and {@code lim_{x->\inf} f(x) = b}. The linear mapping is defined so that
+ * {@code f(0) = (a+b)/2} and the exponential mapping requires that a and b are both positive and has
+ * {@code f(0) = sqrt(ab)}. The linear mapping is useful for values that must stay roughly within a range but
+ * which are roughly uniform within the center of that range. The exponential
+ * mapping is useful for values that must stay within a range but whose distribution is roughly exponential near
+ * geometric mean of the end-points. An identity mapping is also supplied.</p>
+ */
+package org.apache.mahout.ep;
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
new file mode 100644
index 0000000..6618a1a
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+
+import java.io.IOException;
+
+public final class DistributedRowMatrixWriter {
+
+ private DistributedRowMatrixWriter() {
+ }
+
+ public static void write(Path outputDir, Configuration conf, Iterable<MatrixSlice> matrix) throws IOException {
+ FileSystem fs = outputDir.getFileSystem(conf);
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+ IntWritable.class, VectorWritable.class);
+ IntWritable topic = new IntWritable();
+ VectorWritable vector = new VectorWritable();
+ for (MatrixSlice slice : matrix) {
+ topic.set(slice.index());
+ vector.set(slice.vector());
+ writer.append(topic, vector);
+ }
+ writer.close();
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
new file mode 100644
index 0000000..f9ca52e
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import java.io.IOException;
+import java.util.List;
+
+public final class MatrixUtils {
+
+ private MatrixUtils() {
+ }
+
+ public static void write(Path outputDir, Configuration conf, VectorIterable matrix)
+ throws IOException {
+ FileSystem fs = outputDir.getFileSystem(conf);
+ fs.delete(outputDir, true);
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+ IntWritable.class, VectorWritable.class);
+ IntWritable topic = new IntWritable();
+ VectorWritable vector = new VectorWritable();
+ for (MatrixSlice slice : matrix) {
+ topic.set(slice.index());
+ vector.set(slice.vector());
+ writer.append(topic, vector);
+ }
+ writer.close();
+ }
+
+ public static Matrix read(Configuration conf, Path... modelPaths) throws IOException {
+ int numRows = -1;
+ int numCols = -1;
+ boolean sparse = false;
+ List<Pair<Integer, Vector>> rows = Lists.newArrayList();
+ for (Path modelPath : modelPaths) {
+ for (Pair<IntWritable, VectorWritable> row
+ : new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) {
+ rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
+ numRows = Math.max(numRows, row.getFirst().get());
+ sparse = !row.getSecond().get().isDense();
+ if (numCols < 0) {
+ numCols = row.getSecond().get().size();
+ }
+ }
+ }
+ if (rows.isEmpty()) {
+ throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it");
+ }
+ numRows++;
+ Vector[] arrayOfRows = new Vector[numRows];
+ for (Pair<Integer, Vector> pair : rows) {
+ arrayOfRows[pair.getFirst()] = pair.getSecond();
+ }
+ Matrix matrix;
+ if (sparse) {
+ matrix = new SparseRowMatrix(numRows, numCols, arrayOfRows);
+ } else {
+ matrix = new DenseMatrix(numRows, numCols);
+ for (int i = 0; i < numRows; i++) {
+ matrix.assignRow(i, arrayOfRows[i]);
+ }
+ }
+ return matrix;
+ }
+
+ public static OpenObjectIntHashMap<String> readDictionary(Configuration conf, Path... dictPath) {
+ OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<>();
+ for (Path dictionaryFile : dictPath) {
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().toString(), record.getSecond().get());
+ }
+ }
+ return dictionary;
+ }
+
+ public static String[] invertDictionary(OpenObjectIntHashMap<String> termIdMap) {
+ int maxTermId = -1;
+ for (String term : termIdMap.keys()) {
+ maxTermId = Math.max(maxTermId, termIdMap.get(term));
+ }
+ maxTermId++;
+ String[] dictionary = new String[maxTermId];
+ for (String term : termIdMap.keys()) {
+ dictionary[termIdMap.get(term)] = term;
+ }
+ return dictionary;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
new file mode 100644
index 0000000..0c45c9a
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.hadoop.io.Writable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Writable to handle serialization of a vector and a variable list of
+ * associated label indexes.
+ */
+public final class MultiLabelVectorWritable implements Writable {
+
+ private final VectorWritable vectorWritable = new VectorWritable();
+ private int[] labels;
+
+ public MultiLabelVectorWritable() {
+ }
+
+ public MultiLabelVectorWritable(Vector vector, int[] labels) {
+ this.vectorWritable.set(vector);
+ this.labels = labels;
+ }
+
+ public Vector getVector() {
+ return vectorWritable.get();
+ }
+
+ public void setVector(Vector vector) {
+ vectorWritable.set(vector);
+ }
+
+ public void setLabels(int[] labels) {
+ this.labels = labels;
+ }
+
+ public int[] getLabels() {
+ return labels;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ vectorWritable.readFields(in);
+ int labelSize = in.readInt();
+ labels = new int[labelSize];
+ for (int i = 0; i < labelSize; i++) {
+ labels[i] = in.readInt();
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ vectorWritable.write(out);
+ out.writeInt(labels.length);
+ for (int label : labels) {
+ out.writeInt(label);
+ }
+ }
+
+ public static MultiLabelVectorWritable read(DataInput in) throws IOException {
+ MultiLabelVectorWritable writable = new MultiLabelVectorWritable();
+ writable.readFields(in);
+ return writable;
+ }
+
+ public static void write(DataOutput out, SequentialAccessSparseVector ssv, int[] labels) throws IOException {
+ new MultiLabelVectorWritable(ssv, labels).write(out);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
new file mode 100644
index 0000000..dbe1f8b
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java
@@ -0,0 +1,116 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+
+/**
+ * See
+ * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
+ * this paper.</a>
+ */
+public final class AlternatingLeastSquaresSolver {
+
+ private AlternatingLeastSquaresSolver() {}
+
+ //TODO make feature vectors a simple array
+ public static Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) {
+
+ Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null");
+ Preconditions.checkArgument(!Iterables.isEmpty(featureVectors));
+ Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null");
+ Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty");
+ Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements());
+
+ int nui = ratingVector.getNumNondefaultElements();
+
+ Matrix MiIi = createMiIi(featureVectors, numFeatures);
+ Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector);
+
+ /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */
+ Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui);
+ /* compute Vi = MiIi * t(R(i,Ii)) */
+ Matrix Vi = MiIi.times(RiIiMaybeTransposed);
+ /* compute Ai * ui = Vi */
+ return solve(Ai, Vi);
+ }
+
+ private static Vector solve(Matrix Ai, Matrix Vi) {
+ return new QRDecomposition(Ai).solve(Vi).viewColumn(0);
+ }
+
+ static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) {
+ Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix");
+ double lambdaTimesNui = lambda * nui;
+ int numCols = matrix.numCols();
+ for (int n = 0; n < numCols; n++) {
+ matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui);
+ }
+ return matrix;
+ }
+
+ private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) {
+
+ double lambdaTimesNui = lambda * nui;
+ int rows = MiIi.numRows();
+
+ double[][] result = new double[rows][rows];
+
+ for (int i = 0; i < rows; i++) {
+ for (int j = i; j < rows; j++) {
+ double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j));
+ if (i != j) {
+ result[i][j] = dot;
+ result[j][i] = dot;
+ } else {
+ result[i][i] = dot + lambdaTimesNui;
+ }
+ }
+ }
+ return new DenseMatrix(result, true);
+ }
+
+
+ static Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) {
+ double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)];
+ int n = 0;
+ for (Vector featureVector : featureVectors) {
+ for (int m = 0; m < numFeatures; m++) {
+ MiIi[m][n] = featureVector.getQuick(m);
+ }
+ n++;
+ }
+ return new DenseMatrix(MiIi, true);
+ }
+
+ static Matrix createRiIiMaybeTransposed(Vector ratingVector) {
+ Preconditions.checkArgument(ratingVector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order");
+
+ double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1];
+ int index = 0;
+ for (Vector.Element elem : ratingVector.nonZeroes()) {
+ RiIiMaybeTransposed[index++][0] = elem.get();
+ }
+ return new DenseMatrix(RiIiMaybeTransposed, true);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
new file mode 100644
index 0000000..5d77898
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java
@@ -0,0 +1,171 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.als;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/** see <a href="http://research.yahoo.com/pub/2433">Collaborative Filtering for Implicit Feedback Datasets</a> */
+public class ImplicitFeedbackAlternatingLeastSquaresSolver {
+
+ private final int numFeatures;
+ private final double alpha;
+ private final double lambda;
+ private final int numTrainingThreads;
+
+ private final OpenIntObjectHashMap<Vector> Y;
+ private final Matrix YtransposeY;
+
+ private static final Logger log = LoggerFactory.getLogger(ImplicitFeedbackAlternatingLeastSquaresSolver.class);
+
+ public ImplicitFeedbackAlternatingLeastSquaresSolver(int numFeatures, double lambda, double alpha,
+ OpenIntObjectHashMap<Vector> Y, int numTrainingThreads) {
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.alpha = alpha;
+ this.Y = Y;
+ this.numTrainingThreads = numTrainingThreads;
+ YtransposeY = getYtransposeY(Y);
+ }
+
+ public Vector solve(Vector ratings) {
+ return solve(YtransposeY.plus(getYtransponseCuMinusIYPlusLambdaI(ratings)), getYtransponseCuPu(ratings));
+ }
+
+ private static Vector solve(Matrix A, Matrix y) {
+ return new QRDecomposition(A).solve(y).viewColumn(0);
+ }
+
+ double confidence(double rating) {
+ return 1 + alpha * rating;
+ }
+
+ /* Y' Y */
+ public Matrix getYtransposeY(final OpenIntObjectHashMap<Vector> Y) {
+
+ ExecutorService queue = Executors.newFixedThreadPool(numTrainingThreads);
+ if (log.isInfoEnabled()) {
+ log.info("Starting the computation of Y'Y");
+ }
+ long startTime = System.nanoTime();
+ final IntArrayList indexes = Y.keys();
+ final int numIndexes = indexes.size();
+
+ final double[][] YtY = new double[numFeatures][numFeatures];
+
+ // Compute Y'Y by dot products between the 'columns' of Y
+ for (int i = 0; i < numFeatures; i++) {
+ for (int j = i; j < numFeatures; j++) {
+
+ final int ii = i;
+ final int jj = j;
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ double dot = 0;
+ for (int k = 0; k < numIndexes; k++) {
+ Vector row = Y.get(indexes.getQuick(k));
+ dot += row.getQuick(ii) * row.getQuick(jj);
+ }
+ YtY[ii][jj] = dot;
+ if (ii != jj) {
+ YtY[jj][ii] = dot;
+ }
+ }
+ });
+
+ }
+ }
+ queue.shutdown();
+ try {
+ queue.awaitTermination(1, TimeUnit.DAYS);
+ } catch (InterruptedException e) {
+ log.error("Error during Y'Y queue shutdown", e);
+ throw new RuntimeException("Error during Y'Y queue shutdown");
+ }
+ if (log.isInfoEnabled()) {
+ log.info("Computed Y'Y in " + (System.nanoTime() - startTime) / 1000000.0 + " ms" );
+ }
+ return new DenseMatrix(YtY, true);
+ }
+
+ /** Y' (Cu - I) Y + λ I */
+ private Matrix getYtransponseCuMinusIYPlusLambdaI(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!");
+
+ /* (Cu -I) Y */
+ OpenIntObjectHashMap<Vector> CuMinusIY = new OpenIntObjectHashMap<>(userRatings.getNumNondefaultElements());
+ for (Element e : userRatings.nonZeroes()) {
+ CuMinusIY.put(e.index(), Y.get(e.index()).times(confidence(e.get()) - 1));
+ }
+
+ Matrix YtransponseCuMinusIY = new DenseMatrix(numFeatures, numFeatures);
+
+ /* Y' (Cu -I) Y by outer products */
+ for (Element e : userRatings.nonZeroes()) {
+ for (Element feature : Y.get(e.index()).all()) {
+ Vector partial = CuMinusIY.get(e.index()).times(feature.get());
+ YtransponseCuMinusIY.viewRow(feature.index()).assign(partial, Functions.PLUS);
+ }
+ }
+
+ /* Y' (Cu - I) Y + λ I add lambda on the diagonal */
+ for (int feature = 0; feature < numFeatures; feature++) {
+ YtransponseCuMinusIY.setQuick(feature, feature, YtransponseCuMinusIY.getQuick(feature, feature) + lambda);
+ }
+
+ return YtransponseCuMinusIY;
+ }
+
+ /** Y' Cu p(u) */
+ private Matrix getYtransponseCuPu(Vector userRatings) {
+ Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!");
+
+ Vector YtransponseCuPu = new DenseVector(numFeatures);
+
+ for (Element e : userRatings.nonZeroes()) {
+ YtransponseCuPu.assign(Y.get(e.index()).times(confidence(e.get())), Functions.PLUS);
+ }
+
+ return columnVectorAsMatrix(YtransponseCuPu);
+ }
+
+ private Matrix columnVectorAsMatrix(Vector v) {
+ double[][] matrix = new double[numFeatures][1];
+ for (Element e : v.all()) {
+ matrix[e.index()][0] = e.get();
+ }
+ return new DenseMatrix(matrix, true);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
new file mode 100644
index 0000000..0233848
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+import java.io.Closeable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public class AsyncEigenVerifier extends SimpleEigenVerifier implements Closeable {
+
+ private final ExecutorService threadPool;
+ private EigenStatus status;
+ private boolean finished;
+ private boolean started;
+
+ public AsyncEigenVerifier() {
+ threadPool = Executors.newFixedThreadPool(1);
+ status = new EigenStatus(-1, 0);
+ }
+
+ @Override
+ public synchronized EigenStatus verify(VectorIterable corpus, Vector vector) {
+ if (!finished && !started) { // not yet started or finished, so start!
+ status = new EigenStatus(-1, 0);
+ Vector vectorCopy = vector.clone();
+ threadPool.execute(new VerifierRunnable(corpus, vectorCopy));
+ started = true;
+ }
+ if (finished) {
+ finished = false;
+ }
+ return status;
+ }
+
+ @Override
+ public void close() {
+ this.threadPool.shutdownNow();
+ }
+ protected EigenStatus innerVerify(VectorIterable corpus, Vector vector) {
+ return super.verify(corpus, vector);
+ }
+
+ private class VerifierRunnable implements Runnable {
+ private final VectorIterable corpus;
+ private final Vector vector;
+
+ protected VerifierRunnable(VectorIterable corpus, Vector vector) {
+ this.corpus = corpus;
+ this.vector = vector;
+ }
+
+ @Override
+ public void run() {
+ EigenStatus status = innerVerify(corpus, vector);
+ synchronized (AsyncEigenVerifier.this) {
+ AsyncEigenVerifier.this.status = status;
+ finished = true;
+ started = false;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
new file mode 100644
index 0000000..a284f50
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+public class EigenStatus {
+ private final double eigenValue;
+ private final double cosAngle;
+ private volatile Boolean inProgress;
+
+ public EigenStatus(double eigenValue, double cosAngle) {
+ this(eigenValue, cosAngle, true);
+ }
+
+ public EigenStatus(double eigenValue, double cosAngle, boolean inProgress) {
+ this.eigenValue = eigenValue;
+ this.cosAngle = cosAngle;
+ this.inProgress = inProgress;
+ }
+
+ public double getCosAngle() {
+ return cosAngle;
+ }
+
+ public double getEigenValue() {
+ return eigenValue;
+ }
+
+ public boolean inProgress() {
+ return inProgress;
+ }
+
+ void setInProgress(boolean status) {
+ inProgress = status;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
new file mode 100644
index 0000000..71aaa30
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.decomposer;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public class SimpleEigenVerifier implements SingularVectorVerifier {
+
+ @Override
+ public EigenStatus verify(VectorIterable corpus, Vector vector) {
+ Vector resultantVector = corpus.timesSquared(vector);
+ double newNorm = resultantVector.norm(2);
+ double oldNorm = vector.norm(2);
+ double eigenValue;
+ double cosAngle;
+ if (newNorm > 0 && oldNorm > 0) {
+ eigenValue = newNorm / oldNorm;
+ cosAngle = resultantVector.dot(vector) / newNorm * oldNorm;
+ } else {
+ eigenValue = 1.0;
+ cosAngle = 0.0;
+ }
+ return new EigenStatus(eigenValue, cosAngle, false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
new file mode 100644
index 0000000..a9a7af8
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+
+public interface SingularVectorVerifier {
+ EigenStatus verify(VectorIterable eigenMatrix, Vector vector);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
new file mode 100644
index 0000000..ac9cc41
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import org.apache.mahout.math.Vector;
+
+
+public interface EigenUpdater {
+ void update(Vector pseudoEigen, Vector trainingVector, TrainingState currentState);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
new file mode 100644
index 0000000..5b5cc9b
--- /dev/null
+++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java
@@ -0,0 +1,342 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.decomposer.hebbian;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Properties;
+import java.util.Random;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.decomposer.AsyncEigenVerifier;
+import org.apache.mahout.math.decomposer.EigenStatus;
+import org.apache.mahout.math.decomposer.SingularVectorVerifier;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.function.TimesFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The Hebbian solver is an iterative, sparse, singular value decomposition solver, based on the paper
+ * <a href="http://www.dcs.shef.ac.uk/~genevieve/gorrell_webb.pdf">Generalized Hebbian Algorithm for
+ * Latent Semantic Analysis</a> (2005) by Genevieve Gorrell and Brandyn Webb (a.k.a. Simon Funk).
+ * TODO: more description here! For now: read the inline comments, and the comments for the constructors.
+ */
+public class HebbianSolver {
+
+ private static final Logger log = LoggerFactory.getLogger(HebbianSolver.class);
+ private static final boolean DEBUG = false;
+
+ private final EigenUpdater updater;
+ private final SingularVectorVerifier verifier;
+ private final double convergenceTarget;
+ private final int maxPassesPerEigen;
+ private final Random rng = RandomUtils.getRandom();
+
+ private int numPasses = 0;
+
+ /**
+ * Creates a new HebbianSolver
+ *
+ * @param updater
+ * {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess"
+ * singular vector one data-point presentation at a time.
+ * @param verifier
+ * {@link SingularVectorVerifier } an object which perpetually tries to check how close to
+ * convergence the current singular vector is (typically is a
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this
+ * in the background in another thread, while the main thread continues to converge)
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done
+ * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget.
+ */
+ public HebbianSolver(EigenUpdater updater,
+ SingularVectorVerifier verifier,
+ double convergenceTarget,
+ int maxPassesPerEigen) {
+ this.updater = updater;
+ this.verifier = verifier;
+ this.convergenceTarget = convergenceTarget;
+ this.maxPassesPerEigen = maxPassesPerEigen;
+ }
+
+ /**
+ * Creates a new HebbianSolver with maxPassesPerEigen = Integer.MAX_VALUE (i.e. keep on iterating until
+ * convergenceTarget is reached). <b>Not recommended</b> unless only looking for
+ * the first few (5, maybe 10?) singular
+ * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors.
+ *
+ * @param updater {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess"
+ * singular vector one data-point presentation at a time.
+ * @param verifier {@link org.apache.mahout.math.decomposer.SingularVectorVerifier }
+ * an object which perpetually tries to check how close to
+ * convergence the current singular vector is (typically is a
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this
+ * in the background in another thread, while the main thread continues to converge)
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ */
+ public HebbianSolver(EigenUpdater updater,
+ SingularVectorVerifier verifier,
+ double convergenceTarget) {
+ this(updater,
+ verifier,
+ convergenceTarget,
+ Integer.MAX_VALUE);
+ }
+
+ /**
+ * <b>This is the recommended constructor to use if you're not sure</b>
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a
+ * (single) background thread.
+ *
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done
+ * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget.
+ */
+ public HebbianSolver(double convergenceTarget, int maxPassesPerEigen) {
+ this(new HebbianUpdater(),
+ new AsyncEigenVerifier(),
+ convergenceTarget,
+ maxPassesPerEigen);
+ }
+
+ /**
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single)
+ * background thread, with
+ * maxPassesPerEigen set to Integer.MAX_VALUE. <b>Not recommended</b> unless only looking
+ * for the first few (5, maybe 10?) singular
+ * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors.
+ *
+ * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the
+ * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input
+ * corpus
+ */
+ public HebbianSolver(double convergenceTarget) {
+ this(convergenceTarget, Integer.MAX_VALUE);
+ }
+
+ /**
+ * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default
+ * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single)
+ * background thread, with
+ * convergenceTarget set to 0, which means that the solver will not really care about convergence as a loop-exiting
+ * criterion (but will be checking for convergence anyways, so it will be logged and singular values will be
+ * saved).
+ *
+ * @param numPassesPerEigen the exact number of times the verifier will check convergence status in the background
+ * before the solver will move on to the next eigen-vector.
+ */
+ public HebbianSolver(int numPassesPerEigen) {
+ this(0.0, numPassesPerEigen);
+ }
+
+ /**
+ * Primary singular vector solving method.
+ *
+ * @param corpus input matrix to find singular vectors of. Needs not be symmetric, should probably be sparse (in
+ * fact the input vectors are not mutated, and accessed only via dot-products and sums, so they should be
+ * {@link org.apache.mahout.math.SequentialAccessSparseVector }
+ * @param desiredRank the number of singular vectors to find (in roughly decreasing order by singular value)
+ * @return the final {@link TrainingState } of the solver, after desiredRank singular vectors (and approximate
+ * singular values) have been found.
+ */
+ public TrainingState solve(Matrix corpus,
+ int desiredRank) {
+ int cols = corpus.numCols();
+ Matrix eigens = new DenseMatrix(desiredRank, cols);
+ List<Double> eigenValues = new ArrayList<>();
+ log.info("Finding {} singular vectors of matrix with {} rows, via Hebbian", desiredRank, corpus.numRows());
+ /*
+ * The corpusProjections matrix is a running cache of the residual projection of each corpus vector against all
+ * of the previously found singular vectors. Without this, if multiple passes over the data is made (per
+ * singular vector), recalculating these projections eventually dominates the computational complexity of the
+ * solver.
+ */
+ Matrix corpusProjections = new DenseMatrix(corpus.numRows(), desiredRank);
+ TrainingState state = new TrainingState(eigens, corpusProjections);
+ for (int i = 0; i < desiredRank; i++) {
+ Vector currentEigen = new DenseVector(cols);
+ Vector previousEigen = null;
+ while (hasNotConverged(currentEigen, corpus, state)) {
+ int randomStartingIndex = getRandomStartingIndex(corpus, eigens);
+ Vector initialTrainingVector = corpus.viewRow(randomStartingIndex);
+ state.setTrainingIndex(randomStartingIndex);
+ updater.update(currentEigen, initialTrainingVector, state);
+ for (int corpusRow = 0; corpusRow < corpus.numRows(); corpusRow++) {
+ state.setTrainingIndex(corpusRow);
+ if (corpusRow != randomStartingIndex) {
+ updater.update(currentEigen, corpus.viewRow(corpusRow), state);
+ }
+ }
+ state.setFirstPass(false);
+ if (DEBUG) {
+ if (previousEigen == null) {
+ previousEigen = currentEigen.clone();
+ } else {
+ double dot = currentEigen.dot(previousEigen);
+ if (dot > 0.0) {
+ dot /= currentEigen.norm(2) * previousEigen.norm(2);
+ }
+ // log.info("Current pass * previous pass = {}", dot);
+ }
+ }
+ }
+ // converged!
+ double eigenValue = state.getStatusProgress().get(state.getStatusProgress().size() - 1).getEigenValue();
+ // it's actually more efficient to do this to normalize than to call currentEigen = currentEigen.normalize(),
+ // because the latter does a clone, which isn't necessary here.
+ currentEigen.assign(new TimesFunction(), 1 / currentEigen.norm(2));
+ eigens.assignRow(i, currentEigen);
+ eigenValues.add(eigenValue);
+ state.setCurrentEigenValues(eigenValues);
+ log.info("Found eigenvector {}, eigenvalue: {}", i, eigenValue);
+
+ /**
+ * TODO: Persist intermediate output!
+ */
+ state.setFirstPass(true);
+ state.setNumEigensProcessed(state.getNumEigensProcessed() + 1);
+ state.setActivationDenominatorSquared(0);
+ state.setActivationNumerator(0);
+ state.getStatusProgress().clear();
+ numPasses = 0;
+ }
+ return state;
+ }
+
+ /**
+ * You have to start somewhere...
+ * TODO: start instead wherever you find a vector with maximum residual length after subtracting off the projection
+ * TODO: onto all previous eigenvectors.
+ *
+ * @param corpus the corpus matrix
+ * @param eigens not currently used, but should be (see above TODO)
+ * @return the index into the corpus where the "starting seed" input vector lies.
+ */
+ private int getRandomStartingIndex(Matrix corpus, Matrix eigens) {
+ int index;
+ Vector v;
+ do {
+ double r = rng.nextDouble();
+ index = (int) (r * corpus.numRows());
+ v = corpus.viewRow(index);
+ } while (v == null || v.norm(2) == 0 || v.getNumNondefaultElements() < 5);
+ return index;
+ }
+
+ /**
+ * Uses the {@link SingularVectorVerifier } to check for convergence
+ *
+ * @param currentPseudoEigen the purported singular vector whose convergence is being checked
+ * @param corpus the corpus to check against
+ * @param state contains the previous eigens, various other solving state {@link TrainingState}
+ * @return true if <em>either</em> we have converged, <em>or</em> maxPassesPerEigen has been exceeded.
+ */
+ protected boolean hasNotConverged(Vector currentPseudoEigen,
+ Matrix corpus,
+ TrainingState state) {
+ numPasses++;
+ if (state.isFirstPass()) {
+ log.info("First pass through the corpus, no need to check convergence...");
+ return true;
+ }
+ Matrix previousEigens = state.getCurrentEigens();
+ log.info("Have made {} passes through the corpus, checking convergence...", numPasses);
+ /*
+ * Step 1: orthogonalize currentPseudoEigen by subtracting off eigen(i) * helper.get(i)
+ * Step 2: zero-out the helper vector because it has already helped.
+ */
+ for (int i = 0; i < state.getNumEigensProcessed(); i++) {
+ Vector previousEigen = previousEigens.viewRow(i);
+ currentPseudoEigen.assign(previousEigen, new PlusMult(-state.getHelperVector().get(i)));
+ state.getHelperVector().set(i, 0);
+ }
+ if (currentPseudoEigen.norm(2) > 0) {
+ for (int i = 0; i < state.getNumEigensProcessed(); i++) {
+ Vector previousEigen = previousEigens.viewRow(i);
+ log.info("dot with previous: {}", previousEigen.dot(currentPseudoEigen) / currentPseudoEigen.norm(2));
+ }
+ }
+ /*
+ * Step 3: verify how eigen-like the prospective eigen is. This is potentially asynchronous.
+ */
+ EigenStatus status = verify(corpus, currentPseudoEigen);
+ if (status.inProgress()) {
+ log.info("Verifier not finished, making another pass...");
+ } else {
+ log.info("Has 1 - cosAngle: {}, convergence target is: {}", 1.0 - status.getCosAngle(), convergenceTarget);
+ state.getStatusProgress().add(status);
+ }
+ return
+ state.getStatusProgress().size() <= maxPassesPerEigen
+ && 1.0 - status.getCosAngle() > convergenceTarget;
+ }
+
+ protected EigenStatus verify(Matrix corpus, Vector currentPseudoEigen) {
+ return verifier.verify(corpus, currentPseudoEigen);
+ }
+
+ public static void main(String[] args) {
+ Properties props = new Properties();
+ String propertiesFile = args.length > 0 ? args[0] : "config/solver.properties";
+ // props.load(new FileInputStream(propertiesFile));
+
+ String corpusDir = props.getProperty("solver.input.dir");
+ String outputDir = props.getProperty("solver.output.dir");
+ if (corpusDir == null || corpusDir.isEmpty() || outputDir == null || outputDir.isEmpty()) {
+ log.error("{} must contain values for solver.input.dir and solver.output.dir", propertiesFile);
+ return;
+ }
+ //int inBufferSize = Integer.parseInt(props.getProperty("solver.input.bufferSize"));
+ int rank = Integer.parseInt(props.getProperty("solver.output.desiredRank"));
+ double convergence = Double.parseDouble(props.getProperty("solver.convergence"));
+ int maxPasses = Integer.parseInt(props.getProperty("solver.maxPasses"));
+ //int numThreads = Integer.parseInt(props.getProperty("solver.verifier.numThreads"));
+
+ HebbianUpdater updater = new HebbianUpdater();
+ SingularVectorVerifier verifier = new AsyncEigenVerifier();
+ HebbianSolver solver = new HebbianSolver(updater, verifier, convergence, maxPasses);
+ Matrix corpus = null;
+ /*
+ if (numThreads <= 1) {
+ // corpus = new DiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize);
+ } else {
+ // corpus = new ParallelMultiplyingDiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize, numThreads);
+ }
+ */
+ long now = System.currentTimeMillis();
+ TrainingState finalState = solver.solve(corpus, rank);
+ long time = (System.currentTimeMillis() - now) / 1000;
+ log.info("Solved {} eigenVectors in {} seconds. Persisted to {}",
+ finalState.getCurrentEigens().rowSize(), time, outputDir);
+ }
+
+
+}