You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@mahout.apache.org by "Eric Springer (JIRA)" <ji...@apache.org> on 2012/10/09 07:52:03 UTC

[jira] [Updated] (MAHOUT-1093) CrossFoldLearner trains in all folds if trackign key is negative

     [ https://issues.apache.org/jira/browse/MAHOUT-1093?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]

Eric Springer updated MAHOUT-1093:
----------------------------------

    Status: Patch Available  (was: Open)

Since there's no "upload file", I'll paste it inline...

----

commit 831ca2200df9802f24c8a92077377f677be746ef
Author: Eric Springer <er...@gmail.com>
Date:   Tue Oct 9 12:36:14 2012 +1100

    CrossFoldLearner shouldn't train on all folds if TrackingKey is negative

diff --git a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
index 33f0266..f8b5b67 100644
--- a/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
+++ b/core/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -123,7 +123,7 @@ public class CrossFoldLearner extends AbstractVectorClassifier implements Online
     record++;
     int k = 0;
     for (OnlineLogisticRegression model : models) {
-      if (k == trackingKey % models.size()) {
+      if (k == mod(trackingKey, models.size())) {
         Vector v = model.classifyFull(instance);
         double score = Math.max(v.get(actual), MIN_SCORE);
         logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
@@ -140,6 +140,11 @@ public class CrossFoldLearner extends AbstractVectorClassifier implements Online
     }
   }
 
+  private int mod(int x, int y) {
+    int r = x % y;
+    return r < 0 ? r + y : r;
+  }
+
   @Override
   public void close() {
     for (OnlineLogisticRegression m : models) {
                
> CrossFoldLearner trains in all folds if trackign key is negative
> ----------------------------------------------------------------
>
>                 Key: MAHOUT-1093
>                 URL: https://issues.apache.org/jira/browse/MAHOUT-1093
>             Project: Mahout
>          Issue Type: Bug
>          Components: Classification
>            Reporter: Eric Springer
>
> See: https://github.com/apache/mahout/pull/7

--
This message is automatically generated by JIRA.
If you think it was sent incorrectly, please contact your JIRA administrators
For more information on JIRA, see: http://www.atlassian.com/software/jira