You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/03/24 08:04:51 UTC
svn commit: r1460269 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/
core/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/neuron/
core/src/main/java/org/apache/yay/utils/ core/src/test/java/org/apache/yay/
Author: tommaso
Date: Sun Mar 24 07:04:50 2013
New Revision: 1460269
URL: http://svn.apache.org/r1460269
Log:
api refactoring
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java
labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java
labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java Sun Mar 24 07:04:50 2013
@@ -21,15 +21,14 @@ package org.apache.yay;
import java.util.Collection;
/**
- * A cost function calculates the cost of using a specified model (via its
- * {@link ActivationFunction}) for fitting the given corpus (a {@link Collection}
+ * A cost function calculates the cost of using a specified {@link Hypothesis})
+ * for fitting the given corpus (a {@link Collection}
* of {@link TrainingExample}s).
- *
*/
public interface CostFunction<T, I, O> {
- public Double calculateAggregatedCost(Collection<TrainingExample<I, O>> trainingExamples,
- Hypothesis<T, I, O> hypothesis) throws Exception;
+ public Double calculateAggregatedCost(TrainingSet<I, O> trainingExamples,
+ Hypothesis<T, I, O> hypothesis) throws Exception;
public Double calculateCost(TrainingExample<I, O> trainingExample,
Hypothesis<T, I, O> hypothesis) throws Exception;
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java Sun Mar 24 07:04:50 2013
@@ -26,13 +26,12 @@ import java.util.Collection;
*/
public interface ErrorFunction<T, I, O> {
+ public Double calculateSingleError(O output, TrainingExample<I, O> trainingExample,
+ CostFunction<T, I, O> costFunction,
+ Hypothesis<T, I, O> hypothesis);
- public Double calculateSingleError(O output, TrainingExample<I, O> trainingExample,
- CostFunction<T, I, O> costFunction,
- Hypothesis<T, I, O> hypothesis);
-
- public Double calculateAggregateError(O output, TrainingSet<I, O> trainingSet,
- CostFunction<T, I, O> costFunction,
- Hypothesis<T, I, O> hypothesis);
+ public Double calculateAggregateError(O output, TrainingSet<I, O> trainingExamples,
+ CostFunction<T, I, O> costFunction,
+ Hypothesis<T, I, O> hypothesis);
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Mar 24 07:04:50 2013
@@ -24,8 +24,12 @@ package org.apache.yay;
*/
public interface Hypothesis<T, I, O> {
- public void setParameters(T... parameters);
+ public void setParameters(T... parameters);
- public O predict(Input<I> input) throws PredictionException;
+ T[] getParameters();
+
+ public O predict(Input<I> input) throws PredictionException;
+
+ public void learn(TrainingSet<I, O> trainingExamples) throws LearningException;
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java Sun Mar 24 07:04:50 2013
@@ -18,14 +18,12 @@
*/
package org.apache.yay;
-import java.util.Collection;
-
/**
- * Add javadoc here
+ * A factory class for {@link Hypothesis}.
*/
public interface HypothesisFactory {
- <T, I, O> Hypothesis<T, I, O> createHypothesis(TrainingSet<I, O> trainingSet,
+ <T, I, O> Hypothesis<T, I, O> createHypothesis(TrainingSet<I, O> trainingExamples,
LearningStrategy<I, O> learningStrategy,
T... parameters);
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java Sun Mar 24 07:04:50 2013
@@ -20,14 +20,12 @@ package org.apache.yay;
import org.apache.commons.math3.linear.RealMatrix;
-import java.util.Collection;
-
/**
* A {@link LearningStrategy}<F,O> defines a learning algorithm to learn the weights of the neural network's layer
*/
public interface LearningStrategy<F, O> {
public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<F, O>
- trainingSet) throws WeightLearningException;
+ trainingExamples) throws WeightLearningException;
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Mar 24 07:04:50 2013
@@ -18,13 +18,11 @@
*/
package org.apache.yay;
+import org.apache.commons.math3.linear.RealMatrix;
+
/**
* A neural network is a layered connected graph of elaboration units
*/
-public interface NeuralNetwork<I, O> {
-
- public void learn(TrainingSet<I, O> trainingSet) throws LearningException;
-
- public O predict(Input<I> input) throws PredictionException;
+public interface NeuralNetwork<I, O> extends Hypothesis<RealMatrix, I, O>{
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Sun Mar 24 07:04:50 2013
@@ -24,8 +24,6 @@ import org.apache.commons.math3.linear.R
import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.utils.ConversionUtils;
-import java.util.Collection;
-
/**
* Back propagation learning algorithm for neural networks implementation (see
* <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
@@ -33,9 +31,9 @@ import java.util.Collection;
public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
private final PredictionStrategy<Double, Double[]> predictionStrategy;
- private CostFunction<RealMatrix, Double> costFunction;
+ private CostFunction<RealMatrix, Double, Double> costFunction;
- public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
+ public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
this.predictionStrategy = predictionStrategy;
this.costFunction = costFunction;
}
@@ -44,6 +42,8 @@ public class BackPropagationLearningStra
public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double[]> trainingExamples) throws WeightLearningException {
// set up the accumulator matrix(es)
RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+
+ int count = 0;
for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
try {
// contains activation errors for the current training example
@@ -82,10 +82,11 @@ public class BackPropagationLearningStra
} catch (Exception e) {
throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
}
+ count++;
}
for (int i = 0; i < triangle.length; i++) {
// TODO : introduce regularization diversification on bias term (currently not regularized)
- triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
+ triangle[i] = triangle[i].scalarMultiply(1 / count);
}
// TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java Sun Mar 24 07:04:50 2013
@@ -18,11 +18,13 @@
*/
package org.apache.yay;
-import java.util.Collection;
-
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
import org.apache.yay.neuron.BinaryThresholdNeuron;
import org.apache.yay.utils.ConversionUtils;
+import java.util.Collection;
+
/**
* A perceptron {@link NeuralNetwork} implementation based on
* {@link org.apache.yay.neuron.BinaryThresholdNeuron}s
@@ -31,7 +33,7 @@ public class BasicPerceptron implements
private final BinaryThresholdNeuron perceptronNeuron;
- private final Double[] currentWeights;
+ private double[] currentWeights;
/**
* Create a perceptron given its input weights. Assume bias weight is given and all the input
@@ -39,7 +41,7 @@ public class BasicPerceptron implements
*
* @param inputWeights the array of starting weights for the perceptron
*/
- public BasicPerceptron(Double... inputWeights) {
+ public BasicPerceptron(double... inputWeights) {
this.perceptronNeuron = new BinaryThresholdNeuron(0d, inputWeights);
this.currentWeights = inputWeights;
}
@@ -65,6 +67,18 @@ public class BasicPerceptron implements
}
@Override
+ public void setParameters(RealMatrix... parameters) {
+ assert parameters.length == 1 : "a perceptron has only one layer";
+
+ this.currentWeights = parameters[0].getRow(0);
+ }
+
+ @Override
+ public RealMatrix[] getParameters() {
+ return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+ }
+
+ @Override
public Double predict(Input<Double> input) throws PredictionException {
return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
new Double[input.getFeatures().size()]));
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java Sun Mar 24 07:04:50 2013
@@ -20,12 +20,10 @@ package org.apache.yay;
import org.apache.commons.math3.linear.RealMatrix;
-import java.util.Collection;
-
/**
* This calculates the logistic regression cost function for neural networks
*/
-public class LogisticRegressionCostFunction implements CostFunction<RealMatrix, Double> {
+public class LogisticRegressionCostFunction implements CostFunction<RealMatrix, Double, Double> {
private final Double lambda;
@@ -34,18 +32,18 @@ public class LogisticRegressionCostFunct
}
@Override
- public Double calculateCost(Collection<TrainingExample<Double, Double>> trainingExamples,
- ActivationFunction<Double> hypothesis, RealMatrix... parameters) throws Exception {
+ public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples,
+ Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
- Double errorTerm = calculateErrorTerm(parameters, hypothesis, trainingExamples);
- Double regularizationTerm = calculateRegularizationTerm(parameters, trainingExamples);
+ Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples);
+ Double regularizationTerm = calculateRegularizationTerm(hypothesis, trainingExamples);
return errorTerm + regularizationTerm;
}
- private Double calculateRegularizationTerm(RealMatrix[] parameters,
- Collection<TrainingExample<Double, Double>> trainingExamples) {
+ private Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, Double> hypothesis,
+ TrainingSet<Double, Double> trainingExamples) {
Double res = 1d;
- for (RealMatrix layerMatrix : parameters) {
+ for (RealMatrix layerMatrix : hypothesis.getParameters()) {
for (int i = 0; i < layerMatrix.getColumnDimension(); i++) {
double[] column = layerMatrix.getColumn(i);
// starting from 1 to avoid including the bias unit in regularization
@@ -57,22 +55,23 @@ public class LogisticRegressionCostFunct
return (lambda / (2d * trainingExamples.size())) * res;
}
- private Double calculateErrorTerm(RealMatrix[] parameters,
- ActivationFunction<Double> hypothesis,
- Collection<TrainingExample<Double, Double>> trainingExamples) throws PredictionException,
+ private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis,
+ TrainingSet<Double, Double> trainingExamples) throws PredictionException,
CreationException {
Double res = 0d;
- NeuralNetwork<Double, Double> neuralNetwork = NeuralNetworkFactory.create(trainingExamples,
- parameters, new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
- hypothesis));
for (TrainingExample<Double, Double> input : trainingExamples) {
// TODO : handle this for multiple outputs (multi class classification)
- Double predictedOutput = neuralNetwork.predict(input);
+ Double predictedOutput = hypothesis.predict(input);
Double sampleOutput = input.getOutput();
res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput)
* Math.log(1d - predictedOutput);
}
return (-1d / trainingExamples.size()) * res;
}
+
+ @Override
+ public Double calculateCost(TrainingExample<Double, Double> trainingExample, Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
+ return null;
+ }
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java Sun Mar 24 07:04:50 2013
@@ -32,30 +32,38 @@ public class NeuralNetworkFactory {
* creates a neural network via a supervised learning method, given a training set, the initial set of layers defined
* by a set of matrices, the learning and prediction strategies to be used.
*
- * @param trainingExamples the training set
- * @param RealMatrixSet the initial settings for weights matrices
+ * @param realMatrixSet the initial settings for weights matrices
* @param learningStrategy a learning strategy
* @param predictionStrategy a prediction strategy
* @return a NeuralNetwork instance
* @throws CreationException
*/
- public static NeuralNetwork<Double, Double> create(final Collection<TrainingExample<Double, Double>> trainingExamples,
- final RealMatrix[] RealMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
+ public static NeuralNetwork<Double, Double> create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
return new NeuralNetwork<Double, Double>() {
- private RealMatrix[] updatedRealMatrixSet = RealMatrixSet;
+ private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
@Override
public void learn(TrainingSet<Double, Double> samples) throws LearningException {
try {
- updatedRealMatrixSet = learningStrategy.learnWeights(RealMatrixSet, samples);
+ updatedRealMatrixSet = learningStrategy.learnWeights(realMatrixSet, samples);
} catch (WeightLearningException e) {
throw new LearningException(e);
}
}
@Override
+ public void setParameters(RealMatrix... parameters) {
+ updatedRealMatrixSet = parameters;
+ }
+
+ @Override
+ public RealMatrix[] getParameters() {
+ return updatedRealMatrixSet;
+ }
+
+ @Override
public Double predict(Input<Double> input) throws PredictionException {
try {
Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java Sun Mar 24 07:04:50 2013
@@ -30,14 +30,14 @@ import org.apache.yay.StepActivationFunc
*/
public class BinaryThresholdNeuron extends Neuron<Double> {
- private Double[] weights;
+ private double[] weights;
- public BinaryThresholdNeuron(Double threshold, Double... weights) {
+ public BinaryThresholdNeuron(double threshold, double... weights) {
super(new StepActivationFunction(threshold));
this.weights = weights;
}
- public void updateWeights(Double... weights) {
+ public void updateWeights(double... weights) {
this.weights = weights;
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java Sun Mar 24 07:04:50 2013
@@ -25,7 +25,7 @@ import org.apache.yay.TrainingExample;
import java.util.ArrayList;
/**
- * Factory class for {@link org.apache.yay.Input}s
+ * Factory class for {@link org.apache.yay.Input}s and {@link TrainingExample}s.
*/
public class ExamplesFactory {
@@ -44,7 +44,7 @@ public class ExamplesFactory {
};
}
- public static Input<Double> createDoubleExample(final Double... featuresValues) {
+ public static Input<Double> createDoubleInput(final Double... featuresValues) {
return new Input<Double>() {
@Override
public ArrayList<Feature<Double>> getFeatures() {
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java Sun Mar 24 07:04:50 2013
@@ -21,6 +21,7 @@ package org.apache.yay;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.yay.utils.ExamplesFactory;
+import org.junit.Before;
import org.junit.Test;
import java.util.Collection;
@@ -33,9 +34,13 @@ import static org.junit.Assert.assertTru
*/
public class LogisticRegressionCostFunctionTest {
- @Test
- public void testORParametersCost() throws Exception {
- CostFunction<RealMatrix, Double> costFunction = new LogisticRegressionCostFunction(0.1d);
+ private CostFunction<RealMatrix,Double,Double> costFunction;
+ private TrainingSet trainingSet;
+
+ @Before
+ public void setUp() throws Exception {
+
+ costFunction = new LogisticRegressionCostFunction(0.1d);
Collection<TrainingExample<Double, Double>> trainingExamples = new LinkedList<TrainingExample<Double, Double>>();
TrainingExample<Double, Double> example1 = ExamplesFactory.createDoubleTrainingExample(1d, 0d, 1d);
TrainingExample<Double, Double> example2 = ExamplesFactory.createDoubleTrainingExample(1d, 1d, 1d);
@@ -45,11 +50,23 @@ public class LogisticRegressionCostFunct
trainingExamples.add(example2);
trainingExamples.add(example3);
trainingExamples.add(example4);
+ trainingSet = new TrainingSet(trainingExamples);
+
+
+ }
+
+ @Test
+ public void testORParametersCost() throws Exception {
double[][] weights = {{-10d, 20d, 20d}};
RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
- RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
- Double cost = costFunction.calculateCost(trainingExamples, new SigmoidFunction(),
- orWeightsMatrixSet);
+ final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
+
+ final NeuralNetwork<Double,Double> neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
+ new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
+ new SigmoidFunction()));
+
+
+ Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork);
assertTrue("cost should not be negative", cost > 0d);
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java Sun Mar 24 07:04:50 2013
@@ -23,7 +23,6 @@ import org.apache.commons.math3.linear.R
import org.junit.Test;
import java.util.ArrayList;
-import java.util.LinkedList;
import static org.junit.Assert.assertEquals;
@@ -89,10 +88,9 @@ public class NeuralNetworkFactoryTest {
assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
}
- private NeuralNetwork<Double, Double> createFFNN(RealMatrix[] andRealMatrixSet)
+ private NeuralNetwork<Double, Double> createFFNN(RealMatrix[] realMatrixes)
throws CreationException {
- return NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(),
- andRealMatrixSet, new VoidLearningStrategy<Double, Double>(),
+ return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(),
new FeedForwardStrategy(new SigmoidFunction()));
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org