You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/08 01:45:21 UTC
systemml git commit: [SYSTEMML-2403] Fix accuracy issue paramserv BSP
batch updates
Repository: systemml
Updated Branches:
refs/heads/master eb179b151 -> 63a1e2ac5
[SYSTEMML-2403] Fix accuracy issue paramserv BSP batch updates
Closes #791.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/63a1e2ac
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/63a1e2ac
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/63a1e2ac
Branch: refs/heads/master
Commit: 63a1e2ac59f3201ab99a6e5e71636133eec96b1b
Parents: eb179b1
Author: EdgarLGB <gu...@atos.net>
Authored: Sat Jul 7 18:40:25 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Jul 7 18:40:26 2018 -0700
----------------------------------------------------------------------
.../controlprogram/paramserv/LocalPSWorker.java | 19 ++---------
.../controlprogram/paramserv/ParamServer.java | 34 +++++++++++++++-----
.../paramserv/ParamservUtils.java | 24 ++++++++++++++
3 files changed, 52 insertions(+), 25 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index 0ed7c81..366284c 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -20,7 +20,6 @@
package org.apache.sysml.runtime.controlprogram.paramserv;
import java.util.concurrent.Callable;
-import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -30,10 +29,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
import org.apache.sysml.utils.Statistics;
public class LocalPSWorker extends PSWorker implements Callable<Void> {
@@ -84,13 +80,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
ListObject gradients = computeGradients(dataSize, totalIter, i, j);
// Accumulate the intermediate gradients
- accGradients = (accGradients==null) ?
- ParamservUtils.copyList(gradients) :
- accrueGradients(accGradients, gradients);
+ accGradients = ParamservUtils.accrueGradients(accGradients, gradients);
// Update the local model with gradients
if( j < totalIter - 1 )
params = updateModel(params, gradients, i, j, totalIter);
+ ParamservUtils.cleanupListObject(gradients);
}
// Push the gradients to ps
@@ -193,14 +188,4 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
return gradients;
}
- private ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
- IntStream.range(0, accGradients.getLength()).forEach(i -> {
- MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
- MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
- mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
- ((MatrixObject) accGradients.getData().get(i)).release();
- ((MatrixObject) gradients.getData().get(i)).release();
- });
- return accGradients;
- }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index abec267..432d4fc 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -49,7 +49,8 @@ import org.apache.sysml.utils.Statistics;
public abstract class ParamServer
{
- protected final Log LOG = LogFactory.getLog(ParamServer.class.getName());
+ protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
+ protected static final boolean ACCRUE_BSP_GRADIENTS = true;
// worker input queues and global model
protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
@@ -61,6 +62,7 @@ public abstract class ParamServer
private final FunctionCallCPInstruction _inst;
private final String _outputName;
private final boolean[] _finishedStates; // Workers' finished states
+ private ListObject _accGradients = null;
protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
// init worker queues and global model
@@ -126,17 +128,25 @@ public abstract class ParamServer
gradients.getDataSize() / 1024, workerID));
}
- // Update and redistribute the model
- Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
- _model = updateLocalModel(_ec, gradients, _model);
- if (DMLScript.STATISTICS)
- Statistics.accPSAggregationTime((long) tAgg.stop());
-
- // Redistribute model according to update type
switch(_updateType) {
case BSP: {
setFinishedState(workerID);
+
+ // Accumulate the intermediate gradients
+ if( ACCRUE_BSP_GRADIENTS )
+ _accGradients = ParamservUtils.accrueGradients(
+ _accGradients, gradients, true);
+ else
+ updateGlobalModel(gradients);
+ ParamservUtils.cleanupListObject(gradients);
+
if (allFinished()) {
+ // Update the global model with accrued gradients
+ if( ACCRUE_BSP_GRADIENTS ) {
+ updateGlobalModel(_accGradients);
+ _accGradients = null;
+ }
+
// Broadcast the updated model
resetFinishedStates();
broadcastModel();
@@ -146,6 +156,7 @@ public abstract class ParamServer
break;
}
case ASP: {
+ updateGlobalModel(gradients);
broadcastModel(workerID);
break;
}
@@ -158,6 +169,13 @@ public abstract class ParamServer
}
}
+ private void updateGlobalModel(ListObject gradients) {
+ Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+ _model = updateLocalModel(_ec, gradients, _model);
+ if (DMLScript.STATISTICS)
+ Statistics.accPSAggregationTime((long) tAgg.stop());
+ }
+
/**
* A service method for updating model with gradients
*
http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index ecfac66..3aee170 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -50,6 +50,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.functionobjects.Plus;
import org.apache.sysml.runtime.instructions.cp.Data;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -57,6 +58,7 @@ import org.apache.sysml.runtime.matrix.MetaDataFormat;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
public class ParamservUtils {
@@ -88,6 +90,10 @@ public class ParamservUtils {
public static void cleanupListObject(ExecutionContext ec, String lName) {
ListObject lo = (ListObject) ec.removeVariable(lName);
+ cleanupListObject(lo);
+ }
+
+ public static void cleanupListObject(ListObject lo) {
lo.getData().forEach(ParamservUtils::cleanupData);
}
@@ -258,4 +264,22 @@ public class ParamservUtils {
String fname = cfn[1];
return ec.getProgram().getFunctionProgramBlock(ns, fname);
}
+
+ public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
+ return accrueGradients(accGradients, gradients, false);
+ }
+
+ public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) {
+ if (accGradients == null)
+ return ParamservUtils.copyList(gradients);
+ IntStream range = IntStream.range(0, accGradients.getLength());
+ (par ? range.parallel() : range).forEach(i -> {
+ MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
+ MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
+ mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
+ ((MatrixObject) accGradients.getData().get(i)).release();
+ ((MatrixObject) gradients.getData().get(i)).release();
+ });
+ return accGradients;
+ }
}