You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2013/03/24 08:04:51 UTC

svn commit: r1460269 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/neuron/ core/src/main/java/org/apache/yay/utils/ core/src/test/java/org/apache/yay/

Author: tommaso
Date: Sun Mar 24 07:04:50 2013
New Revision: 1460269

URL: http://svn.apache.org/r1460269
Log:
api refactoring

Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java Sun Mar 24 07:04:50 2013
@@ -21,15 +21,14 @@ package org.apache.yay;
 import java.util.Collection;
 
 /**
- * A cost function calculates the cost of using a specified model (via its
- * {@link ActivationFunction}) for fitting the given corpus (a {@link Collection}
+ * A cost function calculates the cost of using a specified {@link Hypothesis})
+ * for fitting the given corpus (a {@link Collection}
  * of {@link TrainingExample}s).
- *
  */
 public interface CostFunction<T, I, O> {
 
-  public Double calculateAggregatedCost(Collection<TrainingExample<I, O>> trainingExamples,
-                              Hypothesis<T, I, O> hypothesis) throws Exception;
+  public Double calculateAggregatedCost(TrainingSet<I, O> trainingExamples,
+                                        Hypothesis<T, I, O> hypothesis) throws Exception;
 
   public Double calculateCost(TrainingExample<I, O> trainingExample,
                               Hypothesis<T, I, O> hypothesis) throws Exception;

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/ErrorFunction.java Sun Mar 24 07:04:50 2013
@@ -26,13 +26,12 @@ import java.util.Collection;
  */
 public interface ErrorFunction<T, I, O> {
 
+  public Double calculateSingleError(O output, TrainingExample<I, O> trainingExample,
+                                     CostFunction<T, I, O> costFunction,
+                                     Hypothesis<T, I, O> hypothesis);
 
-    public Double calculateSingleError(O output, TrainingExample<I, O> trainingExample,
-                                       CostFunction<T, I, O> costFunction,
-                                       Hypothesis<T, I, O> hypothesis);
-
-    public Double calculateAggregateError(O output, TrainingSet<I, O> trainingSet,
-                                          CostFunction<T, I, O> costFunction,
-                                          Hypothesis<T, I, O> hypothesis);
+  public Double calculateAggregateError(O output, TrainingSet<I, O> trainingExamples,
+                                        CostFunction<T, I, O> costFunction,
+                                        Hypothesis<T, I, O> hypothesis);
 
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Mar 24 07:04:50 2013
@@ -24,8 +24,12 @@ package org.apache.yay;
  */
 public interface Hypothesis<T, I, O> {
 
-    public void setParameters(T... parameters);
+  public void setParameters(T... parameters);
 
-    public O predict(Input<I> input) throws PredictionException;
+  T[] getParameters();
+
+  public O predict(Input<I> input) throws PredictionException;
+
+  public void learn(TrainingSet<I, O> trainingExamples) throws LearningException;
 
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java Sun Mar 24 07:04:50 2013
@@ -18,14 +18,12 @@
  */
 package org.apache.yay;
 
-import java.util.Collection;
-
 /**
- * Add javadoc here
+ * A factory class for {@link Hypothesis}.
  */
 public interface HypothesisFactory {
 
-  <T, I, O> Hypothesis<T, I, O> createHypothesis(TrainingSet<I, O> trainingSet,
+  <T, I, O> Hypothesis<T, I, O> createHypothesis(TrainingSet<I, O> trainingExamples,
                                                  LearningStrategy<I, O> learningStrategy,
                                                  T... parameters);
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/LearningStrategy.java Sun Mar 24 07:04:50 2013
@@ -20,14 +20,12 @@ package org.apache.yay;
 
 import org.apache.commons.math3.linear.RealMatrix;
 
-import java.util.Collection;
-
 /**
  * A {@link LearningStrategy}<F,O> defines a learning algorithm to learn the weights of the neural network's layer
  */
 public interface LearningStrategy<F, O> {
 
   public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<F, O>
-         trainingSet) throws WeightLearningException;
+          trainingExamples) throws WeightLearningException;
 
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Mar 24 07:04:50 2013
@@ -18,13 +18,11 @@
  */
 package org.apache.yay;
 
+import org.apache.commons.math3.linear.RealMatrix;
+
 /**
  * A neural network is a layered connected graph of elaboration units
  */
-public interface NeuralNetwork<I, O> {
-
-  public void learn(TrainingSet<I, O> trainingSet) throws LearningException;
-
-  public O predict(Input<I> input) throws PredictionException;
+public interface NeuralNetwork<I, O> extends Hypothesis<RealMatrix, I, O>{
 
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java Sun Mar 24 07:04:50 2013
@@ -24,8 +24,6 @@ import org.apache.commons.math3.linear.R
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.utils.ConversionUtils;
 
-import java.util.Collection;
-
 /**
  * Back propagation learning algorithm for neural networks implementation (see
  * <code>http://en.wikipedia.org/wiki/Backpropagation</code>).
@@ -33,9 +31,9 @@ import java.util.Collection;
 public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double[]> {
 
   private final PredictionStrategy<Double, Double[]> predictionStrategy;
-  private CostFunction<RealMatrix, Double> costFunction;
+  private CostFunction<RealMatrix, Double, Double> costFunction;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
+  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
   }
@@ -44,6 +42,8 @@ public class BackPropagationLearningStra
   public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double[]> trainingExamples) throws WeightLearningException {
     // set up the accumulator matrix(es)
     RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+
+    int count = 0;
     for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) {
       try {
         // contains activation errors for the current training example
@@ -82,10 +82,11 @@ public class BackPropagationLearningStra
       } catch (Exception e) {
         throw new WeightLearningException("error during phase 1 of back-propagation algorithm", e);
       }
+      count++;
     }
     for (int i = 0; i < triangle.length; i++) {
       // TODO : introduce regularization diversification on bias term (currently not regularized)
-      triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
+      triangle[i] = triangle[i].scalarMultiply(1 / count);
     }
 
     // TODO : now apply gradient descent (or other optimization/minimization algorithms) with this derivative terms and the cost function

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/BasicPerceptron.java Sun Mar 24 07:04:50 2013
@@ -18,11 +18,13 @@
  */
 package org.apache.yay;
 
-import java.util.Collection;
-
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.neuron.BinaryThresholdNeuron;
 import org.apache.yay.utils.ConversionUtils;
 
+import java.util.Collection;
+
 /**
  * A perceptron {@link NeuralNetwork} implementation based on
  * {@link org.apache.yay.neuron.BinaryThresholdNeuron}s
@@ -31,7 +33,7 @@ public class BasicPerceptron implements 
 
   private final BinaryThresholdNeuron perceptronNeuron;
 
-  private final Double[] currentWeights;
+  private double[] currentWeights;
 
   /**
    * Create a perceptron given its input weights. Assume bias weight is given and all the input
@@ -39,7 +41,7 @@ public class BasicPerceptron implements 
    *
    * @param inputWeights the array of starting weights for the perceptron
    */
-  public BasicPerceptron(Double... inputWeights) {
+  public BasicPerceptron(double... inputWeights) {
     this.perceptronNeuron = new BinaryThresholdNeuron(0d, inputWeights);
     this.currentWeights = inputWeights;
   }
@@ -65,6 +67,18 @@ public class BasicPerceptron implements 
   }
 
   @Override
+  public void setParameters(RealMatrix... parameters) {
+    assert parameters.length == 1 : "a perceptron has only one layer";
+
+    this.currentWeights = parameters[0].getRow(0);
+  }
+
+  @Override
+  public RealMatrix[] getParameters() {
+    return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
+  }
+
+  @Override
   public Double predict(Input<Double> input) throws PredictionException {
     return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
             new Double[input.getFeatures().size()]));

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java Sun Mar 24 07:04:50 2013
@@ -20,12 +20,10 @@ package org.apache.yay;
 
 import org.apache.commons.math3.linear.RealMatrix;
 
-import java.util.Collection;
-
 /**
  * This calculates the logistic regression cost function for neural networks
  */
-public class LogisticRegressionCostFunction implements CostFunction<RealMatrix, Double> {
+public class LogisticRegressionCostFunction implements CostFunction<RealMatrix, Double, Double> {
 
   private final Double lambda;
 
@@ -34,18 +32,18 @@ public class LogisticRegressionCostFunct
   }
 
   @Override
-  public Double calculateCost(Collection<TrainingExample<Double, Double>> trainingExamples,
-                              ActivationFunction<Double> hypothesis, RealMatrix... parameters) throws Exception {
+  public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples,
+                              Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
 
-    Double errorTerm = calculateErrorTerm(parameters, hypothesis, trainingExamples);
-    Double regularizationTerm = calculateRegularizationTerm(parameters, trainingExamples);
+    Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples);
+    Double regularizationTerm = calculateRegularizationTerm(hypothesis, trainingExamples);
     return errorTerm + regularizationTerm;
   }
 
-  private Double calculateRegularizationTerm(RealMatrix[] parameters,
-                                             Collection<TrainingExample<Double, Double>> trainingExamples) {
+  private Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, Double> hypothesis,
+                                             TrainingSet<Double, Double> trainingExamples) {
     Double res = 1d;
-    for (RealMatrix layerMatrix : parameters) {
+    for (RealMatrix layerMatrix : hypothesis.getParameters()) {
       for (int i = 0; i < layerMatrix.getColumnDimension(); i++) {
         double[] column = layerMatrix.getColumn(i);
         // starting from 1 to avoid including the bias unit in regularization
@@ -57,22 +55,23 @@ public class LogisticRegressionCostFunct
     return (lambda / (2d * trainingExamples.size())) * res;
   }
 
-  private Double calculateErrorTerm(RealMatrix[] parameters,
-                                    ActivationFunction<Double> hypothesis,
-                                    Collection<TrainingExample<Double, Double>> trainingExamples) throws PredictionException,
+  private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis,
+                                    TrainingSet<Double, Double> trainingExamples) throws PredictionException,
           CreationException {
     Double res = 0d;
-    NeuralNetwork<Double, Double> neuralNetwork = NeuralNetworkFactory.create(trainingExamples,
-            parameters, new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
-            hypothesis));
 
     for (TrainingExample<Double, Double> input : trainingExamples) {
       // TODO : handle this for multiple outputs (multi class classification)
-      Double predictedOutput = neuralNetwork.predict(input);
+      Double predictedOutput = hypothesis.predict(input);
       Double sampleOutput = input.getOutput();
       res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput)
               * Math.log(1d - predictedOutput);
     }
     return (-1d / trainingExamples.size()) * res;
   }
+
+  @Override
+  public Double calculateCost(TrainingExample<Double, Double> trainingExample, Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
+    return null;
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java Sun Mar 24 07:04:50 2013
@@ -32,30 +32,38 @@ public class NeuralNetworkFactory {
    * creates a neural network via a supervised learning method, given a training set, the initial set of layers defined
    * by a set of matrices, the learning and prediction strategies to be used.
    *
-   * @param trainingExamples   the training set
-   * @param RealMatrixSet      the initial settings for weights matrices
+   * @param realMatrixSet      the initial settings for weights matrices
    * @param learningStrategy   a learning strategy
    * @param predictionStrategy a prediction strategy
    * @return a NeuralNetwork instance
    * @throws CreationException
    */
-  public static NeuralNetwork<Double, Double> create(final Collection<TrainingExample<Double, Double>> trainingExamples,
-                                                     final RealMatrix[] RealMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
+  public static NeuralNetwork<Double, Double> create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy,
                                                      final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException {
     return new NeuralNetwork<Double, Double>() {
 
-      private RealMatrix[] updatedRealMatrixSet = RealMatrixSet;
+      private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
 
       @Override
       public void learn(TrainingSet<Double, Double> samples) throws LearningException {
         try {
-          updatedRealMatrixSet = learningStrategy.learnWeights(RealMatrixSet, samples);
+          updatedRealMatrixSet = learningStrategy.learnWeights(realMatrixSet, samples);
         } catch (WeightLearningException e) {
           throw new LearningException(e);
         }
       }
 
       @Override
+      public void setParameters(RealMatrix... parameters) {
+        updatedRealMatrixSet = parameters;
+      }
+
+      @Override
+      public RealMatrix[] getParameters() {
+        return updatedRealMatrixSet;
+      }
+
+      @Override
       public Double predict(Input<Double> input) throws PredictionException {
         try {
           Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures());

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/neuron/BinaryThresholdNeuron.java Sun Mar 24 07:04:50 2013
@@ -30,14 +30,14 @@ import org.apache.yay.StepActivationFunc
  */
 public class BinaryThresholdNeuron extends Neuron<Double> {
 
-  private Double[] weights;
+  private double[] weights;
 
-  public BinaryThresholdNeuron(Double threshold, Double... weights) {
+  public BinaryThresholdNeuron(double threshold, double... weights) {
     super(new StepActivationFunction(threshold));
     this.weights = weights;
   }
 
-  public void updateWeights(Double... weights) {
+  public void updateWeights(double... weights) {
     this.weights = weights;
   }
 

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ExamplesFactory.java Sun Mar 24 07:04:50 2013
@@ -25,7 +25,7 @@ import org.apache.yay.TrainingExample;
 import java.util.ArrayList;
 
 /**
- * Factory class for {@link org.apache.yay.Input}s
+ * Factory class for {@link org.apache.yay.Input}s and {@link TrainingExample}s.
  */
 public class ExamplesFactory {
 
@@ -44,7 +44,7 @@ public class ExamplesFactory {
     };
   }
 
-  public static Input<Double> createDoubleExample(final Double... featuresValues) {
+  public static Input<Double> createDoubleInput(final Double... featuresValues) {
     return new Input<Double>() {
       @Override
       public ArrayList<Feature<Double>> getFeatures() {

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java Sun Mar 24 07:04:50 2013
@@ -21,6 +21,7 @@ package org.apache.yay;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.yay.utils.ExamplesFactory;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.util.Collection;
@@ -33,9 +34,13 @@ import static org.junit.Assert.assertTru
  */
 public class LogisticRegressionCostFunctionTest {
 
-  @Test
-  public void testORParametersCost() throws Exception {
-    CostFunction<RealMatrix, Double> costFunction = new LogisticRegressionCostFunction(0.1d);
+  private CostFunction<RealMatrix,Double,Double> costFunction;
+  private TrainingSet trainingSet;
+
+  @Before
+  public void setUp() throws Exception {
+
+    costFunction = new LogisticRegressionCostFunction(0.1d);
     Collection<TrainingExample<Double, Double>> trainingExamples = new LinkedList<TrainingExample<Double, Double>>();
     TrainingExample<Double, Double> example1 = ExamplesFactory.createDoubleTrainingExample(1d, 0d, 1d);
     TrainingExample<Double, Double> example2 = ExamplesFactory.createDoubleTrainingExample(1d, 1d, 1d);
@@ -45,11 +50,23 @@ public class LogisticRegressionCostFunct
     trainingExamples.add(example2);
     trainingExamples.add(example3);
     trainingExamples.add(example4);
+    trainingSet = new TrainingSet(trainingExamples);
+
+
+  }
+
+  @Test
+  public void testORParametersCost() throws Exception {
     double[][] weights = {{-10d, 20d, 20d}};
     RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights);
-    RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
-    Double cost = costFunction.calculateCost(trainingExamples, new SigmoidFunction(),
-            orWeightsMatrixSet);
+    final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights};
+
+    final NeuralNetwork<Double,Double> neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet,
+            new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(
+            new SigmoidFunction()));
+
+
+    Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork);
     assertTrue("cost should not be negative", cost > 0d);
   }
 

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java?rev=1460269&r1=1460268&r2=1460269&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/NeuralNetworkFactoryTest.java Sun Mar 24 07:04:50 2013
@@ -23,7 +23,6 @@ import org.apache.commons.math3.linear.R
 import org.junit.Test;
 
 import java.util.ArrayList;
-import java.util.LinkedList;
 
 import static org.junit.Assert.assertEquals;
 
@@ -89,10 +88,9 @@ public class NeuralNetworkFactoryTest {
     assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
   }
 
-  private NeuralNetwork<Double, Double> createFFNN(RealMatrix[] andRealMatrixSet)
+  private NeuralNetwork<Double, Double> createFFNN(RealMatrix[] realMatrixes)
           throws CreationException {
-    return NeuralNetworkFactory.create(new LinkedList<TrainingExample<Double, Double>>(),
-            andRealMatrixSet, new VoidLearningStrategy<Double, Double>(),
+    return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(),
             new FeedForwardStrategy(new SigmoidFunction()));
   }
 



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org