You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2013/12/19 20:29:03 UTC

svn commit: r1552403 - in /mahout/trunk: ./ core/src/main/java/org/apache/mahout/classifier/mlp/ core/src/test/java/org/apache/mahout/classifier/mlp/

Author: smarthi
Date: Thu Dec 19 19:29:02 2013
New Revision: 1552403

URL: http://svn.apache.org/r1552403
Log:
MAHOUT-1265: Multilayer Perceptron

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
Modified:
    mahout/trunk/CHANGELOG

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1552403&r1=1552402&r2=1552403&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Thu Dec 19 19:29:02 2013
@@ -76,6 +76,8 @@ Release 0.9 - unreleased
 
   MAHOUT-1275: Dropped bz2 distribution format for source and binaries (sslavic)
 
+  MAHOUT-1265: Multilayer Perceptron (Yexi Jiang via smarthi)
+
   MAHOUT-1261: TasteHadoopUtils.idToIndex can return an int that has size Integer.MAX_VALUE (Carl Clark, smarthi)
 
   MAHOUT-1249: Clusterdumper/loadTermDictionary crashes when highest index in (sparse) dictionary vector is larger than dictionary vector size (Andrew Musselman via smarthi)

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+/**
+ * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural
+ * network, which is a mathematical model inspired by the biological neural
+ * network. The multilayer perceptron can be used for various machine learning
+ * tasks such as classification and regression.
+ * 
+ * A detailed introduction about MLP can be found at
+ * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks.
+ * 
+ * For this particular implementation, the users can freely control the topology
+ * of the MLP, including: 1. The size of the input layer; 2. The number of
+ * hidden layers; 3. The size of each hidden layer; 4. The size of the output
+ * later. 5. The cost function. 6. The squashing function.
+ * 
+ * The model is trained in an online learning approach, where the weights of
+ * neurons in the MLP is updated incremented using backPropagation algorithm
+ * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986)
+ * Learning representations by back-propagating errors. Nature, 323, 533--536.)
+ */
+public class MultilayerPerceptron extends NeuralNetwork implements OnlineLearner {
+
+  /**
+   * The default constructor.
+   */
+  public MultilayerPerceptron() {
+    super();
+  }
+
+  /**
+   * Initialize the MLP by specifying the location of the model.
+   * 
+   * @param modelPath The path of the model.
+   */
+  public MultilayerPerceptron(String modelPath) {
+    super(modelPath);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    // construct the training instance, where append the actual to instance
+    Vector trainingInstance = new DenseVector(instance.size() + 1);
+    for (int i = 0; i < instance.size(); ++i) {
+      trainingInstance.setQuick(i, instance.getQuick(i));
+    }
+    trainingInstance.setQuick(instance.size(), actual);
+    this.trainOnline(trainingInstance);
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual,
+      Vector instance) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public void close() {
+    // DO NOTHING
+  }
+
+}

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,740 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * AbstractNeuralNetwork defines the general operations for a neural network
+ * based model. Typically, all derivative models such as Multilayer Perceptron
+ * and Autoencoder consist of neurons and the weights between neurons.
+ */
+public abstract class NeuralNetwork {
+
+  /* The default learning rate */
+  private static final double DEFAULT_LEARNING_RATE = 0.5;
+  /* The default regularization weight */
+  private static final double DEFAULT_REGULARIZATION_WEIGHT = 0;
+  /* The default momentum weight */
+  private static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+
+  public static enum TrainingMethod {
+    GRADIENT_DESCENT
+  }
+
+  /* the name of the model */
+  protected String modelType;
+
+  /* the path to store the model */
+  protected String modelPath;
+
+  protected double learningRate;
+
+  /* The weight of regularization */
+  protected double regularizationWeight;
+
+  /* The momentum weight */
+  protected double momentumWeight;
+
+  /* The cost function of the model */
+  protected String costFunctionName;
+
+  /* Record the size of each layer */
+  protected List<Integer> layerSizeList;
+
+  /* Training method used for training the model */
+  protected TrainingMethod trainingMethod;
+
+  /* Weights between neurons at adjacent layers */
+  protected List<Matrix> weightMatrixList;
+
+  /* Previous weight updates between neurons at adjacent layers */
+  protected List<Matrix> prevWeightUpdatesList;
+
+  /* Different layers can have different squashing function */
+  protected List<String> squashingFunctionList;
+
+  /* The index of the final layer */
+  protected int finalLayerIdx;
+
+  /**
+   * The default constructor that initializes the learning rate, regularization
+   * weight, and momentum weight by default.
+   */
+  public NeuralNetwork() {
+    this.learningRate = DEFAULT_LEARNING_RATE;
+    this.regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
+    this.momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
+    this.trainingMethod = TrainingMethod.GRADIENT_DESCENT;
+    this.costFunctionName = "Minus_Squared";
+    this.modelType = this.getClass().getSimpleName();
+
+    this.layerSizeList = Lists.newArrayList();
+    this.layerSizeList = Lists.newArrayList();
+    this.weightMatrixList = Lists.newArrayList();
+    this.prevWeightUpdatesList = Lists.newArrayList();
+    this.squashingFunctionList = Lists.newArrayList();
+  }
+
+  /**
+   * Initialize the NeuralNetwork by specifying learning rate, momentum weight
+   * and regularization weight.
+   * 
+   * @param learningRate The learning rate.
+   * @param momentumWeight The momentum weight.
+   * @param regularizationWeight The regularization weight.
+   */
+  public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) {
+    this();
+    this.setLearningRate(learningRate);
+    this.setMomentumWeight(momentumWeight);
+    this.setRegularizationWeight(regularizationWeight);
+  }
+
+  /**
+   * Initialize the NeuralNetwork by specifying the location of the model.
+   * 
+   * @param modelPath The location that the model is stored.
+   */
+  public NeuralNetwork(String modelPath) {
+    try {
+      this.modelPath = modelPath;
+      this.readFromModel();
+    } catch (IOException e) {
+      e.printStackTrace();
+    }
+  }
+
+  /**
+   * Get the type of the model.
+   * 
+   * @return The name of the model.
+   */
+  public String getModelType() {
+    return this.modelType;
+  }
+
+  /**
+   * Set the degree of aggression during model training, a large learning rate
+   * can increase the training speed, but it also decreases the chance of model
+   * converge.
+   * 
+   * @param learningRate Learning rate must be a non-negative value. Recommend in range (0, 0.5).
+   * @return The model instance.
+   */
+  public NeuralNetwork setLearningRate(double learningRate) {
+    Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0.");
+    this.learningRate = learningRate;
+    return this;
+  }
+
+  /**
+   * Get the value of learning rate.
+   * 
+   * @return The value of learning rate.
+   */
+  public double getLearningRate() {
+    return this.learningRate;
+  }
+
+  /**
+   * Set the regularization weight. More complex the model is, less weight the
+   * regularization is.
+   * 
+   * @param regularizationWeight regularization must be in the range [0, 0.1).
+   * @return The model instance.
+   */
+  public NeuralNetwork setRegularizationWeight(double regularizationWeight) {
+    Preconditions.checkArgument(regularizationWeight >= 0
+        && regularizationWeight < 0.1, "Regularization weight must be in range [0, 0.1)");
+    this.regularizationWeight = regularizationWeight;
+    return this;
+  }
+
+  /**
+   * Get the weight of the regularization.
+   * 
+   * @return The weight of regularization.
+   */
+  public double getRegularizationWeight() {
+    return this.regularizationWeight;
+  }
+
+  /**
+   * Set the momentum weight for the model.
+   * 
+   * @param momentumWeight momentumWeight must be in range [0, 0.5].
+   * @return The model instance.
+   */
+  public NeuralNetwork setMomentumWeight(double momentumWeight) {
+    Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0,
+        "Momentum weight must be in range [0, 1.0]");
+    this.momentumWeight = momentumWeight;
+    return this;
+  }
+
+  /**
+   * Get the momentum weight.
+   * 
+   * @return The value of momentum.
+   */
+  public double getMomentumWeight() {
+    return this.momentumWeight;
+  }
+
+  /**
+   * Set the training method.
+   * 
+   * @param method The training method, currently supports GRADIENT_DESCENT.
+   * @return The instance of the model.
+   */
+  public NeuralNetwork setTrainingMethod(TrainingMethod method) {
+    this.trainingMethod = method;
+    return this;
+  }
+
+  /**
+   * Get the training method.
+   * 
+   * @return The training method enumeration.
+   */
+  public TrainingMethod getTrainingMethod() {
+    return this.trainingMethod;
+  }
+
+  /**
+   * Set the cost function for the model.
+   * 
+   * @param costFunction the name of the cost function. Currently supports
+   *          "Minus_Squared", "Cross_Entropy".
+   */
+  public NeuralNetwork setCostFunction(String costFunction) {
+    this.costFunctionName = costFunction;
+    return this;
+  }
+
+  /**
+   * Add a layer of neurons with specified size. If the added layer is not the
+   * first layer, it will automatically connect the neurons between with the
+   * previous layer.
+   * 
+   * @param size The size of the layer. (bias neuron excluded)
+   * @param isFinalLayer If false, add a bias neuron.
+   * @param squashingFunctionName The squashing function for this layer, input
+   *          layer is f(x) = x by default.
+   * @return The layer index, starts with 0.
+   */
+  public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) {
+    Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0.");
+    int actualSize = size;
+    if (!isFinalLayer) {
+      actualSize += 1;
+    }
+
+    this.layerSizeList.add(actualSize);
+    int layerIdx = this.layerSizeList.size() - 1;
+    if (isFinalLayer) {
+      this.finalLayerIdx = layerIdx;
+    }
+
+    // add weights between current layer and previous layer, and input layer has
+    // no squashing function
+    if (layerIdx > 0) {
+      int sizePrevLayer = this.layerSizeList.get(layerIdx - 1);
+      // row count equals to size of current size and column count equal to
+      // size of previous layer
+      int row = isFinalLayer ? actualSize : actualSize - 1;
+      Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer);
+      // initialize weights
+      final RandomWrapper rnd = RandomUtils.getRandom();
+      weightMatrix.assign(new DoubleFunction() {
+        @Override
+        public double apply(double value) {
+          return rnd.nextDouble() - 0.5;
+        }
+      });
+      this.weightMatrixList.add(weightMatrix);
+      this.prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer));
+      this.squashingFunctionList.add(squashingFunctionName);
+    }
+    return layerIdx;
+  }
+
+  /**
+   * Get the size of a particular layer.
+   * 
+   * @param layer The index of the layer, starting from 0.
+   * @return The size of the corresponding layer.
+   */
+  public int getLayerSize(int layer) {
+    Preconditions.checkArgument(layer >= 0 && layer < this.layerSizeList.size(),
+        String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1));
+    return this.layerSizeList.get(layer);
+  }
+
+  /**
+   * Get the layer size list.
+   * 
+   * @return The sizes of the layers.
+   */
+  protected List<Integer> getLayerSizeList() {
+    return this.layerSizeList;
+  }
+
+  /**
+   * Get the weights between layer layerIdx and layerIdx + 1
+   * 
+   * @param layerIdx The index of the layer.
+   * @return The weights in form of {@link Matrix}.
+   */
+  public Matrix getWeightsByLayer(int layerIdx) {
+    return this.weightMatrixList.get(layerIdx);
+  }
+
+  /**
+   * Update the weight matrices with given matrices.
+   * 
+   * @param matrices The weight matrices, must be the same dimension as the
+   *          existing matrices.
+   */
+  public void updateWeightMatrices(Matrix[] matrices) {
+    for (int i = 0; i < matrices.length; ++i) {
+      Matrix matrix = this.weightMatrixList.get(i);
+      this.weightMatrixList.set(i, matrix.plus(matrices[i]));
+    }
+  }
+
+  /**
+   * Set the weight matrices.
+   * 
+   * @param matrices The weight matrices, must be the same dimension of the
+   *          existing matrices.
+   */
+  public void setWeightMatrices(Matrix[] matrices) {
+    this.weightMatrixList = Lists.newArrayList();
+    Collections.addAll(this.weightMatrixList, matrices);
+  }
+
+  /**
+   * Set the weight matrix for a specified layer.
+   * 
+   * @param index The index of the matrix, starting from 0 (between layer 0 and 1).
+   * @param matrix The instance of {@link Matrix}.
+   */
+  public void setWeightMatrix(int index, Matrix matrix) {
+    Preconditions.checkArgument(0 <= index && index < this.weightMatrixList.size(),
+        String.format("index [%s] should be in range [%s, %s).", index, 0, this.weightMatrixList.size()));
+    this.weightMatrixList.set(index, matrix);
+  }
+
+  /**
+   * Get all the weight matrices.
+   * 
+   * @return The weight matrices.
+   */
+  public Matrix[] getWeightMatrices() {
+    Matrix[] matrices = new Matrix[this.weightMatrixList.size()];
+    this.weightMatrixList.toArray(matrices);
+    return matrices;
+  }
+
+  /**
+   * Get the output calculated by the model.
+   * 
+   * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+   * @return The output vector.
+   */
+  public Vector getOutput(Vector instance) {
+    Preconditions.checkArgument(this.layerSizeList.get(0) == instance.size() + 1,
+        String.format("The dimension of input instance should be %d, but the input has dimension %d.",
+            this.layerSizeList.get(0) - 1, instance.size()));
+
+    // add bias feature
+    Vector instanceWithBias = new DenseVector(instance.size() + 1);
+    // set bias to be a little bit less than 1.0
+    instanceWithBias.set(0, 0.99999);
+    for (int i = 1; i < instanceWithBias.size(); ++i) {
+      instanceWithBias.set(i, instance.get(i - 1));
+    }
+
+    List<Vector> outputCache = getOutputInternal(instanceWithBias);
+    // return the output of the last layer
+    Vector result = outputCache.get(outputCache.size() - 1);
+    // remove bias
+    return result.viewPart(1, result.size() - 1);
+  }
+
+  /**
+   * Calculate output internally, the intermediate output of each layer will be
+   * stored.
+   * 
+   * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+   * @return Cached output of each layer.
+   */
+  protected List<Vector> getOutputInternal(Vector instance) {
+    List<Vector> outputCache = Lists.newArrayList();
+    // fill with instance
+    Vector intermediateOutput = instance;
+    outputCache.add(intermediateOutput);
+
+    for (int i = 0; i < this.layerSizeList.size() - 1; ++i) {
+      intermediateOutput = forward(i, intermediateOutput);
+      outputCache.add(intermediateOutput);
+    }
+    return outputCache;
+  }
+
+  /**
+   * Forward the calculation for one layer.
+   * 
+   * @param fromLayer The index of the previous layer.
+   * @param intermediateOutput The intermediate output of previous layer.
+   * @return The intermediate results of the current layer.
+   */
+  protected Vector forward(int fromLayer, Vector intermediateOutput) {
+    Matrix weightMatrix = this.weightMatrixList.get(fromLayer);
+
+    Vector vec = weightMatrix.times(intermediateOutput);
+    vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(this.squashingFunctionList.get(fromLayer)));
+
+    // add bias
+    Vector vecWithBias = new DenseVector(vec.size() + 1);
+    vecWithBias.set(0, 1);
+    for (int i = 0; i < vec.size(); ++i) {
+      vecWithBias.set(i + 1, vec.get(i));
+    }
+    return vecWithBias;
+  }
+
+  /**
+   * Train the neural network incrementally with given training instance.
+   * 
+   * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   */
+  public void trainOnline(Vector trainingInstance) {
+    Matrix[] matrices = this.trainByInstance(trainingInstance);
+    this.updateWeightMatrices(matrices);
+  }
+
+  /**
+   * Get the updated weights using one training instance.
+   * 
+   * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   * @return The update of each weight, in form of {@link Matrix} list.
+   */
+  public Matrix[] trainByInstance(Vector trainingInstance) {
+    // validate training instance
+    int inputDimension = this.layerSizeList.get(0) - 1;
+    int outputDimension = this.layerSizeList.get(this.layerSizeList.size() - 1);
+    Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(),
+        String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(),
+            inputDimension + outputDimension));
+
+    if (this.trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
+      return this.trainByInstanceGradientDescent(trainingInstance);
+    }
+    throw new IllegalArgumentException(String.format("Training method is not supported."));
+  }
+
+  /**
+   * Train by gradient descent. Get the updated weights using one training
+   * instance.
+   * 
+   * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+   *          to the size of the input layer (bias neuron excluded) + the size
+   *          of the output layer (a.k.a. the dimension of the labels).
+   * @return The weight update matrices.
+   */
+  private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
+    int inputDimension = this.layerSizeList.get(0) - 1;
+
+    Vector inputInstance = new DenseVector(this.layerSizeList.get(0));
+    inputInstance.set(0, 1); // add bias
+    for (int i = 0; i < inputDimension; ++i) {
+      inputInstance.set(i + 1, trainingInstance.get(i));
+    }
+
+    Vector labels = trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);
+
+    // initialize weight update matrices
+    Matrix[] weightUpdateMatrices = new Matrix[this.weightMatrixList.size()];
+    for (int m = 0; m < weightUpdateMatrices.length; ++m) {
+      weightUpdateMatrices[m] = new DenseMatrix(this.weightMatrixList.get(m).rowSize(), this.weightMatrixList.get(m).columnSize());
+    }
+
+    List<Vector> internalResults = this.getOutputInternal(inputInstance);
+
+    Vector deltaVec = new DenseVector(this.layerSizeList.get(this.layerSizeList.size() - 1));
+    Vector output = internalResults.get(internalResults.size() - 1);
+
+    final DoubleFunction derivativeSquashingFunction =
+        NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(this.squashingFunctionList.size() - 1));
+
+    final DoubleDoubleFunction costFunction = NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(this.costFunctionName);
+
+    Matrix lastWeightMatrix = this.weightMatrixList.get(this.weightMatrixList.size() - 1);
+
+    for (int i = 0; i < deltaVec.size(); ++i) {
+      double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
+      // add regularization
+      costFuncDerivative += this.regularizationWeight * lastWeightMatrix.viewRow(i).zSum();
+      deltaVec.set(i, costFuncDerivative);
+      deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
+    }
+
+    // start from previous layer of output layer
+    for (int layer = this.layerSizeList.size() - 2; layer >= 0; --layer) {
+      deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
+    }
+
+    this.prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
+
+    return weightUpdateMatrices;
+  }
+
+  /**
+   * Back-propagate the errors to from next layer to current layer. The weight
+   * updated information will be stored in the weightUpdateMatrices, and the
+   * delta of the prevLayer will be returned.
+   * 
+   * @param curLayerIdx Index of current layer.
+   * @param nextLayerDelta Delta of next layer.
+   * @param outputCache The output cache to store intermediate results.
+   * @param weightUpdateMatrix  The weight update, in form of {@link Matrix}.
+   * @return The weight updates.
+   */
+  private Vector backPropagate(int curLayerIdx, Vector nextLayerDelta,
+                               List<Vector> outputCache, Matrix weightUpdateMatrix) {
+
+    // get layer related information
+    final DoubleFunction derivativeSquashingFunction =
+        NeuralNetworkFunctions.getDerivativeDoubleFunction(this.squashingFunctionList.get(curLayerIdx));
+    Vector curLayerOutput = outputCache.get(curLayerIdx);
+    Matrix weightMatrix = this.weightMatrixList.get(curLayerIdx);
+    Matrix prevWeightMatrix = this.prevWeightUpdatesList.get(curLayerIdx);
+
+    // next layer is not output layer, remove the delta of bias neuron
+    if (curLayerIdx != this.layerSizeList.size() - 2) {
+      nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
+    }
+
+    Vector delta = weightMatrix.transpose().times(nextLayerDelta);
+
+    delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
+      @Override
+      public double apply(double deltaElem, double curLayerOutputElem) {
+        return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
+      }
+    });
+
+    // update weights
+    for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
+      for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
+        weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
+            curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
+      }
+    }
+
+    return delta;
+  }
+
+  /**
+   * Read the model meta-data from the specified location.
+   * 
+   * @throws IOException
+   */
+  protected void readFromModel() throws IOException {
+    Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
+    FSDataInputStream is = null;
+    try {
+      Path path = new Path(this.modelPath);
+      FileSystem fs = path.getFileSystem(new Configuration());
+      is = new FSDataInputStream(fs.open(path));
+      this.readFields(is);
+    } finally {
+      Closeables.close(is, true);
+    }
+  }
+
+  /**
+   * Write the model data to specified location.
+   * 
+   * @throws IOException
+   */
+  public void writeModelToFile() throws IOException {
+    Preconditions.checkArgument(this.modelPath != null, "Model path has not been set.");
+    FSDataOutputStream stream = null;
+    try {
+      Path path = new Path(this.modelPath);
+      FileSystem fs = path.getFileSystem(new Configuration());
+      stream = fs.create(path, true);
+      this.write(stream);
+    } finally {
+      Closeables.close(stream, false);
+    }
+  }
+
+  /**
+   * Set the model path.
+   * 
+   * @param modelPath The path of the model.
+   */
+  public void setModelPath(String modelPath) {
+    this.modelPath = modelPath;
+  }
+
+  /**
+   * Get the model path.
+   * 
+   * @return The path of the model.
+   */
+  public String getModelPath() {
+    return this.modelPath;
+  }
+
+  /**
+   * Write the fields of the model to output.
+   * 
+   * @param output The output instance.
+   * @throws IOException
+   */
+  public void write(DataOutput output) throws IOException {
+    // write model type
+    WritableUtils.writeString(output, modelType);
+    // write learning rate
+    output.writeDouble(learningRate);
+    // write model path
+    if (this.modelPath != null) {
+      WritableUtils.writeString(output, modelPath);
+    } else {
+      WritableUtils.writeString(output, "null");
+    }
+
+    // write regularization weight
+    output.writeDouble(this.regularizationWeight);
+    // write momentum weight
+    output.writeDouble(this.momentumWeight);
+
+    // write cost function
+    WritableUtils.writeString(output, this.costFunctionName);
+
+    // write layer size list
+    output.writeInt(this.layerSizeList.size());
+    for (Integer aLayerSizeList : this.layerSizeList) {
+      output.writeInt(aLayerSizeList);
+    }
+
+    WritableUtils.writeEnum(output, this.trainingMethod);
+
+    // write squashing functions
+    output.writeInt(this.squashingFunctionList.size());
+    for (String aSquashingFunctionList : this.squashingFunctionList) {
+      WritableUtils.writeString(output, aSquashingFunctionList);
+    }
+
+    // write weight matrices
+    output.writeInt(this.weightMatrixList.size());
+    for (Matrix aWeightMatrixList : this.weightMatrixList) {
+      MatrixWritable.writeMatrix(output, aWeightMatrixList);
+    }
+  }
+
+  /**
+   * Read the fields of the model from input.
+   * 
+   * @param input The input instance.
+   * @throws IOException
+   */
+  public void readFields(DataInput input) throws IOException {
+    // read model type
+    this.modelType = WritableUtils.readString(input);
+    if (!this.modelType.equals(this.getClass().getSimpleName())) {
+      throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
+    }
+    // read learning rate
+    this.learningRate = input.readDouble();
+    // read model path
+    this.modelPath = WritableUtils.readString(input);
+    if (this.modelPath.equals("null")) {
+      this.modelPath = null;
+    }
+
+    // read regularization weight
+    this.regularizationWeight = input.readDouble();
+    // read momentum weight
+    this.momentumWeight = input.readDouble();
+
+    // read cost function
+    this.costFunctionName = WritableUtils.readString(input);
+
+    // read layer size list
+    int numLayers = input.readInt();
+    this.layerSizeList = Lists.newArrayList();
+    for (int i = 0; i < numLayers; i++) {
+      this.layerSizeList.add(input.readInt());
+    }
+
+    this.trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
+
+    // read squash functions
+    int squashingFunctionSize = input.readInt();
+    this.squashingFunctionList = Lists.newArrayList();
+    for (int i = 0; i < squashingFunctionSize; i++) {
+      this.squashingFunctionList.add(WritableUtils.readString(input));
+    }
+
+    // read weights and construct matrices of previous updates
+    int numOfMatrices = input.readInt();
+    this.weightMatrixList = Lists.newArrayList();
+    this.prevWeightUpdatesList = Lists.newArrayList();
+    for (int i = 0; i < numOfMatrices; i++) {
+      Matrix matrix = MatrixWritable.readMatrix(input);
+      this.weightMatrixList.add(matrix);
+      this.prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
+    }
+  }
+
+}
\ No newline at end of file

Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,150 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * The functions that will be used by NeuralNetwork.
+ */
+public class NeuralNetworkFunctions {
+
+  /**
+   * The derivation of identity function (f(x) = x).
+   */
+  public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() {
+    @Override
+    public double apply(double x) {
+      return 1;
+    }
+  };
+
+  /**
+   * The derivation of minus squared function (f(t, o) = (o - t)^2).
+   */
+  public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() {
+    @Override
+    public double apply(double target, double output) {
+      return 2 * (output - target);
+    }
+  };
+
+  /**
+   * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)).
+   */
+  public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() {
+    @Override
+    public double apply(double target, double output) {
+      return -target * Math.log(output) - (1 - target) * Math.log(1 - output);
+    }
+  };
+
+  /**
+   * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) *
+   * log(1 - o)).
+   */
+  public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() {
+    @Override
+    public double apply(double target, double output) {
+      double adjustedTarget = target;
+      double adjustedActual = output;
+      if (adjustedActual == 1) {
+        adjustedActual = 0.999;
+      } else if (output == 0) {
+        adjustedActual = 0.001;
+      }
+      if (adjustedTarget == 1) {
+        adjustedTarget = 0.999;
+      } else if (adjustedTarget == 0) {
+        adjustedTarget = 0.001;
+      }
+      return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual);
+    }
+  };
+
+  /**
+   * Get the corresponding function by its name.
+   * Currently supports: "Identity", "Sigmoid".
+   * 
+   * @param function The name of the function.
+   * @return The corresponding double function.
+   */
+  public static DoubleFunction getDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Identity")) {
+      return Functions.IDENTITY;
+    } else if (function.equalsIgnoreCase("Sigmoid")) {
+      return Functions.SIGMOID;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the derivation double function by the name.
+   * Currently supports: "Identity", "Sigmoid".
+   * 
+   * @param function The name of the function.
+   * @return The double function.
+   */
+  public static DoubleFunction getDerivativeDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Identity")) {
+      return derivativeIdentityFunction;
+    } else if (function.equalsIgnoreCase("Sigmoid")) {
+      return Functions.SIGMOIDGRADIENT;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the corresponding double-double function by the name.
+   * Currently supports: "Minus_Squared", "Cross_Entropy".
+   * 
+   * @param function The name of the function.
+   * @return The double-double function.
+   */
+  public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Minus_Squared")) {
+      return Functions.MINUS_SQUARED;
+    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+      return derivativeCrossEntropy;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+  /**
+   * Get the corresponding derivation of double double function by the name.
+   * Currently supports: "Minus_Squared", "Cross_Entropy".
+   * 
+   * @param function The name of the function.
+   * @return The double-double-function.
+   */
+  public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) {
+    if (function.equalsIgnoreCase("Minus_Squared")) {
+      return derivativeMinusSquared;
+    } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+      return derivativeCrossEntropy;
+    } else {
+      throw new IllegalArgumentException("Function not supported.");
+    }
+  }
+
+}
\ No newline at end of file

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+/**
+ * Test the functionality of {@link MultilayerPerceptron}
+ */
+public class TestMultilayerPerceptron extends MahoutTestCase {
+
+  @Test
+  public void testMLP() throws IOException {
+    testMLP("testMLPXORLocal", false, false, 8000);
+    testMLP("testMLPXORLocalWithMomentum", true, false, 4000);
+    testMLP("testMLPXORLocalWithRegularization", true, true, 2000);
+  }
+
+  private void testMLP(String modelFilename, boolean useMomentum,
+      boolean useRegularization, int iterations) throws IOException {
+    MultilayerPerceptron mlp = new MultilayerPerceptron();
+    mlp.addLayer(2, false, "Sigmoid");
+    mlp.addLayer(3, false, "Sigmoid");
+    mlp.addLayer(1, true, "Sigmoid");
+    mlp.setCostFunction("Minus_Squared").setLearningRate(0.2);
+    if (useMomentum) {
+      mlp.setMomentumWeight(0.6);
+    }
+
+    if (useRegularization) {
+      mlp.setRegularizationWeight(0.01);
+    }
+
+    double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } };
+    for (int i = 0; i < iterations; ++i) {
+      for (double[] instance : instances) {
+        Vector features = new DenseVector(Arrays.copyOf(instance, instance.length - 1));
+        mlp.train((int) instance[2], features);
+      }
+    }
+
+    for (double[] instance : instances) {
+      Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+      // the expected output is the last element in array
+      double actual = instance[2];
+      double expected = mlp.getOutput(input).get(0);
+      assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+    }
+
+    // write model into file and read out
+    File modelFile = this.getTestTempFile(modelFilename);
+    mlp.setModelPath(modelFile.getAbsolutePath());
+    mlp.writeModelToFile();
+    mlp.close();
+
+    MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath());
+    // test on instances
+    for (double[] instance : instances) {
+      Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+      // the expected output is the last element in array
+      double actual = instance[2];
+      double expected = mlpCopy.getOutput(input).get(0);
+      assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+    }
+    mlpCopy.close();
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java?rev=1552403&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java Thu Dec 19 19:29:02 2013
@@ -0,0 +1,345 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.mahout.classifier.mlp.NeuralNetwork.TrainingMethod;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Files;
+
+/**
+ * Test the functionality of {@link NeuralNetwork}.
+ */
+public class TestNeuralNetwork extends MahoutTestCase {
+
+  @Test
+  public void testReadWrite() throws IOException {
+    NeuralNetwork ann = new MultilayerPerceptron();
+    ann.addLayer(2, false, "Identity");
+    ann.addLayer(5, false, "Identity");
+    ann.addLayer(1, true, "Identity");
+    ann.setCostFunction("Minus_Squared");
+    double learningRate = 0.2;
+    double momentumWeight = 0.5;
+    double regularizationWeight = 0.05;
+    ann.setLearningRate(learningRate).setMomentumWeight(momentumWeight).setRegularizationWeight(regularizationWeight);
+
+    // manually set weights
+    Matrix[] matrices = new DenseMatrix[2];
+    matrices[0] = new DenseMatrix(5, 3);
+    matrices[0].assign(0.2);
+    matrices[1] = new DenseMatrix(1, 6);
+    matrices[1].assign(0.8);
+    ann.setWeightMatrices(matrices);
+
+    // write to file
+    String modelFilename = "testNeuralNetworkReadWrite";
+    File tmpModelFile = this.getTestTempFile(modelFilename);
+    ann.setModelPath(tmpModelFile.getAbsolutePath());
+    ann.writeModelToFile();
+
+    // read from file
+    NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+    assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType());
+    assertEquals(tmpModelFile.getAbsolutePath(), annCopy.getModelPath());
+    assertEquals(learningRate, annCopy.getLearningRate(), 0.000001);
+    assertEquals(momentumWeight, annCopy.getMomentumWeight(), 0.000001);
+    assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), 0.000001);
+    assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod());
+
+    // compare weights
+    Matrix[] weightsMatrices = annCopy.getWeightMatrices();
+    for (int i = 0; i < weightsMatrices.length; ++i) {
+      Matrix expectMat = matrices[i];
+      Matrix actualMat = weightsMatrices[i];
+      for (int j = 0; j < expectMat.rowSize(); ++j) {
+        for (int k = 0; k < expectMat.columnSize(); ++k) {
+          assertEquals(expectMat.get(j, k), actualMat.get(j, k), 0.000001);
+        }
+      }
+    }
+  }
+
+  /**
+   * Test the forward functionality.
+   */
+  @Test
+  public void testOutput() {
+    // first network
+    NeuralNetwork ann = new MultilayerPerceptron();
+    ann.addLayer(2, false, "Identity");
+    ann.addLayer(5, false, "Identity");
+    ann.addLayer(1, true, "Identity");
+    ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+    // intentionally initialize all weights to 0.5
+    Matrix[] matrices = new Matrix[2];
+    matrices[0] = new DenseMatrix(5, 3);
+    matrices[0].assign(0.5);
+    matrices[1] = new DenseMatrix(1, 6);
+    matrices[1].assign(0.5);
+    ann.setWeightMatrices(matrices);
+
+    double[] arr = new double[]{0, 1};
+    Vector training = new DenseVector(arr);
+    Vector result = ann.getOutput(training);
+    assertEquals(1, result.size());
+
+    // second network
+    NeuralNetwork ann2 = new MultilayerPerceptron();
+    ann2.addLayer(2, false, "Sigmoid");
+    ann2.addLayer(3, false, "Sigmoid");
+    ann2.addLayer(1, true, "Sigmoid");
+    ann2.setCostFunction("Minus_Squared");
+    ann2.setLearningRate(0.3);
+
+    // intentionally initialize all weights to 0.5
+    Matrix[] matrices2 = new Matrix[2];
+    matrices2[0] = new DenseMatrix(3, 3);
+    matrices2[0].assign(0.5);
+    matrices2[1] = new DenseMatrix(1, 4);
+    matrices2[1].assign(0.5);
+    ann2.setWeightMatrices(matrices2);
+
+    double[] test = {0, 0};
+    double[] result2 = {0.807476};
+
+    Vector vec = ann2.getOutput(new DenseVector(test));
+    double[] arrVec = new double[vec.size()];
+    for (int i = 0; i < arrVec.length; ++i) {
+      arrVec[i] = vec.getQuick(i);
+    }
+    assertArrayEquals(result2, arrVec, 0.000001);
+
+    NeuralNetwork ann3 = new MultilayerPerceptron();
+    ann3.addLayer(2, false, "Sigmoid");
+    ann3.addLayer(3, false, "Sigmoid");
+    ann3.addLayer(1, true, "Sigmoid");
+    ann3.setCostFunction("Minus_Squared").setLearningRate(0.3);
+
+    // intentionally initialize all weights to 0.5
+    Matrix[] initMatrices = new Matrix[2];
+    initMatrices[0] = new DenseMatrix(3, 3);
+    initMatrices[0].assign(0.5);
+    initMatrices[1] = new DenseMatrix(1, 4);
+    initMatrices[1].assign(0.5);
+    ann3.setWeightMatrices(initMatrices);
+
+    double[] instance = {0, 1};
+    Vector output = ann3.getOutput(new DenseVector(instance));
+    assertEquals(0.8315410, output.get(0), 0.000001);
+  }
+
+  @Test
+  public void testNeuralNetwork() throws IOException {
+    testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000);
+    testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000);
+    testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000);
+  }
+
+  private void testNeuralNetwork(String modelFilename, boolean useMomentum,
+                                 boolean useRegularization, int iterations) throws IOException {
+    NeuralNetwork ann = new MultilayerPerceptron();
+    ann.addLayer(2, false, "Sigmoid");
+    ann.addLayer(3, false, "Sigmoid");
+    ann.addLayer(1, true, "Sigmoid");
+    ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+    if (useMomentum) {
+      ann.setMomentumWeight(0.6);
+    }
+
+    if (useRegularization) {
+      ann.setRegularizationWeight(0.01);
+    }
+
+    double[][] instances = {{0, 1, 1}, {0, 0, 0}, {1, 0, 1}, {1, 1, 0}};
+    for (int i = 0; i < iterations; ++i) {
+      for (double[] instance : instances) {
+        ann.trainOnline(new DenseVector(instance));
+      }
+    }
+
+    for (double[] instance : instances) {
+      Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+      // the expected output is the last element in array
+      double actual = instance[2];
+      double expected = ann.getOutput(input).get(0);
+      assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+    }
+
+    // write model into file and read out
+    File tmpModelFile = this.getTestTempFile(modelFilename);
+    ann.setModelPath(tmpModelFile.getAbsolutePath());
+    ann.writeModelToFile();
+
+    NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+    // test on instances
+    for (double[] instance : instances) {
+      Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+      // the expected output is the last element in array
+      double actual = instance[2];
+      double expected = annCopy.getOutput(input).get(0);
+      assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+    }
+  }
+
+  @Test
+  public void testWithCancerDataSet() throws IOException {
+    String dataSetPath = "src/test/resources/cancer.csv";
+    List<Vector> records = Lists.newArrayList();
+    // Returns a mutable list of the data
+    List<String> cancerDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8);
+    // skip the header line, hence remove the first element in the list
+    cancerDataSetList.remove(0);
+    for (String line : cancerDataSetList) {
+      String[] tokens = line.split(",");
+      double[] values = new double[tokens.length];
+      for (int i = 0; i < tokens.length; ++i) {
+        values[i] = Double.parseDouble(tokens[i]);
+      }
+      records.add(new DenseVector(values));
+    }
+
+    int splitPoint = (int) (records.size() * 0.8);
+    List<Vector> trainingSet = records.subList(0, splitPoint);
+    List<Vector> testSet = records.subList(splitPoint, records.size());
+
+    // initialize neural network model
+    NeuralNetwork ann = new MultilayerPerceptron();
+    int featureDimension = records.get(0).size() - 1;
+    ann.addLayer(featureDimension, false, "Sigmoid");
+    ann.addLayer(featureDimension * 2, false, "Sigmoid");
+    ann.addLayer(1, true, "Sigmoid");
+    ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001);
+
+    int iteration = 2000;
+    for (int i = 0; i < iteration; ++i) {
+      for (Vector trainingInstance : trainingSet) {
+        ann.trainOnline(trainingInstance);
+      }
+    }
+
+    int correctInstances = 0;
+    for (Vector testInstance : testSet) {
+      Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1));
+      double actual = res.get(0);
+      double expected = testInstance.get(testInstance.size() - 1);
+      if (Math.abs(actual - expected) <= 0.1) {
+        ++correctInstances;
+      }
+    }
+    double accuracy = (double) correctInstances / testSet.size() * 100;
+    assertTrue("The classifier is even worse than a random guesser!", accuracy > 50);
+    System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
+  }
+
+  @Test
+  public void testWithIrisDataSet() throws IOException {
+    String dataSetPath = "src/test/resources/iris.csv";
+    int numOfClasses = 3;
+    List<Vector> records = Lists.newArrayList();
+    // Returns a mutable list of the data
+    List<String> irisDataSetList = Files.readLines(new File(dataSetPath), Charsets.UTF_8);
+    // skip the header line, hence remove the first element in the list
+    irisDataSetList.remove(0);
+
+    for (String line : irisDataSetList) {
+      String[] tokens = line.split(",");
+      // last three dimensions represent the labels
+      double[] values = new double[tokens.length + numOfClasses - 1];
+      Arrays.fill(values, 0.0);
+      for (int i = 0; i < tokens.length - 1; ++i) {
+        values[i] = Double.parseDouble(tokens[i]);
+      }
+      // add label values
+      String label = tokens[tokens.length - 1];
+      if (label.equalsIgnoreCase("setosa")) {
+        values[values.length - 3] = 1;
+      } else if (label.equalsIgnoreCase("versicolor")) {
+        values[values.length - 2] = 1;
+      } else { // label 'virginica'
+        values[values.length - 1] = 1;
+      }
+      records.add(new DenseVector(values));
+    }
+
+    Collections.shuffle(records);
+
+    int splitPoint = (int) (records.size() * 0.8);
+    List<Vector> trainingSet = records.subList(0, splitPoint);
+    List<Vector> testSet = records.subList(splitPoint, records.size());
+
+    // initialize neural network model
+    NeuralNetwork ann = new MultilayerPerceptron();
+    int featureDimension = records.get(0).size() - numOfClasses;
+    ann.addLayer(featureDimension, false, "Sigmoid");
+    ann.addLayer(featureDimension * 2, false, "Sigmoid");
+    ann.addLayer(3, true, "Sigmoid"); // 3-class classification
+    ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005);
+
+    int iteration = 2000;
+    for (int i = 0; i < iteration; ++i) {
+      for (Vector trainingInstance : trainingSet) {
+        ann.trainOnline(trainingInstance);
+      }
+    }
+
+    int correctInstances = 0;
+    for (Vector testInstance : testSet) {
+      Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses));
+      double[] actualLabels = new double[numOfClasses];
+      for (int i = 0; i < numOfClasses; ++i) {
+        actualLabels[i] = res.get(i);
+      }
+      double[] expectedLabels = new double[numOfClasses];
+      for (int i = 0; i < numOfClasses; ++i) {
+        expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i);
+      }
+
+      boolean allCorrect = true;
+      for (int i = 0; i < numOfClasses; ++i) {
+        if (Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1) {
+          allCorrect = false;
+          break;
+        }
+      }
+      if (allCorrect) {
+        ++correctInstances;
+      }
+    }
+    
+    double accuracy = (double) correctInstances / testSet.size() * 100;
+    assertTrue("The model is even worse than a random guesser.", accuracy > 50);
+    
+    System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
+  }
+  
+}