You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/07/30 15:00:12 UTC

svn commit: r1367075 - in /labs/yay/trunk/core/src: main/java/org/apache/yay/CostFunction.java main/java/org/apache/yay/LogisticRegressionCostFunction.java test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java

Author: tommaso
Date: Mon Jul 30 13:00:12 2012
New Revision: 1367075

URL: http://svn.apache.org/viewvc?rev=1367075&view=rev
Log:
adding CostFunction and LogisticRegression first sketch up impl

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java?rev=1367075&r1=1367074&r2=1367075&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java Mon Jul 30 13:00:12 2012
@@ -18,12 +18,18 @@
  */
 package org.apache.yay;
 
+import java.util.Collection;
+
 /**
- * A cost function helps on figuring out how to fit best a specific model to
- * given data (inputs)
+ * A cost function calculates the cost of using a specified model (via its
+ * {@link ActivationFunction}) for fitting the given corpus (a {@link Collection}
+ * of {@link TrainingExample}s).
+ *
  */
-public interface CostFunction<T> {
+public interface CostFunction<T, S> {
 
-  public Double calculateCost(T... parameters);
+  public Double calculateCost(Collection<TrainingExample<S, S>> trainingExamples,
+                              ActivationFunction<S> activationFunction,
+                              T... parameters) throws Exception;
 
 }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1367075&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java Mon Jul 30 13:00:12 2012
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.util.Collection;
+
+/**
+ * This calculates the logistic regression cost function for neural networks
+ */
+public class LogisticRegressionCostFunction implements CostFunction<WeightsMatrix, Double> {
+
+    private final Double lambda;
+
+    public LogisticRegressionCostFunction(Double lambda) {
+        this.lambda = lambda;
+    }
+
+    @Override
+    public Double calculateCost(Collection<TrainingExample<Double, Double>> trainingExamples,
+                                ActivationFunction<Double> hypothesis,
+                                WeightsMatrix... parameters) throws Exception {
+
+        Double errorTerm = calculateErrorTerm(parameters, hypothesis, trainingExamples);
+        Double regularizationTerm = calculateRegularizationTerm(parameters, trainingExamples);
+        return errorTerm + regularizationTerm;
+    }
+
+    private Double calculateRegularizationTerm(WeightsMatrix[] parameters,
+                                               Collection<TrainingExample<Double, Double>> trainingExamples) {
+        Double res = 1d;
+        for (WeightsMatrix layerMatrix : parameters) {
+            for (int i = 0; i < layerMatrix.getColumnDimension(); i++) {
+                double[] column = layerMatrix.getColumn(i);
+                for (int j = 0; j < column.length; j++) {
+                    res+= Math.pow(column[j], 2d);
+                }
+            }
+        }
+        return (lambda / (2d * trainingExamples.size())) * res;
+    }
+
+    private Double calculateErrorTerm(WeightsMatrix[] parameters, ActivationFunction<Double> hypothesis,
+                                      Collection<TrainingExample<Double, Double>> trainingExamples) throws PredictionException, CreationException {
+        Double res = 0d;
+        NeuralNetwork<Double, Double> neuralNetwork = NeuralNetworkFactory.create(
+                (Collection<TrainingExample<Double, Double>>) trainingExamples,
+                parameters, new VoidLearningStrategy(), new FeedForwardStrategy(
+                (ActivationFunction<Double>) hypothesis));
+        for (TrainingExample<Double, Double> input : trainingExamples) {
+            // TODO : handle this for multiple outputs (multi class classification)
+            Double predictedOutput = neuralNetwork.predict(input);
+            Double sampleOutput = input.getOutput();
+            res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput) * Math.log(1d - predictedOutput);
+        }
+        return (-1d / trainingExamples.size()) * res;
+    }
+}

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java?rev=1367075&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java Mon Jul 30 13:00:12 2012
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.util.Collection;
+import java.util.LinkedList;
+
+import org.apache.yay.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Testcase for {@link LogisticRegressionCostFunction}
+ */
+public class LogisticRegressionCostFunctionTest {
+
+    @Test
+    public void testORParametersCost() throws Exception {
+        CostFunction<WeightsMatrix, Double> costFunction = new LogisticRegressionCostFunction(0.1d);
+        Collection<TrainingExample<Double, Double>> trainingExamples = new LinkedList<TrainingExample<Double, Double>>();
+        TrainingExample<Double, Double> example1 = ExamplesFactory.createDoubleTrainingExample(1d, 0d, 1d);
+        TrainingExample<Double, Double> example2 = ExamplesFactory.createDoubleTrainingExample(1d, 1d, 1d);
+        TrainingExample<Double, Double> example3 = ExamplesFactory.createDoubleTrainingExample(0d, 1d, 1d);
+        TrainingExample<Double, Double> example4 = ExamplesFactory.createDoubleTrainingExample(0d, 0d, 0d);
+        trainingExamples.add(example1);
+        trainingExamples.add(example2);
+        trainingExamples.add(example3);
+        trainingExamples.add(example4);
+        double[][] weights = {{-10d, 20d, 20d}};
+        WeightsMatrix singleOrLayerWeights = new WeightsMatrix(weights);
+        WeightsMatrix[] orWeightsMatrixSet = new WeightsMatrix[]{singleOrLayerWeights};
+        Double cost = costFunction.calculateCost(trainingExamples, new SigmoidFunction(), orWeightsMatrixSet);
+        assertTrue("cost should not be negative", cost > 0d);
+    }
+
+}



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org