You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hama.apache.org by to...@apache.org on 2012/11/10 09:09:34 UTC
svn commit: r1407729 -
/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Author: tommaso
Date: Sat Nov 10 08:09:34 2012
New Revision: 1407729
URL: http://svn.apache.org/viewvc?rev=1407729&view=rev
Log:
[HAMA-669] - fixed derivatives aggregation wrongful array copy
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified: hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1407729&r1=1407728&r2=1407729&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java (original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java Sat Nov 10 08:09:34 2012
@@ -29,6 +29,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
+import java.util.Arrays;
/**
* A gradient descent (see <code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based implementation.
@@ -54,135 +55,72 @@ public class GradientDescentBSP extends
@Override
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
master = peer.getPeerIndex() == peer.getNumPeers() / 2;
- cost = Integer.MAX_VALUE;
+ cost = Double.MAX_VALUE;
costThreshold = peer.getConfiguration().getFloat(COST_THRESHOLD, 0.1f);
iterationsThreshold = peer.getConfiguration().getInt(ITERATIONS_THRESHOLD, 10000);
alpha = peer.getConfiguration().getFloat(ALPHA, 0.003f);
try {
- regressionModel = ((Class<? extends RegressionModel>) peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS, LinearRegressionModel.class)).newInstance();
+ regressionModel = ((Class<? extends RegressionModel>) peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS, LinearRegressionModel.class)).newInstance();
} catch (Exception e) {
- throw new IOException(e);
+ throw new IOException(e);
}
}
@Override
public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
- // 0 superstep : count items
+ // 0a superstep: get initial theta
+ getInitialTheta(peer);
+ // 0b superstep: count items
int itemCount = 0;
while (peer.readNext() != null) {
// increment counter
itemCount++;
}
- for (String peerName : peer.getAllPeerNames()) {
- if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
- peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{itemCount})));
- }
- }
+ broadcastVector(peer, new double[]{itemCount});
peer.sync();
// aggregate number of items
- VectorWritable itemsResult;
- while ((itemsResult = peer.getCurrentMessage()) != null) {
- itemCount += itemsResult.getVector().get(0);
- }
-
- m = itemCount;
+ aggregateItemsNumber(peer, itemCount);
peer.reopenInput();
int iterations = 0;
while (true) {
- getTheta(peer);
-
// first superstep : calculate cost function in parallel
-
- double localCost = 0d;
-
- // read an item
- KeyValuePair<VectorWritable, DoubleWritable> kvp;
- while ((kvp = peer.readNext()) != null) {
- // calculate cost for given input
- double y = kvp.getValue().get();
- DoubleVector x = kvp.getKey().getVector();
- double costForX = regressionModel.calculateCostForItem(x, y, m, theta);
-
- // adds to local cost
- localCost += costForX;
- }
+ double localCost = calculateLocalCost(peer);
// cost is sent and aggregated by each
- for (String peerName : peer.getAllPeerNames()) {
- if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
- peer.send(peerName, new VectorWritable(new DenseDoubleVector(new double[]{localCost})));
- }
- }
+ broadcastVector(peer, new double[]{localCost});
peer.sync();
// second superstep : aggregate cost calculation
- double totalCost = localCost;
- VectorWritable costResult;
- while ((costResult = peer.getCurrentMessage()) != null) {
- totalCost += costResult.getVector().get(0);
- }
+ double totalCost = aggregateTotalCost(peer, localCost);
// cost check
- if (cost - totalCost < 0) {
- throw new RuntimeException(new StringBuilder("gradient descent failed to converge with alpha ").
- append(alpha).toString());
- } else if (totalCost == 0 || totalCost < costThreshold || iterations >= iterationsThreshold) {
- cost = totalCost;
- break;
- } else {
- cost = totalCost;
- if (log.isDebugEnabled()) {
- log.debug(peer.getPeerName() + ": cost is " + cost);
- }
- }
+ if (checkCost(peer, iterations, totalCost)) break;
- peer.reopenInput();
peer.sync();
-
- double[] thetaDelta = new double[theta.getLength()];
+ peer.reopenInput();
// third superstep : calculate partial derivatives' deltas in parallel
- while ((kvp = peer.readNext()) != null) {
- DoubleVector x = kvp.getKey().getVector();
- double y = kvp.getValue().get();
- double difference = regressionModel.applyHypothesis(theta, x) - y;
- for (int j = 0; j < theta.getLength(); j++) {
- thetaDelta[j] += difference * x.get(j);
- }
- }
+ double[] thetaDelta = calculatePartialDerivatives(peer);
// send thetaDelta to the each peer
- for (String peerName : peer.getAllPeerNames()) {
- if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
- peer.send(peerName, new VectorWritable(new DenseDoubleVector(thetaDelta)));
- }
- }
+ broadcastVector(peer, thetaDelta);
peer.sync();
// fourth superstep : aggregate partial derivatives
- VectorWritable thetaDeltaSlice;
- double[] newTheta = thetaDelta;
- while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
+ double[] newTheta = aggregatePartialDerivatives(peer, thetaDelta);
- for (int j = 0; j < theta.getLength(); j++) {
- newTheta[j] += thetaDeltaSlice.getVector().get(j);
- }
-
- for (int j = 0; j < theta.getLength(); j++) {
- newTheta[j] = theta.get(j) - newTheta[j] * alpha;
- }
- }
- theta = new DenseDoubleVector(newTheta);
+ // update theta
+ updateTheta(newTheta);
if (log.isDebugEnabled()) {
log.debug(new StringBuilder(peer.getPeerName()).append(": new theta for cost ").
- append(cost).append(" is ").append(theta.toString()).toString());
+ append(cost).append(" is ").append(theta.toString()).toString());
}
// master writes down the output
if (master) {
@@ -194,6 +132,98 @@ public class GradientDescentBSP extends
iterations++;
}
+}
+
+ private double aggregateTotalCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer, double localCost) throws IOException {
+ double totalCost = localCost;
+ VectorWritable costResult;
+ while ((costResult = peer.getCurrentMessage()) != null) {
+ totalCost += costResult.getVector().get(0);
+ }
+ return totalCost;
+ }
+
+ private double[] aggregatePartialDerivatives(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer, double[] thetaDelta) throws IOException {
+ VectorWritable thetaDeltaSlice;
+ double[] newTheta = Arrays.copyOf(thetaDelta, thetaDelta.length);
+ while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
+ for (int j = 0; j < theta.getLength(); j++) {
+ newTheta[j] += thetaDeltaSlice.getVector().get(j);
+ }
+ }
+ return newTheta;
+ }
+
+ private void updateTheta(double[] thetaDiff) {
+ double[] newTheta = new double[theta.getLength()];
+ for (int j = 0; j < theta.getLength(); j++) {
+ newTheta[j] = theta.get(j) - thetaDiff[j] * alpha;
+ }
+ theta = new DenseDoubleVector(newTheta);
+ }
+
+ private void aggregateItemsNumber(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer, int itemCount) throws IOException {
+ VectorWritable itemsResult;
+ while ((itemsResult = peer.getCurrentMessage()) != null) {
+ itemCount += itemsResult.getVector().get(0);
+ }
+
+ m = itemCount;
+ }
+
+ private boolean checkCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer, int iterations, double totalCost) {
+ if (iterations > 0 && cost < totalCost ) {
+ throw new RuntimeException(new StringBuilder("gradient descent failed to converge with alpha ").
+ append(alpha).toString());
+ } else if (totalCost == 0 || totalCost < costThreshold || iterations >= iterationsThreshold) {
+ cost = totalCost;
+ return true;
+ } else {
+ cost = totalCost;
+ if (log.isDebugEnabled()) {
+ log.debug(new StringBuilder(peer.getPeerName()).append(": current cost is ").append(cost).toString());
+ }
+ return false;
+ }
+}
+
+ private double calculateLocalCost(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException {
+ double localCost = 0d;
+
+ // read an item
+ KeyValuePair<VectorWritable, DoubleWritable> kvp;
+ while ((kvp = peer.readNext()) != null) {
+ // calculate cost for given input
+ double y = kvp.getValue().get();
+ DoubleVector x = kvp.getKey().getVector();
+ double costForX = regressionModel.calculateCostForItem(x, y, m, theta);
+
+ // adds to local cost
+ localCost += costForX;
+ }
+ return localCost;
+}
+
+ private void broadcastVector(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer, double[] vector) throws IOException {
+ for (String peerName : peer.getAllPeerNames()) {
+ if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
+ peer.send(peerName, new VectorWritable(new DenseDoubleVector(vector)));
+ }
+ }
+ }
+
+ private double[] calculatePartialDerivatives(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException {
+ KeyValuePair<VectorWritable, DoubleWritable> kvp;
+ double[] thetaDelta = new double[theta.getLength()];
+ while ((kvp = peer.readNext()) != null) {
+ DoubleVector x = kvp.getKey().getVector();
+ double y = kvp.getValue().get();
+ double difference = regressionModel.applyHypothesis(theta, x) - y;
+ for (int j = 0; j < theta.getLength(); j++) {
+ thetaDelta[j] += difference * x.get(j);
+ }
+ }
+ return thetaDelta;
}
@Override
@@ -203,31 +233,29 @@ public class GradientDescentBSP extends
peer.write(new VectorWritable(theta), new DoubleWritable(cost));
if (log.isInfoEnabled()) {
log.info(new StringBuilder(peer.getPeerName()).append(":computation finished with cost ").
- append(cost).append(" for theta ").append(theta).toString());
+ append(cost).append(" for theta ").append(theta).toString());
}
}
}
- public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
+ public void getInitialTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> peer) throws IOException, SyncException, InterruptedException {
if (theta == null) {
if (master) {
int size = getXSize(peer);
theta = new DenseDoubleVector(size, peer.getConfiguration().getInt(INITIAL_THETA_VALUES, 1));
- for (String peerName : peer.getAllPeerNames()) {
- peer.send(peerName, new VectorWritable(theta));
- }
+ broadcastVector(peer,theta.toArray());
if (log.isDebugEnabled()) {
log.debug(new StringBuilder(peer.getPeerName()).append(": sending theta").toString());
}
peer.sync();
- } else {
- if (log.isDebugEnabled()) {
- log.debug(new StringBuilder(peer.getPeerName()).append(": getting theta").toString());
+ } else {
+ if (log.isDebugEnabled()) {
+ log.debug(new StringBuilder(peer.getPeerName()).append(": getting theta").toString());
+ }
+ peer.sync();
+ VectorWritable vectorWritable = peer.getCurrentMessage();
+ theta = vectorWritable.getVector();
}
- peer.sync();
- VectorWritable vectorWritable = peer.getCurrentMessage();
- theta = vectorWritable.getVector();
- }
}
}
@@ -237,8 +265,8 @@ public class GradientDescentBSP extends
peer.readNext(key, value);
peer.reopenInput(); // reset input to start
if (key.getVector() == null) {
- throw new IOException("cannot read input vector size");
+ throw new IOException("cannot read input vector size");
}
- return key.getVector().getLength();
+ return key.getVector().getDimension();
}
}