You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/10/18 15:59:40 UTC
svn commit: r1709279 - in /labs/yay/trunk: api/src/main/java/org/apache/yay/
core/src/main/java/org/apache/yay/core/
core/src/test/java/org/apache/yay/core/
Author: tommaso
Date: Sun Oct 18 13:59:39 2015
New Revision: 1709279
URL: http://svn.apache.org/viewvc?rev=1709279&view=rev
Log:
update NN api
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct 18 13:59:39 2015
@@ -47,12 +47,4 @@ public interface Hypothesis<T, I, O> {
*/
O[] predict(Input<I> input) throws PredictionException;
- /**
- * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}
- *
- * @param trainingExamples the learning {@link org.apache.yay.TrainingSet}
- * @throws LearningException if any error occurs during the learning phase
- */
- void learn(TrainingSet<I, O> trainingExamples) throws LearningException;
-
}
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct 18 13:59:39 2015
@@ -27,4 +27,13 @@ import org.apache.commons.math3.linear.R
*/
public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
+ /**
+ * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet}
+ *
+ * @param trainingExamples the learning {@link org.apache.yay.TrainingSet}
+ * @return the learned weights
+ * @throws LearningException if any error occurs during the learning phase
+ */
+ RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException;
+
}
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct 18 13:59:39 2015
@@ -52,10 +52,11 @@ public class BasicPerceptron implements
}
@Override
- public void learn(TrainingSet<Double, Double> trainingExamples) throws LearningException {
+ public RealMatrix[] learn(TrainingSet<Double, Double> trainingExamples) throws LearningException {
for (TrainingExample<Double, Double> example : trainingExamples) {
learn(example);
}
+ return new RealMatrix[]{new Array2DRowRealMatrix(currentWeights)};
}
public void learn(TrainingExample<Double, Double> example) {
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct 18 13:59:39 2015
@@ -56,9 +56,10 @@ class NeuralNetworkFactory {
private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
@Override
- public void learn(TrainingSet<Double, Double> samples) throws LearningException {
+ public RealMatrix[] learn(TrainingSet<Double, Double> samples) throws LearningException {
try {
updatedRealMatrixSet = learningStrategy.learnWeights(realMatrixSet, samples);
+ return updatedRealMatrixSet;
} catch (WeightLearningException e) {
throw new LearningException(e);
}
Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1709279&r1=1709278&r2=1709279&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Sun Oct 18 13:59:39 2015
@@ -66,24 +66,29 @@ public class WordVectorsTest {
assertFalse(fragments.isEmpty());
TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments);
-
TrainingExample<Double, Double> next = trainingSet.iterator().next();
+
int inputSize = next.getFeatures().size() ;
int outputSize = next.getOutput().length;
int hiddenSize = new Random().nextInt(50) + 15;
+ System.err.println("i:"+inputSize+",h:"+hiddenSize+",o:"+outputSize);
RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize);
Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>();
activationFunctions.put(0, new IdentityActivationFunction<Double>());
activationFunctions.put(1, new SoftmaxActivationFunction());
FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
- BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 10,
- BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(), 10);
+ BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 1,
+ BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), 10);
NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
- neuralNetwork.learn(trainingSet);
+ RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet);
+
+ RealMatrix wordVectors = learnedWeights[learnedWeights.length - 1];
+
+ assertNotNull(wordVectors);
- RealMatrix vectorsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
+ RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt")));
int m = 0;
@@ -115,7 +120,7 @@ public class WordVectorsTest {
for (int x = 0; x < row.length; x++) {
row[x] = predict[x];
}
- vectorsMatrix.setRow(m, row);
+ mappingsMatrix.setRow(m, row);
m++;
String vectorString = Arrays.toString(predict);
@@ -145,7 +150,7 @@ public class WordVectorsTest {
bufferedWriter.close();
ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin")));
- MatrixUtils.serializeRealMatrix(vectorsMatrix, os);
+ MatrixUtils.serializeRealMatrix(mappingsMatrix, os);
}
private String hotDecode(Double[] doubles, List<String> vocabulary) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org