You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2015/11/12 17:11:26 UTC

svn commit: r1714080 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/core/CrossEntropyCostFunction.java test/java/org/apache/yay/core/WordVectorsTest.java

Author: tommaso
Date: Thu Nov 12 16:11:25 2015
New Revision: 1714080

URL: http://svn.apache.org/viewvc?rev=1714080&view=rev
Log:
added cross-entropy cost function

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java   (with props)
Modified:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java?rev=1714080&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java Thu Nov 12 16:11:25 2015
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.Hypothesis;
+import org.apache.yay.NeuralNetworkCostFunction;
+import org.apache.yay.PredictionException;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+
+/**
+ * This calculates the cross entropy cost function for neural networks
+ */
+public class CrossEntropyCostFunction implements NeuralNetworkCostFunction {
+
+  @Override
+  public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingSet,
+                                        Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
+    TrainingExample<Double, Double>[] samples = new TrainingExample[trainingSet.size()];
+    int i = 0;
+    for (TrainingExample<Double, Double> sample : trainingSet) {
+      samples[i] = sample;
+      i++;
+    }
+    return calculateCost(hypothesis, samples);
+  }
+
+  private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis,
+                                    TrainingExample<Double, Double>... trainingExamples) throws PredictionException {
+    Double res = 0d;
+
+    for (TrainingExample<Double, Double> input : trainingExamples) {
+      Double[] predictedOutput = hypothesis.predict(input);
+      Double[] sampleOutput = input.getOutput();
+      for (int i = 0; i < predictedOutput.length; i++) {
+        Double so = sampleOutput[i];
+        Double po = predictedOutput[i];
+        res -= so * Math.log(po);
+      }
+    }
+    return res;
+  }
+
+  @Override
+  public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception {
+    return calculateErrorTerm(hypothesis, trainingExamples);
+  }
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1714080&r1=1714079&r2=1714080&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Thu Nov 12 16:11:25 2015
@@ -105,7 +105,7 @@ public class WordVectorsTest {
     activationFunctions.put(1, new SoftmaxActivationFunction());
     FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions);
     BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(),
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new CrossEntropyCostFunction(),
             trainingSet.size());
     NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy);
 
@@ -238,9 +238,8 @@ public class WordVectorsTest {
     }
   }
 
-  private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) {
+  private TrainingSet<Double, Double> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
     long start = System.currentTimeMillis();
-    Path file = Paths.get("/Users/teofili/Desktop/ts.txt");
     Collection<TrainingExample<Double, Double>> samples = new LinkedList<>();
     List<byte[]> fragment;
     while ((fragment = fragments.poll()) != null) {



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org