You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2013/12/19 20:29:03 UTC
svn commit: r1552403 - in /mahout/trunk: ./
core/src/main/java/org/apache/mahout/classifier/mlp/
core/src/test/java/org/apache/mahout/classifier/mlp/
Author: smarthi
Date: Thu Dec 19 19:29:02 2013
New Revision: 1552403
URL: http://svn.apache.org/r1552403
Log:
MAHOUT-1265: Multilayer Perceptron
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
Modified:
mahout/trunk/CHANGELOG
Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1552403&r1=1552402&r2=1552403&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Thu Dec 19 19:29:02 2013
@@ -76,6 +76,8 @@ Release 0.9 - unreleased
MAHOUT-1275: Dropped bz2 distribution format for source and binaries (sslavic)
+ MAHOUT-1265: Multilayer Perceptron (Yexi Jiang via smarthi)
+
MAHOUT-1261: TasteHadoopUtils.idToIndex can return an int that has size Integer.MAX_VALUE (Carl Clark, smarthi)
MAHOUT-1249: Clusterdumper/loadTermDictionary crashes when highest index in (sparse) dictionary vector is larger than dictionary vector size (Andrew Musselman via smarthi)
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+/**
+ * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural
+ * network, which is a mathematical model inspired by the biological neural
+ * network. The multilayer perceptron can be used for various machine learning
+ * tasks such as classification and regression.
+ *
+ * A detailed introduction about MLP can be found at
+ * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks.
+ *
+ * For this particular implementation, the users can freely control the topology
+ * of the MLP, including: 1. The size of the input layer; 2. The number of
+ * hidden layers; 3. The size of each hidden layer; 4. The size of the output
+ * later. 5. The cost function. 6. The squashing function.
+ *
+ * The model is trained in an online learning approach, where the weights of
+ * neurons in the MLP is updated incremented using backPropagation algorithm
+ * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986)
+ * Learning representations by back-propagating errors. Nature, 323, 533--536.)
+ */
+public class MultilayerPerceptron extends NeuralNetwork implements OnlineLearner {
+
+ /**
+ * The default constructor.
+ */
+ public MultilayerPerceptron() {
+ super();
+ }
+
+ /**
+ * Initialize the MLP by specifying the location of the model.
+ *
+ * @param modelPath The path of the model.
+ */
+ public MultilayerPerceptron(String modelPath) {
+ super(modelPath);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ // construct the training instance, where append the actual to instance
+ Vector trainingInstance = new DenseVector(instance.size() + 1);
+ for (int i = 0; i < instance.size(); ++i) {
+ trainingInstance.setQuick(i, instance.getQuick(i));
+ }
+ trainingInstance.setQuick(instance.size(), actual);
+ this.trainOnline(trainingInstance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual,
+ Vector instance) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() {
+ // DO NOTHING
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * AbstractNeuralNetwork defines the general operations for a neural network
+ * based model. Typically, all derivative models such as Multilayer Perceptron
+ * and Autoencoder consist of neurons and the weights between neurons.
+ */
+public abstract class NeuralNetwork {
+
+ /* The default learning rate */
+ private static final double DEFAULT_LEARNING_RATE = 0.5;
+ /* The default regularization weight */
+ private static final double DEFAULT_REGULARIZATION_WEIGHT = 0;
+ /* The default momentum weight */
+ private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+
+ public static enum TrainingMethod {
+ GRADIENT_DESCENT
+ }
+
+ /* the name of the model */
+ protected String modelType;
+
+ /* the path to store the model */
+ protected String modelPath;
+
+ protected double learningRate;
+
+ /* The weight of regularization */
+ protected double regularizationWeight;
+
+ /* The momentum weight */
+ protected double momentumWeight;
+
+ /* The cost function of the model */
+ protected String costFunctionName;
+
+ /* Record the size of each layer */
+ protected List<Integer> layerSizeList;
+
+ /* Training method used for training the model */
+ protected TrainingMethod trainingMethod;
+
+ /* Weights between neurons at adjacent layers */
+ protected List<Matrix> weightMatrixList;
+
+ /* Previous weight updates between neurons at adjacent layers */
+ protected List<Matrix> prevWeightUpdatesList;
+
+ /* Different layers can have different squashing function */
+ protected List<String> squashingFunctionList;
+
+ /* The index of the final layer */
+ protected int finalLayerIdx;
+
+ /**
+ * The default constructor that initializes the learning rate, regularization
+ * weight, and momentum weight by default.
+ */
+ public NeuralNetwork() {
+ this.learningRate = DEFAULT_LEARNING_RATE;
+ this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
+ this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
+ this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
+ this.costFunctionName = "Minus_Squared";
+ this.modelType = this.getClass().getSimpleName();
+
+ this.layerSizeList = Lists.newArrayList();
+ this.layerSizeList = Lists.newArrayList();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
+ this.squashingFunctionList = Lists.newArrayList();
+ }
+
+ /**
+ * Initialize the NeuralNetwork by specifying learning rate, momentum weight
+ * and regularization weight.
+ *
+ * @param learningRate The learning rate.
+ * @param momentumWeight The momentum weight.
+ * @param regularizationWeight The regularization weight.
+ */
+ public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) {
+ this();
+ this.setLearningRate(learningRate);
+ this.setMomentumWeight(momentumWeight);
+ this.setRegularizationWeight(regularizationWeight);
+ }
+
+ /**
+ * Initialize the NeuralNetwork by specifying the location of the model.
+ *
+ * @param modelPath The location that the model is stored.
+ */
+ public NeuralNetwork(String modelPath) {
+ try {
+ this.modelPath = modelPath;
+ this.readFromModel();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Get the type of the model.
+ *
+ * @return The name of the model.
+ */
+ public String getModelType() {
+ return this.modelType;
+ }
+
+ /**
+ * Set the degree of aggression during model training, a large learning rate
+ * can increase the training speed, but it also decreases the chance of model
+ * converge.
+ *
+ * @param learningRate Learning rate must be a non-negative value. Recommend in range (0, 0.5).
+ * @return The model instance.
+ */
+ public NeuralNetwork setLearningRate(double learningRate) {
+ Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0.");
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ /**
+ * Get the value of learning rate.
+ *
+ * @return The value of learning rate.
+ */
+ public double getLearningRate() {
+ return this.learningRate;
+ }
+
+ /**
+ * Set the regularization weight. More complex the model is, less weight the
+ * regularization is.
+ *
+ * @param regularizationWeight regularization must be in the range [0, 0.1).
+ * @return The model instance.
+ */
+ public NeuralNetwork setRegularizationWeight(double regularizationWeight) {
+ Preconditions.checkArgument(regularizationWeight >= 0
+ && regularizationWeight < 0.1, "Regularization weight must be in range [0, 0.1)");
+ this.regularizationWeight = regularizationWeight;
+ return this;
+ }
+
+ /**
+ * Get the weight of the regularization.
+ *
+ * @return The weight of regularization.
+ */
+ public double getRegularizationWeight() {
+ return this.regularizationWeight;
+ }
+
+ /**
+ * Set the momentum weight for the model.
+ *
+ * @param momentumWeight momentumWeight must be in range [0, 0.5].
+ * @return The model instance.
+ */
+ public NeuralNetwork setMomentumWeight(double momentumWeight) {
+ Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0,
+ "Momentum weight must be in range [0, 1.0]");
+ this.momentumWeight = momentumWeight;
+ return this;
+ }
+
+ /**
+ * Get the momentum weight.
+ *
+ * @return The value of momentum.
+ */
+ public double getMomentumWeight() {
+ return this.momentumWeight;
+ }
+
+ /**
+ * Set the training method.
+ *
+ * @param method The training method, currently supports GRADIENT_DESCENT.
+ * @return The instance of the model.
+ */
+ public NeuralNetwork setTrainingMethod(TrainingMethod method) {
+ this.trainingMethod = method;
+ return this;
+ }
+
+ /**
+ * Get the training method.
+ *
+ * @return The training method enumeration.
+ */
+ public TrainingMethod getTrainingMethod() {
+ return this.trainingMethod;
+ }
+
+ /**
+ * Set the cost function for the model.
+ *
+ * @param costFunction the name of the cost function. Currently supports
+ * "Minus_Squared", "Cross_Entropy".
+ */
+ public NeuralNetwork setCostFunction(String costFunction) {
+ this.costFunctionName = costFunction;
+ return this;
+ }
+
+ /**
+ * Add a layer of neurons with specified size. If the added layer is not the
+ * first layer, it will automatically connect the neurons between with the
+ * previous layer.
+ *
+ * @param size The size of the layer. (bias neuron excluded)
+ * @param isFinalLayer If false, add a bias neuron.
+ * @param squashingFunctionName The squashing function for this layer, input
+ * layer is f(x) = x by default.
+ * @return The layer index, starts with 0.
+ */
+ public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) {
+ Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0.");
+ int actualSize = size;
+ if (!isFinalLayer) {
+ actualSize += 1;
+ }
+
+ this.layerSizeList.add(actualSize);
+ int layerIdx = this.layerSizeList.size() - 1;
+ if (isFinalLayer) {
+ this.finalLayerIdx = layerIdx;
+ }
+
+ // add weights between current layer and previous layer, and input layer has
+ // no squashing function
+ if (layerIdx > 0) {
+ int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
+ // row count equals to size of current size and column count equal to
+ // size of previous layer
+ int row = isFinalLayer ? actualSize : actualSize - 1;
+ Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer);
+ // initialize weights
+ final RandomWrapper rnd = RandomUtils.getRandom();
+ weightMatrix.assign(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return rnd.nextDouble() - 0.5;
+ }
+ });
+ this.weightMatrixList.add(weightMatrix);
+ this.prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer));
+ this.squashingFunctionList.add(squashingFunctionName);
+ }
+ return layerIdx;
+ }
+
+ /**
+ * Get the size of a particular layer.
+ *
+ * @param layer The index of the layer, starting from 0.
+ * @return The size of the corresponding layer.
+ */
+ public int getLayerSize(int layer) {
+ Preconditions.checkArgument(layer >= 0 && layer < this.layerSizeList.size(),
+ String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1));
+ return this.layerSizeList.get(layer);
+ }
+
+ /**
+ * Get the layer size list.
+ *
+ * @return The sizes of the layers.
+ */
+ protected List<Integer> getLayerSizeList() {
+ return this.layerSizeList;
+ }
+
+ /**
+ * Get the weights between layer layerIdx and layerIdx + 1
+ *
+ * @param layerIdx The index of the layer.
+ * @return The weights in form of {@link Matrix}.
+ */
+ public Matrix getWeightsByLayer(int layerIdx) {
+ return this.weightMatrixList.get(layerIdx);
+ }
+
+ /**
+ * Update the weight matrices with given matrices.
+ *
+ * @param matrices The weight matrices, must be the same dimension as the
+ * existing matrices.
+ */
+ public void updateWeightMatrices(Matrix[] matrices) {
+ for (int i = 0; i < matrices.length; ++i) {
+ Matrix matrix = this.weightMatrixList.get(i);
+ this.weightMatrixList.set(i, matrix.plus(matrices[i]));
+ }
+ }
+
+ /**
+ * Set the weight matrices.
+ *
+ * @param matrices The weight matrices, must be the same dimension of the
+ * existing matrices.
+ */
+ public void setWeightMatrices(Matrix[] matrices) {
+ this.weightMatrixList = Lists.newArrayList();
+ Collections.addAll(this.weightMatrixList, matrices);
+ }
+
+ /**
+ * Set the weight matrix for a specified layer.
+ *
+ * @param index The index of the matrix, starting from 0 (between layer 0 and 1).
+ * @param matrix The instance of {@link Matrix}.
+ */
+ public void setWeightMatrix(int index, Matrix matrix) {
+ Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(),
+ String.format("index [%s] should be in range [%s, %s).", index, 0, this.weightMatrixList.size()));
+ this.weightMatrixList.set(index, matrix);
+ }
+
+ /**
+ * Get all the weight matrices.
+ *
+ * @return The weight matrices.
+ */
+ public Matrix[] getWeightMatrices() {
+ Matrix[] matrices = new Matrix[this.weightMatrixList.size()];
+ this.weightMatrixList.toArray(matrices);
+ return matrices;
+ }
+
+ /**
+ * Get the output calculated by the model.
+ *
+ * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+ * @return The output vector.
+ */
+ public Vector getOutput(Vector instance) {
+ Preconditions.checkArgument(this.layerSizeList.get(0) == instance.size() + 1,
+ String.format("The dimension of input instance should be %d, but the input has dimension %d.",
+ this.layerSizeList.get(0) - 1, instance.size()));
+
+ // add bias feature
+ Vector instanceWithBias = new DenseVector(instance.size() + 1);
+ // set bias to be a little bit less than 1.0
+ instanceWithBias.set(0, 0.99999);
+ for (int i = 1; i < instanceWithBias.size(); ++i) {
+ instanceWithBias.set(i, instance.get(i - 1));
+ }
+
+ List<Vector> outputCache = getOutputInternal(instanceWithBias);
+ // return the output of the last layer
+ Vector result = outputCache.get(outputCache.size() - 1);
+ // remove bias
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ /**
+ * Calculate output internally, the intermediate output of each layer will be
+ * stored.
+ *
+ * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+ * @return Cached output of each layer.
+ */
+ protected List<Vector> getOutputInternal(Vector instance) {
+ List<Vector> outputCache = Lists.newArrayList();
+ // fill with instance
+ Vector intermediateOutput = instance;
+ outputCache.add(intermediateOutput);
+
+ for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
+ intermediateOutput = forward(i, intermediateOutput);
+ outputCache.add(intermediateOutput);
+ }
+ return outputCache;
+ }
+
+ /**
+ * Forward the calculation for one layer.
+ *
+ * @param fromLayer The index of the previous layer.
+ * @param intermediateOutput The intermediate output of previous layer.
+ * @return The intermediate results of the current layer.
+ */
+ protected Vector forward(int fromLayer, Vector intermediateOutput) {
+ Matrix weightMatrix = this.weightMatrixList.get(fromLayer);
+
+ Vector vec = weightMatrix.times(intermediateOutput);
+ vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(this.squashingFunctionList.get(fromLayer)));
+
+ // add bias
+ Vector vecWithBias = new DenseVector(vec.size() + 1);
+ vecWithBias.set(0, 1);
+ for (int i = 0; i < vec.size(); ++i) {
+ vecWithBias.set(i + 1, vec.get(i));
+ }
+ return vecWithBias;
+ }
+
+ /**
+ * Train the neural network incrementally with given training instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ */
+ public void trainOnline(Vector trainingInstance) {
+ Matrix[] matrices = this.trainByInstance(trainingInstance);
+ this.updateWeightMatrices(matrices);
+ }
+
+ /**
+ * Get the updated weights using one training instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ * @return The update of each weight, in form of {@link Matrix} list.
+ */
+ public Matrix[] trainByInstance(Vector trainingInstance) {
+ // validate training instance
+ int inputDimension = this.layerSizeList.get(0) - 1;
+ int outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
+ Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(),
+ String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(),
+ inputDimension + outputDimension));
+
+ if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
+ return this.trainByInstanceGradientDescent(trainingInstance);
+ }
+ throw new IllegalArgumentException(String.format("Training method is not supported."));
+ }
+
+ /**
+ * Train by gradient descent. Get the updated weights using one training
+ * instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ * @return The weight update matrices.
+ */
+ private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
+ int inputDimension = this.layerSizeList.get(0) - 1;
+
+ Vector inputInstance = new DenseVector(this.layerSizeList.get(0));
+ inputInstance.set(0, 1); // add bias
+ for (int i = 0; i < inputDimension; ++i) {
+ inputInstance.set(i + 1, trainingInstance.get(i));
+ }
+
+ Vector labels = trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);
+
+ // initialize weight update matrices
+ Matrix[] weightUpdateMatrices = new Matrix[this.weightMatrixList.size()];
+ for (int m = 0; m < weightUpdateMatrices.length; ++m) {
+ weightUpdateMatrices[m] = new DenseMatrix(this.weightMatrixList.get(m).rowSize(), this.weightMatrixList.get(m).columnSize());
+ }
+
+ List<Vector> internalResults = this.getOutputInternal(inputInstance);
+
+ Vector deltaVec = new DenseVector(this.layerSizeList.get(this.layerSizeList.size() - 1));
+ Vector output = internalResults.get(internalResults.size() - 1);
+
+ final DoubleFunction derivativeSquashingFunction =
+ NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(this.squashingFunctionList.size() - 1));
+
+ final DoubleDoubleFunction costFunction = NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(this.costFunctionName);
+
+ Matrix lastWeightMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
+
+ for (int i = 0; i < deltaVec.size(); ++i) {
+ double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
+ // add regularization
+ costFuncDerivative += this.regularizationWeight * lastWeightMatrix.viewRow(i).zSum();
+ deltaVec.set(i, costFuncDerivative);
+ deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
+ }
+
+ // start from previous layer of output layer
+ for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
+ deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
+ }
+
+ this.prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
+
+ return weightUpdateMatrices;
+ }
+
+ /**
+ * Back-propagate the errors to from next layer to current layer. The weight
+ * updated information will be stored in the weightUpdateMatrices, and the
+ * delta of the prevLayer will be returned.
+ *
+ * @param curLayerIdx Index of current layer.
+ * @param nextLayerDelta Delta of next layer.
+ * @param outputCache The output cache to store intermediate results.
+ * @param weightUpdateMatrix The weight update, in form of {@link Matrix}.
+ * @return The weight updates.
+ */
+ private Vector backPropagate(int curLayerIdx, Vector nextLayerDelta,
+ List<Vector> outputCache, Matrix weightUpdateMatrix) {
+
+ // get layer related information
+ final DoubleFunction derivativeSquashingFunction =
+ NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(curLayerIdx));
+ Vector curLayerOutput = outputCache.get(curLayerIdx);
+ Matrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
+ Matrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
+
+ // next layer is not output layer, remove the delta of bias neuron
+ if (curLayerIdx != this.layerSizeList.size() - 2) {
+ nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
+ }
+
+ Vector delta = weightMatrix.transpose().times(nextLayerDelta);
+
+ delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double deltaElem, double curLayerOutputElem) {
+ return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
+ }
+ });
+
+ // update weights
+ for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
+ for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
+ weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
+ curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
+ }
+ }
+
+ return delta;
+ }
+
+ /**
+ * Read the model meta-data from the specified location.
+ *
+ * @throws IOException
+ */
+ protected void readFromModel() throws IOException {
+ Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
+ FSDataInputStream is = null;
+ try {
+ Path path = new Path(this.modelPath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+ is = new FSDataInputStream(fs.open(path));
+ this.readFields(is);
+ } finally {
+ Closeables.close(is, true);
+ }
+ }
+
+ /**
+ * Write the model data to specified location.
+ *
+ * @throws IOException
+ */
+ public void writeModelToFile() throws IOException {
+ Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
+ FSDataOutputStream stream = null;
+ try {
+ Path path = new Path(this.modelPath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+ stream = fs.create(path, true);
+ this.write(stream);
+ } finally {
+ Closeables.close(stream, false);
+ }
+ }
+
+ /**
+ * Set the model path.
+ *
+ * @param modelPath The path of the model.
+ */
+ public void setModelPath(String modelPath) {
+ this.modelPath = modelPath;
+ }
+
+ /**
+ * Get the model path.
+ *
+ * @return The path of the model.
+ */
+ public String getModelPath() {
+ return this.modelPath;
+ }
+
+ /**
+ * Write the fields of the model to output.
+ *
+ * @param output The output instance.
+ * @throws IOException
+ */
+ public void write(DataOutput output) throws IOException {
+ // write model type
+ WritableUtils.writeString(output, modelType);
+ // write learning rate
+ output.writeDouble(learningRate);
+ // write model path
+ if (this.modelPath != null) {
+ WritableUtils.writeString(output, modelPath);
+ } else {
+ WritableUtils.writeString(output, "null");
+ }
+
+ // write regularization weight
+ output.writeDouble(this.regularizationWeight);
+ // write momentum weight
+ output.writeDouble(this.momentumWeight);
+
+ // write cost function
+ WritableUtils.writeString(output, this.costFunctionName);
+
+ // write layer size list
+ output.writeInt(this.layerSizeList.size());
+ for (Integer aLayerSizeList : this.layerSizeList) {
+ output.writeInt(aLayerSizeList);
+ }
+
+ WritableUtils.writeEnum(output, this.trainingMethod);
+
+ // write squashing functions
+ output.writeInt(this.squashingFunctionList.size());
+ for (String aSquashingFunctionList : this.squashingFunctionList) {
+ WritableUtils.writeString(output, aSquashingFunctionList);
+ }
+
+ // write weight matrices
+ output.writeInt(this.weightMatrixList.size());
+ for (Matrix aWeightMatrixList : this.weightMatrixList) {
+ MatrixWritable.writeMatrix(output, aWeightMatrixList);
+ }
+ }
+
+ /**
+ * Read the fields of the model from input.
+ *
+ * @param input The input instance.
+ * @throws IOException
+ */
+ public void readFields(DataInput input) throws IOException {
+ // read model type
+ this.modelType = WritableUtils.readString(input);
+ if (!this.modelType.equals(this.getClass().getSimpleName())) {
+ throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
+ }
+ // read learning rate
+ this.learningRate = input.readDouble();
+ // read model path
+ this.modelPath = WritableUtils.readString(input);
+ if (this.modelPath.equals("null")) {
+ this.modelPath = null;
+ }
+
+ // read regularization weight
+ this.regularizationWeight = input.readDouble();
+ // read momentum weight
+ this.momentumWeight = input.readDouble();
+
+ // read cost function
+ this.costFunctionName = WritableUtils.readString(input);
+
+ // read layer size list
+ int numLayers = input.readInt();
+ this.layerSizeList = Lists.newArrayList();
+ for (int i = 0; i < numLayers; i++) {
+ this.layerSizeList.add(input.readInt());
+ }
+
+ this.trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
+
+ // read squash functions
+ int squashingFunctionSize = input.readInt();
+ this.squashingFunctionList = Lists.newArrayList();
+ for (int i = 0; i < squashingFunctionSize; i++) {
+ this.squashingFunctionList.add(WritableUtils.readString(input));
+ }
+
+ // read weights and construct matrices of previous updates
+ int numOfMatrices = input.readInt();
+ this.weightMatrixList = Lists.newArrayList();
+ this.prevWeightUpdatesList = Lists.newArrayList();
+ for (int i = 0; i < numOfMatrices; i++) {
+ Matrix matrix = MatrixWritable.readMatrix(input);
+ this.weightMatrixList.add(matrix);
+ this.prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
+ }
+ }
+
+}
\ No newline at end of file
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,150 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * The functions that will be used by NeuralNetwork.
+ */
+public class NeuralNetworkFunctions {
+
+ /**
+ * The derivation of identity function (f(x) = x).
+ */
+ public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() {
+ @Override
+ public double apply(double x) {
+ return 1;
+ }
+ };
+
+ /**
+ * The derivation of minus squared function (f(t, o) = (o - t)^2).
+ */
+ public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ return 2 * (output - target);
+ }
+ };
+
+ /**
+ * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)).
+ */
+ public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ return -target * Math.log(output) - (1 - target) * Math.log(1 - output);
+ }
+ };
+
+ /**
+ * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) *
+ * log(1 - o)).
+ */
+ public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ double adjustedTarget = target;
+ double adjustedActual = output;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (output == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual);
+ }
+ };
+
+ /**
+ * Get the corresponding function by its name.
+ * Currently supports: "Identity", "Sigmoid".
+ *
+ * @param function The name of the function.
+ * @return The corresponding double function.
+ */
+ public static DoubleFunction getDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Identity")) {
+ return Functions.IDENTITY;
+ } else if (function.equalsIgnoreCase("Sigmoid")) {
+ return Functions.SIGMOID;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the derivation double function by the name.
+ * Currently supports: "Identity", "Sigmoid".
+ *
+ * @param function The name of the function.
+ * @return The double function.
+ */
+ public static DoubleFunction getDerivativeDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Identity")) {
+ return derivativeIdentityFunction;
+ } else if (function.equalsIgnoreCase("Sigmoid")) {
+ return Functions.SIGMOIDGRADIENT;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the corresponding double-double function by the name.
+ * Currently supports: "Minus_Squared", "Cross_Entropy".
+ *
+ * @param function The name of the function.
+ * @return The double-double function.
+ */
+ public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Minus_Squared")) {
+ return Functions.MINUS_SQUARED;
+ } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+ return derivativeCrossEntropy;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the corresponding derivation of double double function by the name.
+ * Currently supports: "Minus_Squared", "Cross_Entropy".
+ *
+ * @param function The name of the function.
+ * @return The double-double-function.
+ */
+ public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Minus_Squared")) {
+ return derivativeMinusSquared;
+ } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+ return derivativeCrossEntropy;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+}
\ No newline at end of file
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+/**
+ * Test the functionality of {@link MultilayerPerceptron}
+ */
+public class TestMultilayerPerceptron extends MahoutTestCase {
+
+ @Test
+ public void testMLP() throws IOException {
+ testMLP("testMLPXORLocal", false, false, 8000);
+ testMLP("testMLPXORLocalWithMomentum", true, false, 4000);
+ testMLP("testMLPXORLocalWithRegularization", true, true, 2000);
+ }
+
+ private void testMLP(String modelFilename, boolean useMomentum,
+ boolean useRegularization, int iterations) throws IOException {
+ MultilayerPerceptron mlp = new MultilayerPerceptron();
+ mlp.addLayer(2, false, "Sigmoid");
+ mlp.addLayer(3, false, "Sigmoid");
+ mlp.addLayer(1, true, "Sigmoid");
+ mlp.setCostFunction("Minus_Squared").setLearningRate(0.2);
+ if (useMomentum) {
+ mlp.setMomentumWeight(0.6);
+ }
+
+ if (useRegularization) {
+ mlp.setRegularizationWeight(0.01);
+ }
+
+ double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } };
+ for (int i = 0; i < iterations; ++i) {
+ for (double[] instance : instances) {
+ Vector features = new DenseVector(Arrays.copyOf(instance, instance.length - 1));
+ mlp.train((int) instance[2], features);
+ }
+ }
+
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = mlp.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+
+ // write model into file and read out
+ File modelFile = this.getTestTempFile(modelFilename);
+ mlp.setModelPath(modelFile.getAbsolutePath());
+ mlp.writeModelToFile();
+ mlp.close();
+
+ MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath());
+ // test on instances
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = mlpCopy.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+ mlpCopy.close();
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.mahout.classifier.mlp.NeuralNetwork.TrainingMethod;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Files;
+
+/**
+ * Test the functionality of {@link NeuralNetwork}.
+ */
+public class TestNeuralNetwork extends MahoutTestCase {
+
+ @Test
+ public void testReadWrite() throws IOException {
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Identity");
+ ann.addLayer(5, false, "Identity");
+ ann.addLayer(1, true, "Identity");
+ ann.setCostFunction("Minus_Squared");
+ double learningRate = 0.2;
+ double momentumWeight = 0.5;
+ double regularizationWeight = 0.05;
+ ann.setLearningRate(learningRate).setMomentumWeight(momentumWeight).setRegularizationWeight(regularizationWeight);
+
+ // manually set weights
+ Matrix[] matrices = new DenseMatrix[2];
+ matrices[0] = new DenseMatrix(5, 3);
+ matrices[0].assign(0.2);
+ matrices[1] = new DenseMatrix(1, 6);
+ matrices[1].assign(0.8);
+ ann.setWeightMatrices(matrices);
+
+ // write to file
+ String modelFilename = "testNeuralNetworkReadWrite";
+ File tmpModelFile = this.getTestTempFile(modelFilename);
+ ann.setModelPath(tmpModelFile.getAbsolutePath());
+ ann.writeModelToFile();
+
+ // read from file
+ NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+ assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType());
+ assertEquals(tmpModelFile.getAbsolutePath(), annCopy.getModelPath());
+ assertEquals(learningRate, annCopy.getLearningRate(), 0.000001);
+ assertEquals(momentumWeight, annCopy.getMomentumWeight(), 0.000001);
+ assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), 0.000001);
+ assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod());
+
+ // compare weights
+ Matrix[] weightsMatrices = annCopy.getWeightMatrices();
+ for (int i = 0; i < weightsMatrices.length; ++i) {
+ Matrix expectMat = matrices[i];
+ Matrix actualMat = weightsMatrices[i];
+ for (int j = 0; j < expectMat.rowSize(); ++j) {
+ for (int k = 0; k < expectMat.columnSize(); ++k) {
+ assertEquals(expectMat.get(j, k), actualMat.get(j, k), 0.000001);
+ }
+ }
+ }
+ }
+
+ /**
+ * Test the forward functionality.
+ */
+ @Test
+ public void testOutput() {
+ // first network
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Identity");
+ ann.addLayer(5, false, "Identity");
+ ann.addLayer(1, true, "Identity");
+ ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+ // intentionally initialize all weights to 0.5
+ Matrix[] matrices = new Matrix[2];
+ matrices[0] = new DenseMatrix(5, 3);
+ matrices[0].assign(0.5);
+ matrices[1] = new DenseMatrix(1, 6);
+ matrices[1].assign(0.5);
+ ann.setWeightMatrices(matrices);
+
+ double[] arr = new double[]{0, 1};
+ Vector training = new DenseVector(arr);
+ Vector result = ann.getOutput(training);
+ assertEquals(1, result.size());
+
+ // second network
+ NeuralNetwork ann2 = new MultilayerPerceptron();
+ ann2.addLayer(2, false, "Sigmoid");
+ ann2.addLayer(3, false, "Sigmoid");
+ ann2.addLayer(1, true, "Sigmoid");
+ ann2.setCostFunction("Minus_Squared");
+ ann2.setLearningRate(0.3);
+
+ // intentionally initialize all weights to 0.5
+ Matrix[] matrices2 = new Matrix[2];
+ matrices2[0] = new DenseMatrix(3, 3);
+ matrices2[0].assign(0.5);
+ matrices2[1] = new DenseMatrix(1, 4);
+ matrices2[1].assign(0.5);
+ ann2.setWeightMatrices(matrices2);
+
+ double[] test = {0, 0};
+ double[] result2 = {0.807476};
+
+ Vector vec = ann2.getOutput(new DenseVector(test));
+ double[] arrVec = new double[vec.size()];
+ for (int i = 0; i < arrVec.length; ++i) {
+ arrVec[i] = vec.getQuick(i);
+ }
+ assertArrayEquals(result2, arrVec, 0.000001);
+
+ NeuralNetwork ann3 = new MultilayerPerceptron();
+ ann3.addLayer(2, false, "Sigmoid");
+ ann3.addLayer(3, false, "Sigmoid");
+ ann3.addLayer(1, true, "Sigmoid");
+ ann3.setCostFunction("Minus_Squared").setLearningRate(0.3);
+
+ // intentionally initialize all weights to 0.5
+ Matrix[] initMatrices = new Matrix[2];
+ initMatrices[0] = new DenseMatrix(3, 3);
+ initMatrices[0].assign(0.5);
+ initMatrices[1] = new DenseMatrix(1, 4);
+ initMatrices[1].assign(0.5);
+ ann3.setWeightMatrices(initMatrices);
+
+ double[] instance = {0, 1};
+ Vector output = ann3.getOutput(new DenseVector(instance));
+ assertEquals(0.8315410, output.get(0), 0.000001);
+ }
+
+ @Test
+ public void testNeuralNetwork() throws IOException {
+ testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000);
+ testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000);
+ testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000);
+ }
+
+ private void testNeuralNetwork(String modelFilename, boolean useMomentum,
+ boolean useRegularization, int iterations) throws IOException {
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Sigmoid");
+ ann.addLayer(3, false, "Sigmoid");
+ ann.addLayer(1, true, "Sigmoid");
+ ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+ if (useMomentum) {
+ ann.setMomentumWeight(0.6);
+ }
+
+ if (useRegularization) {
+ ann.setRegularizationWeight(0.01);
+ }
+
+ double[][] instances = {{0, 1, 1}, {0, 0, 0}, {1, 0, 1}, {1, 1, 0}};
+ for (int i = 0; i < iterations; ++i) {
+ for (double[] instance : instances) {
+ ann.trainOnline(new DenseVector(instance));
+ }
+ }
+
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = ann.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+
+ // write model into file and read out
+ File tmpModelFile = this.getTestTempFile(modelFilename);
+ ann.setModelPath(tmpModelFile.getAbsolutePath());
+ ann.writeModelToFile();
+
+ NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+ // test on instances
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = annCopy.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+ }
+
+ @Test
+ public void testWithCancerDataSet() throws IOException {
+ String dataSetPath = "src/test/resources/cancer.csv";
+ List<Vector> records = Lists.newArrayList();
+ // Returns a mutable list of the data
+ List<String> cancerDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8);
+ // skip the header line, hence remove the first element in the list
+ cancerDataSetList.remove(0);
+ for (String line : cancerDataSetList) {
+ String[] tokens = line.split(",");
+ double[] values = new double[tokens.length];
+ for (int i = 0; i < tokens.length; ++i) {
+ values[i] = Double.parseDouble(tokens[i]);
+ }
+ records.add(new DenseVector(values));
+ }
+
+ int splitPoint = (int) (records.size() * 0.8);
+ List<Vector> trainingSet = records.subList(0, splitPoint);
+ List<Vector> testSet = records.subList(splitPoint, records.size());
+
+ // initialize neural network model
+ NeuralNetwork ann = new MultilayerPerceptron();
+ int featureDimension = records.get(0).size() - 1;
+ ann.addLayer(featureDimension, false, "Sigmoid");
+ ann.addLayer(featureDimension * 2, false, "Sigmoid");
+ ann.addLayer(1, true, "Sigmoid");
+ ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001);
+
+ int iteration = 2000;
+ for (int i = 0; i < iteration; ++i) {
+ for (Vector trainingInstance : trainingSet) {
+ ann.trainOnline(trainingInstance);
+ }
+ }
+
+ int correctInstances = 0;
+ for (Vector testInstance : testSet) {
+ Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1));
+ double actual = res.get(0);
+ double expected = testInstance.get(testInstance.size() - 1);
+ if (Math.abs(actual - expected) <= 0.1) {
+ ++correctInstances;
+ }
+ }
+ double accuracy = (double) correctInstances / testSet.size() * 100;
+ assertTrue("The classifier is even worse than a random guesser!", accuracy > 50);
+ System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
+ }
+
+ @Test
+ public void testWithIrisDataSet() throws IOException {
+ String dataSetPath = "src/test/resources/iris.csv";
+ int numOfClasses = 3;
+ List<Vector> records = Lists.newArrayList();
+ // Returns a mutable list of the data
+ List<String> irisDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8);
+ // skip the header line, hence remove the first element in the list
+ irisDataSetList.remove(0);
+
+ for (String line : irisDataSetList) {
+ String[] tokens = line.split(",");
+ // last three dimensions represent the labels
+ double[] values = new double[tokens.length + numOfClasses - 1];
+ Arrays.fill(values, 0.0);
+ for (int i = 0; i < tokens.length - 1; ++i) {
+ values[i] = Double.parseDouble(tokens[i]);
+ }
+ // add label values
+ String label = tokens[tokens.length - 1];
+ if (label.equalsIgnoreCase("setosa")) {
+ values[values.length - 3] = 1;
+ } else if (label.equalsIgnoreCase("versicolor")) {
+ values[values.length - 2] = 1;
+ } else { // label 'virginica'
+ values[values.length - 1] = 1;
+ }
+ records.add(new DenseVector(values));
+ }
+
+ Collections.shuffle(records);
+
+ int splitPoint = (int) (records.size() * 0.8);
+ List<Vector> trainingSet = records.subList(0, splitPoint);
+ List<Vector> testSet = records.subList(splitPoint, records.size());
+
+ // initialize neural network model
+ NeuralNetwork ann = new MultilayerPerceptron();
+ int featureDimension = records.get(0).size() - numOfClasses;
+ ann.addLayer(featureDimension, false, "Sigmoid");
+ ann.addLayer(featureDimension * 2, false, "Sigmoid");
+ ann.addLayer(3, true, "Sigmoid"); // 3-class classification
+ ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005);
+
+ int iteration = 2000;
+ for (int i = 0; i < iteration; ++i) {
+ for (Vector trainingInstance : trainingSet) {
+ ann.trainOnline(trainingInstance);
+ }
+ }
+
+ int correctInstances = 0;
+ for (Vector testInstance : testSet) {
+ Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses));
+ double[] actualLabels = new double[numOfClasses];
+ for (int i = 0; i < numOfClasses; ++i) {
+ actualLabels[i] = res.get(i);
+ }
+ double[] expectedLabels = new double[numOfClasses];
+ for (int i = 0; i < numOfClasses; ++i) {
+ expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i);
+ }
+
+ boolean allCorrect = true;
+ for (int i = 0; i < numOfClasses; ++i) {
+ if (Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1) {
+ allCorrect = false;
+ break;
+ }
+ }
+ if (allCorrect) {
+ ++correctInstances;
+ }
+ }
+
+ double accuracy = (double) correctInstances / testSet.size() * 100;
+ assertTrue("The model is even worse than a random guesser.", accuracy > 50);
+
+ System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
+ }
+
+}