You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@horn.apache.org by ed...@apache.org on 2016/04/26 05:46:24 UTC
[3/4] incubator-horn git commit: Code refactoring
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
new file mode 100644
index 0000000..f87e771
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/AbstractLayeredNeuralNetwork.java
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.horn.funcs.FunctionFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
+/**
+ * AbstractLayeredNeuralNetwork defines the general operations for derivative
+ * layered models, include Linear Regression, Logistic Regression, Multilayer
+ * Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc.
+ *
+ * In general, these models consist of neurons which are aligned in layers.
+ * Between layers, for any two adjacent layers, the neurons are connected to
+ * form a bipartite weighted graph.
+ *
+ */
+abstract class AbstractLayeredNeuralNetwork extends AbstractNeuralNetwork {
+
+ private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+
+ double trainingError;
+
+ /* The momentumWeight */
+ protected double momentumWeight;
+
+ /* The cost function of the model */
+ protected DoubleDoubleFunction costFunction;
+
+ /* Record the size of each layer */
+ protected List<Integer> layerSizeList;
+
+ protected TrainingMethod trainingMethod;
+
+ protected LearningStyle learningStyle;
+
+ public static enum TrainingMethod {
+ GRADIENT_DESCENT
+ }
+
+ public static enum LearningStyle {
+ UNSUPERVISED,
+ SUPERVISED
+ }
+
+ public AbstractLayeredNeuralNetwork() {
+ this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
+ this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
+ this.learningStyle = LearningStyle.SUPERVISED;
+ }
+
+ public AbstractLayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
+ super(conf, modelPath);
+ }
+
+ public void setTrainingMethod(TrainingMethod method) {
+ this.trainingMethod = method;
+ }
+
+ public TrainingMethod getTrainingMethod() {
+ return this.trainingMethod;
+ }
+
+ public void setLearningStyle(LearningStyle style) {
+ this.learningStyle = style;
+ }
+
+ public LearningStyle getLearningStyle() {
+ return this.learningStyle;
+ }
+
+ /**
+ * Set the cost function for the model.
+ *
+ * @param costFunction
+ */
+ public void setCostFunction(DoubleDoubleFunction costFunction) {
+ this.costFunction = costFunction;
+ }
+
+ /**
+ * Add a layer of neurons with specified size. If the added layer is not the
+ * first layer, it will automatically connects the neurons between with the
+ * previous layer.
+ *
+ * @param size
+ * @param isFinalLayer If false, add a bias neuron.
+ * @param squashingFunction The squashing function for this layer, input layer
+ * is f(x) = x by default.
+ * @return The layer index, starts with 0.
+ */
+ public abstract int addLayer(int size, boolean isFinalLayer,
+ DoubleFunction squashingFunction);
+
+ /**
+ * Get the size of a particular layer.
+ *
+ * @param layer
+ * @return The layer size.
+ */
+ public int getLayerSize(int layer) {
+ Preconditions.checkArgument(
+ layer >= 0 && layer < this.layerSizeList.size(),
+ String.format("Input must be in range [0, %d]\n",
+ this.layerSizeList.size() - 1));
+ return this.layerSizeList.get(layer);
+ }
+
+ /**
+ * Get the layer size list.
+ *
+ * @return The layer size list.
+ */
+ protected List<Integer> getLayerSizeList() {
+ return this.layerSizeList;
+ }
+
+ /**
+ * Get the weights between layer layerIdx and layerIdx + 1
+ *
+ * @param layerIdx The index of the layer
+ * @return The weights in form of {@link DoubleMatrix}
+ */
+ public abstract DoubleMatrix getWeightsByLayer(int layerIdx);
+
+ /**
+ * Get the updated weights using one training instance.
+ *
+ * @param trainingInstance The trainingInstance is the concatenation of
+ * feature vector and class label vector.
+ * @return The update of each weight, in form of matrix list.
+ * @throws Exception
+ */
+ public abstract DoubleMatrix[] trainByInstance(DoubleVector trainingInstance);
+
+ /**
+ * Get the output calculated by the model.
+ *
+ * @param instance The feature instance.
+ * @return a new vector with the result of the operation.
+ */
+ public abstract DoubleVector getOutput(DoubleVector instance);
+
+ /**
+ * Calculate the training error based on the labels and outputs.
+ *
+ * @param labels
+ * @param output
+ */
+ protected abstract void calculateTrainingError(DoubleVector labels,
+ DoubleVector output);
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ super.readFields(input);
+ // read momentum weight
+ this.momentumWeight = input.readDouble();
+
+ // read cost function
+ this.costFunction = FunctionFactory
+ .createDoubleDoubleFunction(WritableUtils.readString(input));
+
+ // read layer size list
+ int numLayers = input.readInt();
+ this.layerSizeList = Lists.newArrayList();
+ for (int i = 0; i < numLayers; ++i) {
+ this.layerSizeList.add(input.readInt());
+ }
+
+ this.trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
+ this.learningStyle = WritableUtils.readEnum(input, LearningStyle.class);
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ super.write(output);
+ // write momentum weight
+ output.writeDouble(this.momentumWeight);
+
+ // write cost function
+ WritableUtils.writeString(output, costFunction.getFunctionName());
+
+ // write layer size list
+ output.writeInt(this.layerSizeList.size());
+ for (Integer aLayerSizeList : this.layerSizeList) {
+ output.writeInt(aLayerSizeList);
+ }
+
+ WritableUtils.writeEnum(output, this.trainingMethod);
+ WritableUtils.writeEnum(output, this.learningStyle);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
new file mode 100644
index 0000000..45f56a3
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetwork.java
@@ -0,0 +1,237 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+
+import org.apache.commons.lang.SerializationUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.ml.util.DefaultFeatureTransformer;
+import org.apache.hama.ml.util.FeatureTransformer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+
+/**
+ * NeuralNetwork defines the general operations for all the derivative models.
+ * Typically, all derivative models such as Linear Regression, Logistic
+ * Regression, and Multilayer Perceptron consist of neurons and the weights
+ * between neurons.
+ *
+ */
+public abstract class AbstractNeuralNetwork implements Writable {
+ protected HamaConfiguration conf;
+ protected FileSystem fs;
+
+ private static final double DEFAULT_LEARNING_RATE = 0.5;
+
+ protected double learningRate;
+ protected boolean learningRateDecay = false;
+
+ // the name of the model
+ protected String modelType;
+ // the path to store the model
+ protected String modelPath;
+
+ protected FeatureTransformer featureTransformer;
+
+ public AbstractNeuralNetwork() {
+ this.learningRate = DEFAULT_LEARNING_RATE;
+ this.modelType = this.getClass().getSimpleName();
+ this.featureTransformer = new DefaultFeatureTransformer();
+ }
+
+ public AbstractNeuralNetwork(String modelPath) {
+ this.modelPath = modelPath;
+ }
+
+ public AbstractNeuralNetwork(HamaConfiguration conf, String modelPath) {
+ try {
+ this.conf = conf;
+ this.fs = FileSystem.get(conf);
+ this.modelPath = modelPath;
+
+ this.readFromModel();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ }
+
+ public void isLearningRateDecay(boolean decay) {
+ this.learningRateDecay = decay;
+ }
+
+ public String getModelType() {
+ return this.modelType;
+ }
+
+ /**
+ * Train the model with the path of given training data and parameters.
+ *
+ * @param dataInputPath The path of the training data.
+ * @param trainingParams The parameters for training.
+ * @throws InterruptedException
+ * @throws ClassNotFoundException
+ * @throws IOException
+ */
+ public BSPJob train(Configuration conf) throws ClassNotFoundException, IOException, InterruptedException {
+ Preconditions.checkArgument(this.modelPath != null,
+ "Please set the model path before training.");
+
+ // train with BSP job
+ return trainInternal((HamaConfiguration) conf);
+ }
+
+ /**
+ * Train the model with the path of given training data and parameters.
+ */
+ protected abstract BSPJob trainInternal(HamaConfiguration hamaConf)
+ throws IOException, InterruptedException, ClassNotFoundException;
+
+ /**
+ * Read the model meta-data from the specified location.
+ *
+ * @throws IOException
+ */
+ protected void readFromModel() throws IOException {
+ Preconditions.checkArgument(this.modelPath != null,
+ "Model path has not been set.");
+ FSDataInputStream is = new FSDataInputStream(fs.open(new Path(modelPath)));
+ this.readFields(is);
+ Closeables.close(is, false);
+ }
+
+ /**
+ * Write the model data to specified location.
+ *
+ * @throws IOException
+ */
+ public void writeModelToFile() throws IOException {
+ Preconditions.checkArgument(this.modelPath != null,
+ "Model path has not been set.");
+
+ FSDataOutputStream is = fs.create(new Path(this.modelPath), true);
+ this.write(is);
+
+ Closeables.close(is, false);
+ }
+
+ /**
+ * Set the model path.
+ *
+ * @param modelPath
+ */
+ public void setModelPath(String modelPath) {
+ this.modelPath = modelPath;
+ }
+
+ /**
+ * Get the model path.
+ *
+ * @return the path to store the model.
+ */
+ public String getModelPath() {
+ return this.modelPath;
+ }
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ // read model type
+ this.modelType = WritableUtils.readString(input);
+ // read learning rate
+ this.learningRate = input.readDouble();
+ // read model path
+ this.modelPath = WritableUtils.readString(input);
+
+ if (this.modelPath.equals("null")) {
+ this.modelPath = null;
+ }
+
+ // read feature transformer
+ int bytesLen = input.readInt();
+ byte[] featureTransformerBytes = new byte[bytesLen];
+ for (int i = 0; i < featureTransformerBytes.length; ++i) {
+ featureTransformerBytes[i] = input.readByte();
+ }
+
+ Class<? extends FeatureTransformer> featureTransformerCls = (Class<? extends FeatureTransformer>) SerializationUtils
+ .deserialize(featureTransformerBytes);
+
+ Constructor[] constructors = featureTransformerCls
+ .getDeclaredConstructors();
+ Constructor constructor = constructors[0];
+
+ try {
+ this.featureTransformer = (FeatureTransformer) constructor
+ .newInstance(new Object[] {});
+ } catch (InstantiationException e) {
+ e.printStackTrace();
+ } catch (IllegalAccessException e) {
+ e.printStackTrace();
+ } catch (IllegalArgumentException e) {
+ e.printStackTrace();
+ } catch (InvocationTargetException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ // write model type
+ WritableUtils.writeString(output, modelType);
+ // write learning rate
+ output.writeDouble(learningRate);
+ // write model path
+ if (this.modelPath != null) {
+ WritableUtils.writeString(output, modelPath);
+ } else {
+ WritableUtils.writeString(output, "null");
+ }
+
+ // serialize the class
+ Class<? extends FeatureTransformer> featureTransformerCls = this.featureTransformer
+ .getClass();
+ byte[] featureTransformerBytes = SerializationUtils
+ .serialize(featureTransformerCls);
+ output.writeInt(featureTransformerBytes.length);
+ output.write(featureTransformerBytes);
+ }
+
+ public void setFeatureTransformer(FeatureTransformer featureTransformer) {
+ this.featureTransformer = featureTransformer;
+ }
+
+ public FeatureTransformer getFeatureTransformer() {
+ return this.featureTransformer;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
new file mode 100644
index 0000000..3547a1a
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/AbstractNeuralNetworkTrainer.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+import org.apache.hama.commons.io.VectorWritable;
+import org.apache.hama.ml.util.DefaultFeatureTransformer;
+import org.apache.hama.ml.util.FeatureTransformer;
+
+/**
+ * The trainer that is used to train the {@link LayeredNeuralNetwork} with
+ * BSP. The trainer would read the training data and obtain the trained
+ * parameters of the model.
+ *
+ */
+public abstract class AbstractNeuralNetworkTrainer
+ extends
+ BSP<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> {
+
+ protected static final Log LOG = LogFactory
+ .getLog(AbstractNeuralNetworkTrainer.class);
+
+ protected Configuration conf;
+ protected int maxIteration;
+ protected int batchSize;
+ protected String trainingMode;
+
+ protected FeatureTransformer featureTransformer;
+
+ @Override
+ final public void setup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
+ throws IOException, SyncException, InterruptedException {
+ conf = peer.getConfiguration();
+ featureTransformer = new DefaultFeatureTransformer();
+ this.extraSetup(peer);
+ }
+
+ /**
+ * Handle extra setup for sub-classes.
+ *
+ * @param peer
+ * @throws IOException
+ * @throws SyncException
+ * @throws InterruptedException
+ */
+ protected void extraSetup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
+ throws IOException, SyncException, InterruptedException {
+
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public abstract void bsp(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
+ throws IOException, SyncException, InterruptedException;
+
+ @Override
+ public void cleanup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
+ throws IOException {
+ this.extraCleanup(peer);
+ // write model to modelPath
+ }
+
+ /**
+ * Handle cleanup for sub-classes. Write the trained model back.
+ *
+ * @param peer
+ * @throws IOException
+ * @throws SyncException
+ * @throws InterruptedException
+ */
+ protected void extraCleanup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, Synapse<DoubleWritable, DoubleWritable>> peer)
+ throws IOException {
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/AutoEncoder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/AutoEncoder.java b/src/main/java/org/apache/horn/core/AutoEncoder.java
new file mode 100644
index 0000000..f638245
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/AutoEncoder.java
@@ -0,0 +1,197 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.commons.math.DenseDoubleVector;
+import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.ml.util.FeatureTransformer;
+import org.apache.horn.funcs.FunctionFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * AutoEncoder is a model used for dimensional reduction and feature learning.
+ * It is a special kind of {@link AbstractNeuralNetwork} that consists of three layers
+ * of neurons, where the first layer and third layer contains the same number of
+ * neurons.
+ *
+ */
+public class AutoEncoder {
+
+ private final LayeredNeuralNetwork model;
+
+ /**
+ * Initialize the autoencoder.
+ *
+ * @param inputDimensions The number of dimensions for the input feature.
+ * @param compressedDimensions The number of dimensions for the compressed
+ * information.
+ */
+ public AutoEncoder(int inputDimensions, int compressedDimensions) {
+ model = new LayeredNeuralNetwork();
+ model.addLayer(inputDimensions, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ model.addLayer(compressedDimensions, false,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ model.addLayer(inputDimensions, true,
+ FunctionFactory.createDoubleFunction("Sigmoid"));
+ model
+ .setLearningStyle(AbstractLayeredNeuralNetwork.LearningStyle.UNSUPERVISED);
+ model.setCostFunction(FunctionFactory
+ .createDoubleDoubleFunction("SquaredError"));
+ }
+
+ public AutoEncoder(HamaConfiguration conf, String modelPath) {
+ model = new LayeredNeuralNetwork(conf, modelPath);
+ }
+
+ public AutoEncoder setModelPath(String modelPath) {
+ model.setModelPath(modelPath);
+ return this;
+ }
+
+ /**
+ * Train the autoencoder with given data. Note that the training data is
+ * pre-processed, where the features
+ *
+ * @param dataInputPath
+ * @param trainingParams
+ * @throws InterruptedException
+ * @throws IOException
+ * @throws ClassNotFoundException
+ */
+ public BSPJob train(HamaConfiguration conf, Path dataInputPath,
+ Map<String, String> trainingParams) throws ClassNotFoundException, IOException, InterruptedException {
+ return model.train(conf);
+ }
+
+ /**
+ * Train the model with one instance.
+ *
+ * @param trainingInstance
+ */
+ public void trainOnline(DoubleVector trainingInstance) {
+ model.trainOnline(trainingInstance);
+ }
+
+ /**
+ * Get the matrix M used to encode the input features.
+ *
+ * @return this matrix with encode the input.
+ */
+ public DoubleMatrix getEncodeWeightMatrix() {
+ return model.getWeightsByLayer(0);
+ }
+
+ /**
+ * Get the matrix M used to decode the compressed information.
+ *
+ * @return this matrix with decode the compressed information.
+ */
+ public DoubleMatrix getDecodeWeightMatrix() {
+ return model.getWeightsByLayer(1);
+ }
+
+ /**
+ * Transform the input features.
+ *
+ * @param inputInstance
+ * @return The compressed information.
+ */
+ private DoubleVector transform(DoubleVector inputInstance, int inputLayer) {
+ DoubleVector internalInstance = new DenseDoubleVector(
+ inputInstance.getDimension() + 1);
+ internalInstance.set(0, 1);
+ for (int i = 0; i < inputInstance.getDimension(); ++i) {
+ internalInstance.set(i + 1, inputInstance.get(i));
+ }
+ DoubleFunction squashingFunction = model.getSquashingFunction(inputLayer);
+ DoubleMatrix weightMatrix = null;
+ if (inputLayer == 0) {
+ weightMatrix = this.getEncodeWeightMatrix();
+ } else {
+ weightMatrix = this.getDecodeWeightMatrix();
+ }
+ DoubleVector vec = weightMatrix.multiplyVectorUnsafe(internalInstance);
+ vec = vec.applyToElements(squashingFunction);
+ return vec;
+ }
+
+ /**
+ * Encode the input instance.
+ *
+ * @param inputInstance
+ * @return a new vector with the encode input instance.
+ */
+ public DoubleVector encode(DoubleVector inputInstance) {
+ Preconditions
+ .checkArgument(
+ inputInstance.getDimension() == model.getLayerSize(0) - 1,
+ String
+ .format(
+ "The dimension of input instance is %d, but the model requires dimension %d.",
+ inputInstance.getDimension(), model.getLayerSize(1) - 1));
+ return this.transform(inputInstance, 0);
+ }
+
+ /**
+ * Decode the input instance.
+ *
+ * @param inputInstance
+ * @return a new vector with the decode input instance.
+ */
+ public DoubleVector decode(DoubleVector inputInstance) {
+ Preconditions
+ .checkArgument(
+ inputInstance.getDimension() == model.getLayerSize(1) - 1,
+ String
+ .format(
+ "The dimension of input instance is %d, but the model requires dimension %d.",
+ inputInstance.getDimension(), model.getLayerSize(1) - 1));
+ return this.transform(inputInstance, 1);
+ }
+
+ /**
+ * Get the label(s) according to the given features.
+ *
+ * @param inputInstance
+ * @return a new vector with output of the model according to given feature
+ * instance.
+ */
+ public DoubleVector getOutput(DoubleVector inputInstance) {
+ return model.getOutput(inputInstance);
+ }
+
+ /**
+ * Set the feature transformer.
+ *
+ * @param featureTransformer
+ */
+ public void setFeatureTransformer(FeatureTransformer featureTransformer) {
+ this.model.setFeatureTransformer(featureTransformer);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/HornJob.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/HornJob.java b/src/main/java/org/apache/horn/core/HornJob.java
new file mode 100644
index 0000000..82dcad8
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/HornJob.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.commons.math.Function;
+import org.apache.horn.funcs.FunctionFactory;
+
+public class HornJob extends BSPJob {
+
+ LayeredNeuralNetwork neuralNetwork;
+
+ public HornJob(HamaConfiguration conf, Class<?> exampleClass)
+ throws IOException {
+ super(conf);
+ this.setJarByClass(exampleClass);
+
+ neuralNetwork = new LayeredNeuralNetwork();
+ }
+
+ public void inputLayer(int featureDimension, Class<? extends Function> func) {
+ addLayer(featureDimension, func);
+ }
+
+ public void addLayer(int featureDimension, Class<? extends Function> func) {
+ neuralNetwork.addLayer(featureDimension, false,
+ FunctionFactory.createDoubleFunction(func.getSimpleName()));
+ }
+
+ public void outputLayer(int labels, Class<? extends Function> func) {
+ neuralNetwork.addLayer(labels, true,
+ FunctionFactory.createDoubleFunction(func.getSimpleName()));
+ }
+
+ public void setCostFunction(Class<? extends Function> func) {
+ neuralNetwork.setCostFunction(FunctionFactory
+ .createDoubleDoubleFunction(func.getSimpleName()));
+ }
+
+ public void setDouble(String name, double value) {
+ conf.setDouble(name, value);
+ }
+
+ public void setMaxIteration(int maxIteration) {
+ this.conf.setInt("training.max.iterations", maxIteration);
+ }
+
+ public void setBatchSize(int batchSize) {
+ this.conf.setInt("training.batch.size", batchSize);
+ }
+
+ public void setLearningRate(double learningRate) {
+ this.conf.setDouble("mlp.learning.rate", learningRate);
+ }
+
+ public void setConvergenceCheckInterval(int n) {
+ this.conf.setInt("convergence.check.interval", n);
+ }
+
+ public void setMomentumWeight(double momentumWeight) {
+ this.conf.setDouble("mlp.momentum.weight", momentumWeight);
+ }
+
+ public LayeredNeuralNetwork getNeuralNetwork() {
+ return neuralNetwork;
+ }
+
+ public boolean waitForCompletion(boolean verbose) throws IOException,
+ InterruptedException, ClassNotFoundException {
+ BSPJob job = neuralNetwork.train(this.conf);
+ if (verbose) {
+ return job.waitForCompletion(true);
+ } else {
+ return job.waitForCompletion(false);
+ }
+ }
+
+ public void setRegularizationWeight(double regularizationWeight) {
+ this.conf.setDouble("regularization.weight", regularizationWeight);
+ }
+
+ public void setModelPath(String modelPath) {
+ this.conf.set("model.path", modelPath);
+ neuralNetwork.setModelPath(modelPath);
+ }
+
+ public void setTrainingSetPath(String inputPath) {
+ this.conf.set("training.input.path", inputPath);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
new file mode 100644
index 0000000..afccbff
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetwork.java
@@ -0,0 +1,621 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.commons.lang.math.RandomUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.hama.Constants;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSPJob;
+import org.apache.hama.commons.io.MatrixWritable;
+import org.apache.hama.commons.io.VectorWritable;
+import org.apache.hama.commons.math.DenseDoubleMatrix;
+import org.apache.hama.commons.math.DenseDoubleVector;
+import org.apache.hama.commons.math.DoubleFunction;
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.util.ReflectionUtils;
+import org.apache.horn.examples.MultiLayerPerceptron.StandardNeuron;
+import org.apache.horn.funcs.FunctionFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
+/**
+ * SmallLayeredNeuralNetwork defines the general operations for derivative
+ * layered models, include Linear Regression, Logistic Regression, Multilayer
+ * Perceptron, Autoencoder, and Restricted Boltzmann Machine, etc. For
+ * SmallLayeredNeuralNetwork, the training can be conducted in parallel, but the
+ * parameters of the models are assumes to be stored in a single machine.
+ *
+ * In general, these models consist of neurons which are aligned in layers.
+ * Between layers, for any two adjacent layers, the neurons are connected to
+ * form a bipartite weighted graph.
+ *
+ */
+public class LayeredNeuralNetwork extends AbstractLayeredNeuralNetwork {
+
+ private static final Log LOG = LogFactory
+ .getLog(LayeredNeuralNetwork.class);
+
+ public static Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>> neuronClass;
+
+ /* Weights between neurons at adjacent layers */
+ protected List<DoubleMatrix> weightMatrixList;
+
+ /* Previous weight updates between neurons at adjacent layers */
+ protected List<DoubleMatrix> prevWeightUpdatesList;
+
+ /* Different layers can have different squashing function */
+ protected List<DoubleFunction> squashingFunctionList;
+
+ protected int finalLayerIdx;
+
+ protected double regularizationWeight;
+
+ public LayeredNeuralNetwork() {
+ this.layerSizeList = Lists.newArrayList();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
+ this.squashingFunctionList = Lists.newArrayList();
+ }
+
+ public LayeredNeuralNetwork(HamaConfiguration conf, String modelPath) {
+ super(conf, modelPath);
+ this.regularizationWeight = conf.getDouble("regularization.weight", 0);
+ }
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public int addLayer(int size, boolean isFinalLayer,
+ DoubleFunction squashingFunction) {
+ Preconditions.checkArgument(size > 0,
+ "Size of layer must be larger than 0.");
+ if (!isFinalLayer) {
+ size += 1;
+ }
+
+ LOG.info("Add Layer: " + size);
+ this.layerSizeList.add(size);
+ int layerIdx = this.layerSizeList.size() - 1;
+ if (isFinalLayer) {
+ this.finalLayerIdx = layerIdx;
+ }
+
+ // add weights between current layer and previous layer, and input layer has
+ // no squashing function
+ if (layerIdx > 0) {
+ int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
+ // row count equals to size of current size and column count equals to
+ // size of previous layer
+ int row = isFinalLayer ? size : size - 1;
+ int col = sizePrevLayer;
+ DoubleMatrix weightMatrix = new DenseDoubleMatrix(row, col);
+ // initialize weights
+ weightMatrix.applyToElements(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return RandomUtils.nextDouble() - 0.5;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ throw new UnsupportedOperationException("");
+ }
+ });
+ this.weightMatrixList.add(weightMatrix);
+ this.prevWeightUpdatesList.add(new DenseDoubleMatrix(row, col));
+ this.squashingFunctionList.add(squashingFunction);
+ }
+ return layerIdx;
+ }
+
+ /**
+ * Update the weight matrices with given matrices.
+ *
+ * @param matrices
+ */
+ public void updateWeightMatrices(DoubleMatrix[] matrices) {
+ for (int i = 0; i < matrices.length; ++i) {
+ DoubleMatrix matrix = this.weightMatrixList.get(i);
+ this.weightMatrixList.set(i, matrix.add(matrices[i]));
+ }
+ }
+
+ /**
+ * Set the previous weight matrices.
+ *
+ * @param prevUpdates
+ */
+ void setPrevWeightMatrices(DoubleMatrix[] prevUpdates) {
+ this.prevWeightUpdatesList.clear();
+ Collections.addAll(this.prevWeightUpdatesList, prevUpdates);
+ }
+
+ /**
+ * Add a batch of matrices onto the given destination matrices.
+ *
+ * @param destMatrices
+ * @param sourceMatrices
+ */
+ static void matricesAdd(DoubleMatrix[] destMatrices,
+ DoubleMatrix[] sourceMatrices) {
+ for (int i = 0; i < destMatrices.length; ++i) {
+ destMatrices[i] = destMatrices[i].add(sourceMatrices[i]);
+ }
+ }
+
+ /**
+ * Get all the weight matrices.
+ *
+ * @return The matrices in form of matrix array.
+ */
+ DoubleMatrix[] getWeightMatrices() {
+ DoubleMatrix[] matrices = new DoubleMatrix[this.weightMatrixList.size()];
+ this.weightMatrixList.toArray(matrices);
+ return matrices;
+ }
+
+ /**
+ * Set the weight matrices.
+ *
+ * @param matrices
+ */
+ public void setWeightMatrices(DoubleMatrix[] matrices) {
+ this.weightMatrixList = new ArrayList<DoubleMatrix>();
+ Collections.addAll(this.weightMatrixList, matrices);
+ }
+
+ /**
+ * Get the previous matrices updates in form of array.
+ *
+ * @return The matrices in form of matrix array.
+ */
+ public DoubleMatrix[] getPrevMatricesUpdates() {
+ DoubleMatrix[] prevMatricesUpdates = new DoubleMatrix[this.prevWeightUpdatesList
+ .size()];
+ for (int i = 0; i < this.prevWeightUpdatesList.size(); ++i) {
+ prevMatricesUpdates[i] = this.prevWeightUpdatesList.get(i);
+ }
+ return prevMatricesUpdates;
+ }
+
+ public void setWeightMatrix(int index, DoubleMatrix matrix) {
+ Preconditions.checkArgument(
+ 0 <= index && index < this.weightMatrixList.size(), String.format(
+ "index [%d] should be in range[%d, %d].", index, 0,
+ this.weightMatrixList.size()));
+ this.weightMatrixList.set(index, matrix);
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ super.readFields(input);
+
+ // read squash functions
+ int squashingFunctionSize = input.readInt();
+ this.squashingFunctionList = Lists.newArrayList();
+ for (int i = 0; i < squashingFunctionSize; ++i) {
+ this.squashingFunctionList.add(FunctionFactory
+ .createDoubleFunction(WritableUtils.readString(input)));
+ }
+
+ // read weights and construct matrices of previous updates
+ int numOfMatrices = input.readInt();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
+ for (int i = 0; i < numOfMatrices; ++i) {
+ DoubleMatrix matrix = MatrixWritable.read(input);
+ this.weightMatrixList.add(matrix);
+ this.prevWeightUpdatesList.add(new DenseDoubleMatrix(
+ matrix.getRowCount(), matrix.getColumnCount()));
+ }
+
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ super.write(output);
+
+ // write squashing functions
+ output.writeInt(this.squashingFunctionList.size());
+ for (DoubleFunction aSquashingFunctionList : this.squashingFunctionList) {
+ WritableUtils.writeString(output,
+ aSquashingFunctionList.getFunctionName());
+ }
+
+ // write weight matrices
+ output.writeInt(this.weightMatrixList.size());
+ for (DoubleMatrix aWeightMatrixList : this.weightMatrixList) {
+ MatrixWritable.write(aWeightMatrixList, output);
+ }
+
+ // DO NOT WRITE WEIGHT UPDATE
+ }
+
+ @Override
+ public DoubleMatrix getWeightsByLayer(int layerIdx) {
+ return this.weightMatrixList.get(layerIdx);
+ }
+
+ /**
+ * Get the output of the model according to given feature instance.
+ */
+ @Override
+ public DoubleVector getOutput(DoubleVector instance) {
+ Preconditions.checkArgument(this.layerSizeList.get(0) - 1 == instance
+ .getDimension(), String.format(
+ "The dimension of input instance should be %d.",
+ this.layerSizeList.get(0) - 1));
+ // transform the features to another space
+ DoubleVector transformedInstance = this.featureTransformer
+ .transform(instance);
+ // add bias feature
+ DoubleVector instanceWithBias = new DenseDoubleVector(
+ transformedInstance.getDimension() + 1);
+ instanceWithBias.set(0, 0.99999); // set bias to be a little bit less than
+ // 1.0
+ for (int i = 1; i < instanceWithBias.getDimension(); ++i) {
+ instanceWithBias.set(i, transformedInstance.get(i - 1));
+ }
+
+ List<DoubleVector> outputCache = getOutputInternal(instanceWithBias);
+ // return the output of the last layer
+ DoubleVector result = outputCache.get(outputCache.size() - 1);
+ // remove bias
+ return result.sliceUnsafe(1, result.getDimension() - 1);
+ }
+
+ /**
+ * Calculate output internally, the intermediate output of each layer will be
+ * stored.
+ *
+ * @param instanceWithBias The instance contains the features.
+ * @return Cached output of each layer.
+ */
+ public List<DoubleVector> getOutputInternal(DoubleVector instanceWithBias) {
+ List<DoubleVector> outputCache = new ArrayList<DoubleVector>();
+ // fill with instance
+ DoubleVector intermediateOutput = instanceWithBias;
+ outputCache.add(intermediateOutput);
+
+ for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
+ intermediateOutput = forward(i, intermediateOutput);
+ outputCache.add(intermediateOutput);
+ }
+
+ return outputCache;
+ }
+
+ /**
+ * @return a new neuron instance
+ */
+ public static Neuron<Synapse<DoubleWritable, DoubleWritable>> newNeuronInstance() {
+ return (Neuron<Synapse<DoubleWritable, DoubleWritable>>) ReflectionUtils
+ .newInstance(neuronClass);
+ }
+
+ /**
+ * Forward the calculation for one layer.
+ *
+ * @param fromLayer The index of the previous layer.
+ * @param intermediateOutput The intermediateOutput of previous layer.
+ * @return a new vector with the result of the operation.
+ */
+ @SuppressWarnings("unchecked")
+ protected DoubleVector forward(int fromLayer, DoubleVector intermediateOutput) {
+ DoubleMatrix weightMatrix = this.weightMatrixList.get(fromLayer);
+
+ neuronClass = (Class<Neuron<Synapse<DoubleWritable, DoubleWritable>>>) conf
+ .getClass("neuron.class", Neuron.class);
+
+ // TODO use the multithread processing
+ DoubleVector vec = new DenseDoubleVector(weightMatrix.getRowCount());
+ for (int row = 0; row < weightMatrix.getRowCount(); row++) {
+ List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
+ for (int col = 0; col < weightMatrix.getColumnCount(); col++) {
+ msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
+ new DoubleWritable(intermediateOutput.get(col)),
+ new DoubleWritable(weightMatrix.get(row, col))));
+ }
+ Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
+ Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
+ n.setup(conf);
+ n.setSquashingFunction(this.squashingFunctionList.get(fromLayer));
+ try {
+ n.forward(iterable);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ vec.set(row, n.getOutput());
+ }
+
+ // add bias
+ DoubleVector vecWithBias = new DenseDoubleVector(vec.getDimension() + 1);
+ vecWithBias.set(0, 1);
+ for (int i = 0; i < vec.getDimension(); ++i) {
+ vecWithBias.set(i + 1, vec.get(i));
+ }
+
+ return vecWithBias;
+ }
+
+ /**
+ * Train the model online.
+ *
+ * @param trainingInstance
+ */
+ public void trainOnline(DoubleVector trainingInstance) {
+ DoubleMatrix[] updateMatrices = this.trainByInstance(trainingInstance);
+ this.updateWeightMatrices(updateMatrices);
+ }
+
+ @Override
+ public DoubleMatrix[] trainByInstance(DoubleVector trainingInstance) {
+ DoubleVector transformedVector = this.featureTransformer
+ .transform(trainingInstance.sliceUnsafe(this.layerSizeList.get(0) - 1));
+
+ int inputDimension = this.layerSizeList.get(0) - 1;
+ int outputDimension;
+ DoubleVector inputInstance = null;
+ DoubleVector labels = null;
+ if (this.learningStyle == LearningStyle.SUPERVISED) {
+ outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
+ // validate training instance
+ Preconditions.checkArgument(
+ inputDimension + outputDimension == trainingInstance.getDimension(),
+ String
+ .format(
+ "The dimension of training instance is %d, but requires %d.",
+ trainingInstance.getDimension(), inputDimension
+ + outputDimension));
+
+ inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
+ inputInstance.set(0, 1); // add bias
+ // get the features from the transformed vector
+ for (int i = 0; i < inputDimension; ++i) {
+ inputInstance.set(i + 1, transformedVector.get(i));
+ }
+ // get the labels from the original training instance
+ labels = trainingInstance.sliceUnsafe(inputInstance.getDimension() - 1,
+ trainingInstance.getDimension() - 1);
+ } else if (this.learningStyle == LearningStyle.UNSUPERVISED) {
+ // labels are identical to input features
+ outputDimension = inputDimension;
+ // validate training instance
+ Preconditions.checkArgument(inputDimension == trainingInstance
+ .getDimension(), String.format(
+ "The dimension of training instance is %d, but requires %d.",
+ trainingInstance.getDimension(), inputDimension));
+
+ inputInstance = new DenseDoubleVector(this.layerSizeList.get(0));
+ inputInstance.set(0, 1); // add bias
+ // get the features from the transformed vector
+ for (int i = 0; i < inputDimension; ++i) {
+ inputInstance.set(i + 1, transformedVector.get(i));
+ }
+ // get the labels by copying the transformed vector
+ labels = transformedVector.deepCopy();
+ }
+
+ List<DoubleVector> internalResults = this.getOutputInternal(inputInstance);
+ DoubleVector output = internalResults.get(internalResults.size() - 1);
+
+ // get the training error
+ calculateTrainingError(labels,
+ output.deepCopy().sliceUnsafe(1, output.getDimension() - 1));
+
+ if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
+ return this.trainByInstanceGradientDescent(labels, internalResults);
+ } else {
+ throw new IllegalArgumentException(
+ String.format("Training method is not supported."));
+ }
+ }
+
+ /**
+ * Train by gradient descent. Get the updated weights using one training
+ * instance.
+ *
+ * @param trainingInstance
+ * @return The weight update matrices.
+ */
+ private DoubleMatrix[] trainByInstanceGradientDescent(DoubleVector labels,
+ List<DoubleVector> internalResults) {
+
+ DoubleVector output = internalResults.get(internalResults.size() - 1);
+ // initialize weight update matrices
+ DenseDoubleMatrix[] weightUpdateMatrices = new DenseDoubleMatrix[this.weightMatrixList
+ .size()];
+ for (int m = 0; m < weightUpdateMatrices.length; ++m) {
+ weightUpdateMatrices[m] = new DenseDoubleMatrix(this.weightMatrixList
+ .get(m).getRowCount(), this.weightMatrixList.get(m).getColumnCount());
+ }
+ DoubleVector deltaVec = new DenseDoubleVector(
+ this.layerSizeList.get(this.layerSizeList.size() - 1));
+
+ DoubleFunction squashingFunction = this.squashingFunctionList
+ .get(this.squashingFunctionList.size() - 1);
+
+ DoubleMatrix lastWeightMatrix = this.weightMatrixList
+ .get(this.weightMatrixList.size() - 1);
+ for (int i = 0; i < deltaVec.getDimension(); ++i) {
+ double costFuncDerivative = this.costFunction.applyDerivative(
+ labels.get(i), output.get(i + 1));
+ // add regularization
+ costFuncDerivative += this.regularizationWeight
+ * lastWeightMatrix.getRowVector(i).sum();
+ deltaVec.set(
+ i,
+ costFuncDerivative
+ * squashingFunction.applyDerivative(output.get(i + 1)));
+ }
+
+ // start from previous layer of output layer
+ for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
+ output = internalResults.get(layer);
+ deltaVec = backpropagate(layer, deltaVec, internalResults,
+ weightUpdateMatrices[layer]);
+ }
+
+ this.setPrevWeightMatrices(weightUpdateMatrices);
+
+ return weightUpdateMatrices;
+ }
+
+ /**
+ * Back-propagate the errors to from next layer to current layer. The weight
+ * updated information will be stored in the weightUpdateMatrices, and the
+ * delta of the prevLayer would be returned.
+ *
+ * @param layer Index of current layer.
+ * @param internalOutput Internal output of current layer.
+ * @param deltaVec Delta of next layer.
+ * @return the squashing function of the specified position.
+ */
+ private DoubleVector backpropagate(int curLayerIdx,
+ DoubleVector nextLayerDelta, List<DoubleVector> outputCache,
+ DenseDoubleMatrix weightUpdateMatrix) {
+
+ // get layer related information
+ DoubleVector curLayerOutput = outputCache.get(curLayerIdx);
+ DoubleMatrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
+ DoubleMatrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
+
+ // next layer is not output layer, remove the delta of bias neuron
+ if (curLayerIdx != this.layerSizeList.size() - 2) {
+ nextLayerDelta = nextLayerDelta.slice(1,
+ nextLayerDelta.getDimension() - 1);
+ }
+
+ // DoubleMatrix transposed = weightMatrix.transpose();
+ DoubleVector deltaVector = new DenseDoubleVector(
+ weightMatrix.getColumnCount());
+ for (int row = 0; row < weightMatrix.getColumnCount(); ++row) {
+ Neuron<Synapse<DoubleWritable, DoubleWritable>> n = newNeuronInstance();
+ // calls setup method
+ n.setup(conf);
+ n.setSquashingFunction(this.squashingFunctionList.get(curLayerIdx));
+ n.setOutput(curLayerOutput.get(row));
+
+ List<Synapse<DoubleWritable, DoubleWritable>> msgs = new ArrayList<Synapse<DoubleWritable, DoubleWritable>>();
+
+ n.setWeightVector(weightMatrix.getRowCount());
+
+ for (int col = 0; col < weightMatrix.getRowCount(); ++col) {
+ // sum += (transposed.get(row, col) * nextLayerDelta.get(col));
+ msgs.add(new Synapse<DoubleWritable, DoubleWritable>(
+ new DoubleWritable(nextLayerDelta.get(col)), new DoubleWritable(
+ weightMatrix.get(col, row)), new DoubleWritable(
+ prevWeightMatrix.get(col, row))));
+ }
+
+ Iterable<Synapse<DoubleWritable, DoubleWritable>> iterable = msgs;
+ try {
+ n.backward(iterable);
+ } catch (IOException e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+
+ // update weights
+ weightUpdateMatrix.setColumn(row, n.getWeights());
+ deltaVector.set(row, n.getDelta());
+ }
+
+ return deltaVector;
+ }
+
+ @Override
+ protected BSPJob trainInternal(HamaConfiguration hamaConf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ this.conf = hamaConf;
+ this.fs = FileSystem.get(conf);
+
+ String modelPath = conf.get("model.path");
+ if (modelPath != null) {
+ this.modelPath = modelPath;
+ }
+ // modelPath must be set before training
+ if (this.modelPath == null) {
+ throw new IllegalArgumentException(
+ "Please specify the modelPath for model, "
+ + "either through setModelPath() or add 'modelPath' to the training parameters.");
+ }
+ this.writeModelToFile();
+
+ // create job
+ BSPJob job = new BSPJob(conf, LayeredNeuralNetworkTrainer.class);
+ job.setJobName("Small scale Neural Network training");
+ job.setJarByClass(LayeredNeuralNetworkTrainer.class);
+ job.setBspClass(LayeredNeuralNetworkTrainer.class);
+
+ job.getConfiguration().setClass("neuron.class", StandardNeuron.class,
+ Neuron.class);
+
+ // additional for parameter server
+ // TODO at this moment, we use 1 task as a parameter server
+ // In the future, the number of parameter server should be configurable
+ job.getConfiguration().setInt(Constants.ADDITIONAL_BSP_TASKS, 1);
+
+ job.setInputPath(new Path(conf.get("training.input.path")));
+ job.setInputFormat(org.apache.hama.bsp.SequenceFileInputFormat.class);
+ job.setInputKeyClass(LongWritable.class);
+ job.setInputValueClass(VectorWritable.class);
+ job.setOutputKeyClass(NullWritable.class);
+ job.setOutputValueClass(NullWritable.class);
+ job.setOutputFormat(org.apache.hama.bsp.NullOutputFormat.class);
+
+ return job;
+ }
+
+ @Override
+ protected void calculateTrainingError(DoubleVector labels, DoubleVector output) {
+ DoubleVector errors = labels.deepCopy().applyToElements(output,
+ this.costFunction);
+ this.trainingError = errors.sum();
+ }
+
+ /**
+ * Get the squashing function of a specified layer.
+ *
+ * @param idx
+ * @return a new vector with the result of the operation.
+ */
+ public DoubleFunction getSquashingFunction(int idx) {
+ return this.squashingFunctionList.get(idx);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
new file mode 100644
index 0000000..effd5b0
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/LayeredNeuralNetworkTrainer.java
@@ -0,0 +1,216 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hama.HamaConfiguration;
+import org.apache.hama.bsp.BSP;
+import org.apache.hama.bsp.BSPPeer;
+import org.apache.hama.bsp.sync.SyncException;
+import org.apache.hama.commons.io.VectorWritable;
+import org.apache.hama.commons.math.DenseDoubleMatrix;
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.ipc.RPC;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * The trainer that train the {@link LayeredNeuralNetwork} based on BSP
+ * framework.
+ *
+ */
+public final class LayeredNeuralNetworkTrainer
+ extends
+ BSP<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> {
+
+ private static final Log LOG = LogFactory
+ .getLog(LayeredNeuralNetworkTrainer.class);
+
+ /* When given peer is master worker: base of parameter merge */
+ /* When given peer is slave worker: neural network for training */
+ private LayeredNeuralNetwork inMemoryModel;
+
+ /* Job configuration */
+ private HamaConfiguration conf;
+
+ /* Default batch size */
+ private int batchSize;
+
+ /* whether it is converging or not */
+ private AtomicBoolean isConverge;
+
+ /* When given peer is master worker: Asynchronous parameter merger */
+ /* When given peer is slave worker: null */
+ private RPC.Server merger;
+
+ /* When given peer is master worker: null */
+ /* When given peer is slave worker: proxy to Asynchronous parameter merger */
+ private ParameterMerger proxy;
+
+ /**
+ * Returns true if this worker is master worker.
+ *
+ * @param peer
+ * */
+ private boolean isMaster(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
+ return peer.getPeerIndex() == peer.getNumPeers() - 1;
+ }
+
+ @Override
+ /**
+ * If the model path is specified, load the existing from storage location.
+ */
+ public void setup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
+ // At least one master & slave worker exist.
+ Preconditions.checkArgument(peer.getNumPeers() >= 2);
+ this.conf = peer.getConfiguration();
+
+ String modelPath = conf.get("model.path");
+ this.inMemoryModel = new LayeredNeuralNetwork(conf, modelPath);
+
+ this.batchSize = conf.getInt("training.batch.size", 50);
+ this.isConverge = new AtomicBoolean(false);
+
+ int slaveCount = peer.getNumPeers() - 1;
+ int mergeLimit = conf.getInt("training.max.iterations", 100000);
+ int convergenceCheckInterval = peer.getNumPeers()
+ * conf.getInt("convergence.check.interval", 2000);
+ String master = peer.getPeerName();
+ String masterAddr = master.substring(0, master.indexOf(':'));
+ int port = conf.getInt("sync.server.port", 40052);
+
+ if (isMaster(peer)) {
+ try {
+ this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel,
+ isConverge, slaveCount, mergeLimit, convergenceCheckInterval),
+ masterAddr, port, conf);
+ merger.start();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ LOG.info("Begin to train");
+ } else {
+ InetSocketAddress addr = new InetSocketAddress(masterAddr, port);
+ try {
+ this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class,
+ ParameterMerger.versionID, addr, conf);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ @Override
+ /**
+ * Write the trained model back to stored location.
+ */
+ public void cleanup(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer) {
+ // write model to modelPath
+ if (isMaster(peer)) {
+ try {
+ LOG.info("Write model back to " + inMemoryModel.getModelPath());
+ this.inMemoryModel.writeModelToFile();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ @Override
+ public void bsp(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
+ throws IOException, SyncException, InterruptedException {
+ while (!this.isConverge.get()) {
+ // each slave-worker calculate the matrices updates according to local
+ // data
+ // and merge them with master
+ if (!isMaster(peer)) {
+ calculateUpdates(peer);
+ }
+ }
+
+ if (isMaster(peer)) {
+ merger.stop();
+ }
+ peer.sync(); // finalize the bsp program.
+ }
+
+ /**
+ * Calculate the matrices updates according to local partition of data.
+ *
+ * @param peer
+ * @throws IOException
+ */
+ private void calculateUpdates(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, ParameterMessage> peer)
+ throws IOException {
+
+ DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
+ .size()];
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ int row = this.inMemoryModel.weightMatrixList.get(i).getRowCount();
+ int col = this.inMemoryModel.weightMatrixList.get(i).getColumnCount();
+ weightUpdates[i] = new DenseDoubleMatrix(row, col);
+ }
+
+ // continue to train
+ double avgTrainingError = 0.0;
+ LongWritable key = new LongWritable();
+ VectorWritable value = new VectorWritable();
+ for (int recordsRead = 0; recordsRead < batchSize; ++recordsRead) {
+ if (!peer.readNext(key, value)) {
+ peer.reopenInput();
+ peer.readNext(key, value);
+ }
+ DoubleVector trainingInstance = value.getVector();
+ LayeredNeuralNetwork.matricesAdd(weightUpdates,
+ this.inMemoryModel.trainByInstance(trainingInstance));
+ avgTrainingError += this.inMemoryModel.trainingError;
+ }
+ avgTrainingError /= batchSize;
+
+ // calculate the average of updates
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ weightUpdates[i] = weightUpdates[i].divide(batchSize);
+ }
+
+ // exchange parameter update with master
+ ParameterMessage msg = new ParameterMessage(
+ avgTrainingError, false, weightUpdates,
+ this.inMemoryModel.getPrevMatricesUpdates());
+
+ ParameterMessage inMessage = proxy.merge(msg);
+ DoubleMatrix[] newWeights = inMessage.getCurMatrices();
+ DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
+ this.inMemoryModel.setWeightMatrices(newWeights);
+ this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
+ this.isConverge.set(inMessage.isConverge());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/Neuron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Neuron.java b/src/main/java/org/apache/horn/core/Neuron.java
new file mode 100644
index 0000000..357b42f
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/Neuron.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.hama.commons.math.DoubleFunction;
+
+public abstract class Neuron<M extends Writable> implements NeuronInterface<M> {
+ double output;
+ double weight;
+ double delta;
+ protected DoubleFunction squashingFunction;
+
+ public void feedforward(double sum) {
+ // TODO Auto-generated method stub
+ // squashing
+ this.output = sum;
+ }
+
+ public void backpropagate(double gradient) {
+ // TODO Auto-generated method stub
+ this.delta = gradient;
+ }
+
+ public double getDelta() {
+ return delta;
+ }
+
+ public void setWeight(double weight) {
+ this.weight = weight;
+ }
+
+ public void setOutput(double output) {
+ this.output = output;
+ }
+
+ public double getOutput() {
+ return output;
+ }
+
+ // ////////* Below methods will communicate with parameter server */
+ private int i;
+
+ public void push(double weight) {
+ weights[i++] = weight;
+ }
+
+ public double getUpdate() {
+ return weight;
+ }
+
+ double[] weights;
+
+ public void setWeightVector(int rowCount) {
+ i = 0;
+ weights = new double[rowCount];
+ }
+
+ public double[] getWeights() {
+ return weights;
+ }
+
+ public void setSquashingFunction(DoubleFunction squashingFunction) {
+ this.squashingFunction = squashingFunction;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/NeuronInterface.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/NeuronInterface.java b/src/main/java/org/apache/horn/core/NeuronInterface.java
new file mode 100644
index 0000000..5e4c113
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/NeuronInterface.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.hama.HamaConfiguration;
+
+public interface NeuronInterface<M extends Writable> {
+
+ public void setup(HamaConfiguration conf);
+
+ /**
+ * This method is called when the messages are propagated from the lower
+ * layer. It can be used to determine if the neuron would activate, or fire.
+ *
+ * @param messages
+ * @throws IOException
+ */
+ public void forward(Iterable<M> messages) throws IOException;
+
+ /**
+ * This method is called when the errors are propagated from the upper layer.
+ * It can be used to calculate the error of each neuron and change the
+ * weights.
+ *
+ * @param messages
+ * @throws IOException
+ */
+ public void backward(Iterable<M> messages) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/ParameterMerger.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/ParameterMerger.java b/src/main/java/org/apache/horn/core/ParameterMerger.java
new file mode 100644
index 0000000..512b402
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/ParameterMerger.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import org.apache.hama.ipc.VersionedProtocol;
+
+public interface ParameterMerger extends VersionedProtocol {
+ long versionID = 1L;
+
+ ParameterMessage merge(ParameterMessage msg);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/ParameterMergerServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/ParameterMergerServer.java b/src/main/java/org/apache/horn/core/ParameterMergerServer.java
new file mode 100644
index 0000000..c76a4d0
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/ParameterMergerServer.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hama.commons.math.DoubleMatrix;
+
+import com.google.common.base.Preconditions;
+
+public class ParameterMergerServer implements ParameterMerger {
+
+ private static final Log LOG = LogFactory.getLog(ParameterMergerServer.class);
+
+ /* The parameter merge base. */
+ protected LayeredNeuralNetwork inMemoryModel;
+
+ /* To terminate or not to terminate. */
+ protected AtomicBoolean isConverge;
+
+ /* The number of slave works that request commits. */
+ protected int SlaveCount;
+
+ /* After mergeLimit, terminate whether the result is converging or not. */
+ protected int mergeLimit;
+
+ /*
+ * last n training errors. converging is decided based on the average value of
+ * these errors.
+ */
+ protected double[] trainingErrors;
+
+ /*
+ * If the average of last n training errors is smaller than this value, it is
+ * converging.
+ */
+ protected double prevAvgTrainingError = Double.MAX_VALUE;
+
+ /* current index for trainingErrors. */
+ protected int curTrainingError = 0;
+
+ /* how many merges have been conducted? */
+ protected int mergeCount = 0;
+
+ public ParameterMergerServer(LayeredNeuralNetwork inMemoryModel,
+ AtomicBoolean isConverge, int slaveCount, int mergeLimit,
+ int convergenceCheckInterval) {
+ this.inMemoryModel = inMemoryModel;
+ this.isConverge = isConverge;
+ this.SlaveCount = slaveCount;
+ this.mergeLimit = mergeLimit;
+ this.trainingErrors = new double[convergenceCheckInterval];
+ }
+
+ @Override
+ public long getProtocolVersion(String s, long l) throws IOException {
+ return ParameterMerger.versionID;
+ }
+
+ @Override
+ public ParameterMessage merge(
+ ParameterMessage msg) {
+
+ double trainingError = msg.getTrainingError();
+ DoubleMatrix[] weightUpdates = msg.getCurMatrices();
+ DoubleMatrix[] prevWeightUpdates = msg.getPrevMatrices();
+
+ Preconditions
+ .checkArgument(weightUpdates.length == prevWeightUpdates.length);
+
+ LOG.info("Start merging: " + this.mergeCount);
+
+ if (!this.isConverge.get()) {
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
+ prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
+ }
+
+ synchronized (inMemoryModel) {
+ this.inMemoryModel.updateWeightMatrices(weightUpdates);
+ this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
+
+ // add trainingError to trainingErrors
+ this.trainingErrors[this.curTrainingError++] = trainingError;
+
+ // check convergence
+ if (this.trainingErrors.length == this.curTrainingError) {
+ double curAvgTrainingError = 0.0;
+ for (int i = 0; i < this.curTrainingError; ++i) {
+ curAvgTrainingError += this.trainingErrors[i];
+ }
+ curAvgTrainingError /= this.trainingErrors.length;
+
+ if (prevAvgTrainingError < curAvgTrainingError) {
+ this.isConverge.set(true);
+ } else {
+ // update
+ prevAvgTrainingError = curAvgTrainingError;
+ this.curTrainingError = 0;
+ }
+ }
+
+ if (++this.mergeCount == this.mergeLimit) {
+ this.isConverge.set(true);
+ }
+ }
+ }
+
+ return new ParameterMessage(0, this.isConverge.get(),
+ this.inMemoryModel.getWeightMatrices(),
+ this.inMemoryModel.getPrevMatricesUpdates());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/ParameterMessage.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/ParameterMessage.java b/src/main/java/org/apache/horn/core/ParameterMessage.java
new file mode 100644
index 0000000..3905e25
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/ParameterMessage.java
@@ -0,0 +1,125 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.hama.commons.io.MatrixWritable;
+import org.apache.hama.commons.math.DenseDoubleMatrix;
+import org.apache.hama.commons.math.DoubleMatrix;
+
+/**
+ * ParameterMessage transmits the messages between workers and parameter
+ * servers during the training of neural networks.
+ *
+ */
+public class ParameterMessage implements Writable {
+
+ protected double trainingError;
+ protected DoubleMatrix[] curMatrices;
+ protected DoubleMatrix[] prevMatrices;
+ protected boolean converge;
+
+ public ParameterMessage() {
+ }
+
+ public ParameterMessage(double trainingError, boolean converge,
+ DoubleMatrix[] weightMatrices, DoubleMatrix[] prevMatrices) {
+ this.trainingError = trainingError;
+ this.converge = converge;
+ this.curMatrices = weightMatrices;
+ this.prevMatrices = prevMatrices;
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ trainingError = input.readDouble();
+ converge = input.readBoolean();
+ int numMatrices = input.readInt();
+ boolean hasPrevMatrices = input.readBoolean();
+ curMatrices = new DenseDoubleMatrix[numMatrices];
+ // read matrice updates
+ for (int i = 0; i < curMatrices.length; ++i) {
+ curMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
+ }
+
+ if (hasPrevMatrices) {
+ prevMatrices = new DenseDoubleMatrix[numMatrices];
+ // read previous matrices updates
+ for (int i = 0; i < prevMatrices.length; ++i) {
+ prevMatrices[i] = (DenseDoubleMatrix) MatrixWritable.read(input);
+ }
+ }
+ }
+
+ @Override
+ public void write(DataOutput output) throws IOException {
+ output.writeDouble(trainingError);
+ output.writeBoolean(converge);
+ output.writeInt(curMatrices.length);
+ if (prevMatrices == null) {
+ output.writeBoolean(false);
+ } else {
+ output.writeBoolean(true);
+ }
+ for (DoubleMatrix matrix : curMatrices) {
+ MatrixWritable.write(matrix, output);
+ }
+ if (prevMatrices != null) {
+ for (DoubleMatrix matrix : prevMatrices) {
+ MatrixWritable.write(matrix, output);
+ }
+ }
+ }
+
+ public double getTrainingError() {
+ return trainingError;
+ }
+
+ public void setTrainingError(double trainingError) {
+ this.trainingError = trainingError;
+ }
+
+ public boolean isConverge() {
+ return converge;
+ }
+
+ public void setConverge(boolean converge) {
+ this.converge = converge;
+ }
+
+ public DoubleMatrix[] getCurMatrices() {
+ return curMatrices;
+ }
+
+ public void setMatrices(DoubleMatrix[] curMatrices) {
+ this.curMatrices = curMatrices;
+ }
+
+ public DoubleMatrix[] getPrevMatrices() {
+ return prevMatrices;
+ }
+
+ public void setPrevMatrices(DoubleMatrix[] prevMatrices) {
+ this.prevMatrices = prevMatrices;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/core/Synapse.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/core/Synapse.java b/src/main/java/org/apache/horn/core/Synapse.java
new file mode 100644
index 0000000..714767b
--- /dev/null
+++ b/src/main/java/org/apache/horn/core/Synapse.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.core;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Message wrapper for a propagating message
+ */
+public class Synapse<M extends Writable, W extends Writable> implements
+ Writable {
+
+ DoubleWritable message;
+ DoubleWritable weight;
+ DoubleWritable prevWeight;
+
+ public Synapse(DoubleWritable message, DoubleWritable weight) {
+ this.message = message;
+ this.weight = weight;
+ }
+
+ public Synapse(DoubleWritable message, DoubleWritable weight, DoubleWritable prevWeight) {
+ this.message = message;
+ this.weight = weight;
+ this.prevWeight = prevWeight;
+ }
+
+ /**
+ * @return the activation or error message
+ */
+ public double getMessage() {
+ return message.get();
+ }
+
+ public double getInput() {
+ // returns the input
+ return message.get();
+ }
+
+ public double getDelta() {
+ // returns the delta
+ return message.get();
+ }
+
+ public double getWeight() {
+ return weight.get();
+ }
+
+ public double getPrevWeight() {
+ return prevWeight.get();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ message.readFields(in);
+ weight.readFields(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ message.write(out);
+ weight.write(out);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-horn/blob/ac8eaf8e/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
index f66344c..c3bf180 100644
--- a/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
+++ b/src/main/java/org/apache/horn/examples/MultiLayerPerceptron.java
@@ -21,9 +21,9 @@ import java.io.IOException;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hama.HamaConfiguration;
-import org.apache.horn.bsp.HornJob;
-import org.apache.horn.bsp.Neuron;
-import org.apache.horn.bsp.Synapse;
+import org.apache.horn.core.HornJob;
+import org.apache.horn.core.Neuron;
+import org.apache.horn.core.Synapse;
import org.apache.horn.funcs.CrossEntropy;
import org.apache.horn.funcs.Sigmoid;
@@ -101,7 +101,7 @@ public class MultiLayerPerceptron {
InterruptedException, ClassNotFoundException {
if (args.length < 9) {
System.out
- .println("Usage: model_path training_set learning_rate momentum regularization_weight feature_dimension label_dimension max_iteration num_tasks");
+ .println("Usage: <MODEL_PATH> <INPUT_PATH> <LEARNING_RATE> <MOMEMTUM_WEIGHT> <REGULARIZATION_WEIGHT> <FEATURE_DIMENSION> <LABEL_DIMENSION> <MAX_ITERATION> <NUM_TASKS>");
System.exit(1);
}
HornJob ann = createJob(new HamaConfiguration(), args[0], args[1],