You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by to...@apache.org on 2012/10/12 15:02:26 UTC
svn commit: r1397550 - in /hama/trunk:
examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Author: tommaso
Date: Fri Oct 12 13:02:26 2012
New Revision: 1397550
URL: http://svn.apache.org/viewvc?rev=1397550&view=rev
Log:
[HAMA-651] - added iterations threshold
Modified:
hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified: hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java
URL: http://svn.apache.org/viewvc/hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java?rev=1397550&r1=1397549&r2=1397550&view=diff
==============================================================================
--- hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java (original)
+++ hama/trunk/examples/src/main/java/org/apache/hama/examples/GradientDescentExample.java Fri Oct 12 13:02:26 2012
@@ -44,7 +44,8 @@ public class GradientDescentExample {
// BSP job configuration
HamaConfiguration conf = new HamaConfiguration();
conf.setFloat(GradientDescentBSP.ALPHA, 0.002f);
- conf.setFloat(GradientDescentBSP.THRESHOLD, 0.2f);
+ conf.setFloat(GradientDescentBSP.COST_THRESHOLD, 0.5f);
+ conf.setInt(GradientDescentBSP.ITERATIONS_THRESHOLD, 300);
BSPJob bsp = new BSPJob(conf, GradientDescentExample.class);
// Set the job name
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1397550&r1=1397549&r2=1397550&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Fri Oct 12 13:02:26 2012
@@ -38,21 +38,24 @@ public class GradientDescentBSP extends
private static final Logger log = LoggerFactory.getLogger(GradientDescentBSP.class);
public static final String INITIAL_THETA_VALUES = "gd.initial.theta";
public static final String ALPHA = "gd.alpha";
- public static final String THRESHOLD = "gd.threshold";
+ public static final String COST_THRESHOLD = "gd.cost.threshold";
+ public static final String ITERATIONS_THRESHOLD = "gd.iterations.threshold";
public static final String REGRESSION_MODEL_CLASS = "gd.regression.model";
private boolean master;
private DoubleVector theta;
private double cost;
- private double threshold;
+ private double costThreshold;
private float alpha;
private RegressionModel regressionModel;
+ private int iterationsThreshold;
@Override
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
master = peer.getPeerIndex() == peer.getNumPeers() / 2;
cost = Integer.MAX_VALUE;
- threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.1f);
+ costThreshold = peer.getConfiguration().getFloat(COST_THRESHOLD, 0.1f);
+ iterationsThreshold = peer.getConfiguration().getInt(ITERATIONS_THRESHOLD, 10000);
alpha = peer.getConfiguration().getFloat(ALPHA, 0.003f);
try {
regressionModel = ((Class<? extends RegressionModel>) peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS, LinearRegressionModel.class)).newInstance();
@@ -63,7 +66,7 @@ public class GradientDescentBSP extends
@Override
public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
-
+ int iterations = 0;
while (true) {
getTheta(peer);
@@ -109,7 +112,7 @@ public class GradientDescentBSP extends
if (cost - totalCost < 0) {
throw new RuntimeException(new StringBuilder("gradient descent failed to converge with alpha ").
append(alpha).toString());
- } else if (totalCost == 0 || totalCost < threshold) {
+ } else if (totalCost == 0 || totalCost < costThreshold || iterations >= iterationsThreshold) {
cost = totalCost;
break;
} else {
@@ -168,8 +171,8 @@ public class GradientDescentBSP extends
peer.reopenInput();
peer.sync();
+ iterations++;
}
-
}
@Override