You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by to...@apache.org on 2012/10/07 17:01:19 UTC
svn commit: r1395322 -
/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Author: tommaso
Date: Sun Oct 7 15:01:18 2012
New Revision: 1395322
URL: http://svn.apache.org/viewvc?rev=1395322&view=rev
Log:
[HAMA-651] - added gradient descent cost check with threashold
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395322&r1=1395321&r2=1395322&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sun Oct 7 15:01:18 2012
@@ -42,10 +42,16 @@ public abstract class GradientDescentBSP
private boolean master;
private DoubleVector theta;
+ private double cost;
+ private double threshold;
+ private float alpha;
@Override
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
master = peer.getPeerIndex() == peer.getNumPeers() / 2;
+ cost = Integer.MAX_VALUE;
+ threshold = peer.getConfiguration().getFloat("threashold", 0.01f);
+ alpha = peer.getConfiguration().getFloat(ALPHA, 0.3f);
}
@Override
@@ -93,10 +99,24 @@ public abstract class GradientDescentBSP
totalCost /= numRead;
+ if (cost - totalCost < 0){
+ throw new RuntimeException("gradient descent failed to converge with alpha " + alpha);
+ }
+ else if (totalCost == 0 || cost - totalCost < threshold) {
+ cost = totalCost;
+ break;
+ }
+ else {
+ cost = totalCost;
+ }
+
+
if (log.isInfoEnabled()) {
- log.info("cost is " + totalCost);
+ log.info("cost is " + cost);
}
+
+
peer.sync();
peer.reopenInput();
@@ -130,7 +150,7 @@ public abstract class GradientDescentBSP
}
for (int j = 0; j < theta.getLength(); j++) {
- newTheta[j] = theta.get(j) - newTheta[j] * peer.getConfiguration().getFloat(ALPHA, 0.3f);
+ newTheta[j] = theta.get(j) - newTheta[j] * alpha;
}
theta = new DenseDoubleVector(newTheta);
@@ -145,11 +165,6 @@ public abstract class GradientDescentBSP
}
peer.sync();
- // eventually break execution !?
- if (totalCost == 0) {
- // TODO change this as just 0 is too strict
- break;
- }
}
}