You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by to...@apache.org on 2012/10/07 17:01:19 UTC

svn commit: r1395322 - /hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Author: tommaso
Date: Sun Oct  7 15:01:18 2012
New Revision: 1395322

URL: http://svn.apache.org/viewvc?rev=1395322&view=rev
Log:
[HAMA-651] - added gradient descent cost check with threashold

Modified:
    hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395322&r1=1395321&r2=1395322&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sun Oct  7 15:01:18 2012
@@ -42,10 +42,16 @@ public abstract class GradientDescentBSP
 
   private boolean master;
   private DoubleVector theta;
+  private double cost;
+  private double threshold;
+  private float alpha;
 
   @Override
   public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
     master = peer.getPeerIndex() == peer.getNumPeers() / 2;
+    cost = Integer.MAX_VALUE;
+    threshold = peer.getConfiguration().getFloat("threashold", 0.01f);
+    alpha = peer.getConfiguration().getFloat(ALPHA, 0.3f);
   }
 
   @Override
@@ -93,10 +99,24 @@ public abstract class GradientDescentBSP
 
       totalCost /= numRead;
 
+      if (cost - totalCost < 0){
+        throw new RuntimeException("gradient descent failed to converge with alpha " + alpha);
+      }
+      else if (totalCost == 0 || cost - totalCost < threshold) {
+        cost = totalCost;
+        break;
+      }
+      else {
+        cost = totalCost;
+      }
+
+
       if (log.isInfoEnabled()) {
-        log.info("cost is " + totalCost);
+        log.info("cost is " + cost);
       }
 
+
+
       peer.sync();
 
       peer.reopenInput();
@@ -130,7 +150,7 @@ public abstract class GradientDescentBSP
         }
 
         for (int j = 0; j < theta.getLength(); j++) {
-          newTheta[j] = theta.get(j) - newTheta[j] * peer.getConfiguration().getFloat(ALPHA, 0.3f);
+          newTheta[j] = theta.get(j) - newTheta[j] * alpha;
         }
 
         theta = new DenseDoubleVector(newTheta);
@@ -145,11 +165,6 @@ public abstract class GradientDescentBSP
       }
       peer.sync();
 
-      // eventually break execution !?
-      if (totalCost == 0) {
-        // TODO change this as just 0 is too strict
-        break;
-      }
     }
 
   }