You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/18 15:59:40 UTC

svn commit: r1709279 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/ core/src/main/java/org/apache/yay/core/ core/src/test/java/org/apache/yay/core/

Author: tommaso
Date: Sun Oct 18 13:59:39 2015
New Revision: 1709279

URL: http://svn.apache.org/viewvc?rev=1709279&view=rev
Log:
update NN api

Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct 18 13:59:39 2015
@@ -47,12 +47,4 @@ public interface Hypothesis<T, I, O> {
    */
   O[] predict(Input<I> input) throws PredictionException;
 
-  /**
-   * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}
-   *
-   * @param trainingExamples the learning {@link org.apache.yay.TrainingSet}
-   * @throws LearningException if any error occurs during the learning phase
-   */
-  void learn(TrainingSet<I, O> trainingExamples) throws LearningException;
-
 }

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct 18 13:59:39 2015
@@ -27,4 +27,13 @@ import org.apache.commons.math3.linear.R
  */
 public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
 
+  /**
+   * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}
+   *
+   * @param trainingExamples the learning {@link org.apache.yay.TrainingSet}
+   * @return the learned weights
+   * @throws LearningException if any error occurs during the learning phase
+   */
+  RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException;
+
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct 18 13:59:39 2015
@@ -52,10 +52,11 @@ public class BasicPerceptron implements
   }
 
   @Override
-  public void learn(TrainingSet<Double, Double> trainingExamples) throws LearningException {
+  public RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException {
     for (TrainingExample<Double, Double> example : trainingExamples) {
       learn(example);
     }
+    return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
   }
 
   public void learn(TrainingExample<Double, Double> example) {

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct 18 13:59:39 2015
@@ -56,9 +56,10 @@ class NeuralNetworkFactory {
       private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
 
       @Override
-      public void learn(TrainingSet<Double, Double> samples) throws LearningException {
+      public RealMatrix[] learn(TrainingSet<Double, Double> samples) throws LearningException {
         try {
           updatedRealMatrixSet = learningStrategy.learnWeights(realMatrixSet, samples);
+          return updatedRealMatrixSet;
         } catch (WeightLearningException e) {
           throw new LearningException(e);
         }

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Sun Oct 18 13:59:39 2015
@@ -66,24 +66,29 @@ public class WordVectorsTest {
     assertFalse(fragments.isEmpty());
 
     TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
-
     TrainingExample<Double, Double> next = trainingSet.iterator().next();
+
     int inputSize = next.getFeatures().size() ;
     int outputSize = next.getOutput().length;
     int hiddenSize = new Random().nextInt(50) + 15;
+    System.err.println("i:"+inputSize+",h:"+hiddenSize+",o:"+outputSize);
     RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
 
     Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>();
     activationFunctions.put(0, new IdentityActivationFunction<Double>());
     activationFunctions.put(1, new SoftmaxActivationFunction());
     FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
-    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 10,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(), 10);
+    BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 1,
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), 10);
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
 
-    neuralNetwork.learn(trainingSet);
+    RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet);
+
+    RealMatrix wordVectors = learnedWeights[learnedWeights.length - 1];
+
+    assertNotNull(wordVectors);
 
-    RealMatrix vectorsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
+    RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
 
     BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
     int m = 0;
@@ -115,7 +120,7 @@ public class WordVectorsTest {
       for (int x = 0; x < row.length; x++) {
         row[x] = predict[x];
       }
-      vectorsMatrix.setRow(m, row);
+      mappingsMatrix.setRow(m, row);
       m++;
 
       String vectorString = Arrays.toString(predict);
@@ -145,7 +150,7 @@ public class WordVectorsTest {
     bufferedWriter.close();
 
     ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
-    MatrixUtils.serializeRealMatrix(vectorsMatrix, os);
+    MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
   }
 
   private String hotDecode(Double[] doubles, List<String> vocabulary) {



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org