You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@labs.apache.org by to...@apache.org on 2012/07/30 15:00:12 UTC
svn commit: r1367075 - in /labs/yay/trunk/core/src:
main/java/org/apache/yay/CostFunction.java
main/java/org/apache/yay/LogisticRegressionCostFunction.java
test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
Author: tommaso
Date: Mon Jul 30 13:00:12 2012
New Revision: 1367075
URL: http://svn.apache.org/viewvc?rev=1367075&view=rev
Log:
adding CostFunction and LogisticRegression first sketch up impl
Added:
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java?rev=1367075&r1=1367074&r2=1367075&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/CostFunction.java Mon Jul 30 13:00:12 2012
@@ -18,12 +18,18 @@
*/
package org.apache.yay;
+import java.util.Collection;
+
/**
- * A cost function helps on figuring out how to fit best a specific model to
- * given data (inputs)
+ * A cost function calculates the cost of using a specified model (via its
+ * {@link ActivationFunction}) for fitting the given corpus (a {@link Collection}
+ * of {@link TrainingExample}s).
+ *
*/
-public interface CostFunction<T> {
+public interface CostFunction<T, S> {
- public Double calculateCost(T... parameters);
+ public Double calculateCost(Collection<TrainingExample<S, S>> trainingExamples,
+ ActivationFunction<S> activationFunction,
+ T... parameters) throws Exception;
}
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1367075&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java Mon Jul 30 13:00:12 2012
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.util.Collection;
+
+/**
+ * This calculates the logistic regression cost function for neural networks
+ */
+public class LogisticRegressionCostFunction implements CostFunction<WeightsMatrix, Double> {
+
+ private final Double lambda;
+
+ public LogisticRegressionCostFunction(Double lambda) {
+ this.lambda = lambda;
+ }
+
+ @Override
+ public Double calculateCost(Collection<TrainingExample<Double, Double>> trainingExamples,
+ ActivationFunction<Double> hypothesis,
+ WeightsMatrix... parameters) throws Exception {
+
+ Double errorTerm = calculateErrorTerm(parameters, hypothesis, trainingExamples);
+ Double regularizationTerm = calculateRegularizationTerm(parameters, trainingExamples);
+ return errorTerm + regularizationTerm;
+ }
+
+ private Double calculateRegularizationTerm(WeightsMatrix[] parameters,
+ Collection<TrainingExample<Double, Double>> trainingExamples) {
+ Double res = 1d;
+ for (WeightsMatrix layerMatrix : parameters) {
+ for (int i = 0; i < layerMatrix.getColumnDimension(); i++) {
+ double[] column = layerMatrix.getColumn(i);
+ for (int j = 0; j < column.length; j++) {
+ res+= Math.pow(column[j], 2d);
+ }
+ }
+ }
+ return (lambda / (2d * trainingExamples.size())) * res;
+ }
+
+ private Double calculateErrorTerm(WeightsMatrix[] parameters, ActivationFunction<Double> hypothesis,
+ Collection<TrainingExample<Double, Double>> trainingExamples) throws PredictionException, CreationException {
+ Double res = 0d;
+ NeuralNetwork<Double, Double> neuralNetwork = NeuralNetworkFactory.create(
+ (Collection<TrainingExample<Double, Double>>) trainingExamples,
+ parameters, new VoidLearningStrategy(), new FeedForwardStrategy(
+ (ActivationFunction<Double>) hypothesis));
+ for (TrainingExample<Double, Double> input : trainingExamples) {
+ // TODO : handle this for multiple outputs (multi class classification)
+ Double predictedOutput = neuralNetwork.predict(input);
+ Double sampleOutput = input.getOutput();
+ res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput) * Math.log(1d - predictedOutput);
+ }
+ return (-1d / trainingExamples.size()) * res;
+ }
+}
Added: labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java
URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java?rev=1367075&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/LogisticRegressionCostFunctionTest.java Mon Jul 30 13:00:12 2012
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.util.Collection;
+import java.util.LinkedList;
+
+import org.apache.yay.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Testcase for {@link LogisticRegressionCostFunction}
+ */
+public class LogisticRegressionCostFunctionTest {
+
+ @Test
+ public void testORParametersCost() throws Exception {
+ CostFunction<WeightsMatrix, Double> costFunction = new LogisticRegressionCostFunction(0.1d);
+ Collection<TrainingExample<Double, Double>> trainingExamples = new LinkedList<TrainingExample<Double, Double>>();
+ TrainingExample<Double, Double> example1 = ExamplesFactory.createDoubleTrainingExample(1d, 0d, 1d);
+ TrainingExample<Double, Double> example2 = ExamplesFactory.createDoubleTrainingExample(1d, 1d, 1d);
+ TrainingExample<Double, Double> example3 = ExamplesFactory.createDoubleTrainingExample(0d, 1d, 1d);
+ TrainingExample<Double, Double> example4 = ExamplesFactory.createDoubleTrainingExample(0d, 0d, 0d);
+ trainingExamples.add(example1);
+ trainingExamples.add(example2);
+ trainingExamples.add(example3);
+ trainingExamples.add(example4);
+ double[][] weights = {{-10d, 20d, 20d}};
+ WeightsMatrix singleOrLayerWeights = new WeightsMatrix(weights);
+ WeightsMatrix[] orWeightsMatrixSet = new WeightsMatrix[]{singleOrLayerWeights};
+ Double cost = costFunction.calculateCost(trainingExamples, new SigmoidFunction(), orWeightsMatrixSet);
+ assertTrue("cost should not be negative", cost > 0d);
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@labs.apache.org
For additional commands, e-mail: commits-help@labs.apache.org