You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ss...@apache.org on 2013/03/11 17:55:10 UTC
svn commit: r1455231 -
/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Author: ssc
Date: Mon Mar 11 16:55:09 2013
New Revision: 1455231
URL: http://svn.apache.org/r1455231
Log:
MAHOUT-1093 CrossFoldLearner trains in all folds if trackign key is negative
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java?rev=1455231&r1=1455230&r2=1455231&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java Mon Mar 11 16:55:09 2013
@@ -123,7 +123,7 @@ public class CrossFoldLearner extends Ab
record++;
int k = 0;
for (OnlineLogisticRegression model : models) {
- if (k == trackingKey % models.size()) {
+ if (k == mod(trackingKey, models.size())) {
Vector v = model.classifyFull(instance);
double score = Math.max(v.get(actual), MIN_SCORE);
logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
@@ -140,6 +140,11 @@ public class CrossFoldLearner extends Ab
}
}
+ private long mod(long x, int y) {
+ long r = x % y;
+ return r < 0 ? r + y : r;
+ }
+
@Override
public void close() {
for (OnlineLogisticRegression m : models) {