You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/07/08 01:45:21 UTC

systemml git commit: [SYSTEMML-2403] Fix accuracy issue paramserv BSP batch updates

Repository: systemml
Updated Branches:
  refs/heads/master eb179b151 -> 63a1e2ac5


[SYSTEMML-2403] Fix accuracy issue paramserv BSP batch updates

Closes #791.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/63a1e2ac
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/63a1e2ac
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/63a1e2ac

Branch: refs/heads/master
Commit: 63a1e2ac59f3201ab99a6e5e71636133eec96b1b
Parents: eb179b1
Author: EdgarLGB <gu...@atos.net>
Authored: Sat Jul 7 18:40:25 2018 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Jul 7 18:40:26 2018 -0700

----------------------------------------------------------------------
 .../controlprogram/paramserv/LocalPSWorker.java | 19 ++---------
 .../controlprogram/paramserv/ParamServer.java   | 34 +++++++++++++++-----
 .../paramserv/ParamservUtils.java               | 24 ++++++++++++++
 3 files changed, 52 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
index 0ed7c81..366284c 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/LocalPSWorker.java
@@ -20,7 +20,6 @@
 package org.apache.sysml.runtime.controlprogram.paramserv;
 
 import java.util.concurrent.Callable;
-import java.util.stream.IntStream;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -30,10 +29,7 @@ import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
-import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysml.utils.Statistics;
 
 public class LocalPSWorker extends PSWorker implements Callable<Void> {
@@ -84,13 +80,12 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 				ListObject gradients = computeGradients(dataSize, totalIter, i, j);
 
 				// Accumulate the intermediate gradients
-				accGradients = (accGradients==null) ?
-					ParamservUtils.copyList(gradients) :
-					accrueGradients(accGradients, gradients);
+				accGradients = ParamservUtils.accrueGradients(accGradients, gradients);
 
 				// Update the local model with gradients
 				if( j < totalIter - 1 )
 					params = updateModel(params, gradients, i, j, totalIter);
+				ParamservUtils.cleanupListObject(gradients);
 			}
 
 			// Push the gradients to ps
@@ -193,14 +188,4 @@ public class LocalPSWorker extends PSWorker implements Callable<Void> {
 		return gradients;
 	}
 
-	private ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
-		IntStream.range(0, accGradients.getLength()).forEach(i -> {
-			MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
-			MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
-			mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
-			((MatrixObject) accGradients.getData().get(i)).release();
-			((MatrixObject) gradients.getData().get(i)).release();
-		});
-		return accGradients;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
index abec267..432d4fc 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamServer.java
@@ -49,7 +49,8 @@ import org.apache.sysml.utils.Statistics;
 
 public abstract class ParamServer 
 {
-	protected final Log LOG = LogFactory.getLog(ParamServer.class.getName());
+	protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
+	protected static final boolean ACCRUE_BSP_GRADIENTS = true;
 	
 	// worker input queues and global model
 	protected final Map<Integer, BlockingQueue<ListObject>> _modelMap;
@@ -61,6 +62,7 @@ public abstract class ParamServer
 	private final FunctionCallCPInstruction _inst;
 	private final String _outputName;
 	private final boolean[] _finishedStates;  // Workers' finished states
+	private ListObject _accGradients = null;
 
 	protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
 		// init worker queues and global model
@@ -126,17 +128,25 @@ public abstract class ParamServer
 					gradients.getDataSize() / 1024, workerID));
 			}
 
-			// Update and redistribute the model
-			Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
-			_model = updateLocalModel(_ec, gradients, _model);
-			if (DMLScript.STATISTICS)
-				Statistics.accPSAggregationTime((long) tAgg.stop());
-
-			// Redistribute model according to update type
 			switch(_updateType) {
 				case BSP: {
 					setFinishedState(workerID);
+
+					// Accumulate the intermediate gradients
+					if( ACCRUE_BSP_GRADIENTS )
+						_accGradients = ParamservUtils.accrueGradients(
+							_accGradients, gradients, true);
+					else
+						updateGlobalModel(gradients);
+					ParamservUtils.cleanupListObject(gradients);
+
 					if (allFinished()) {
+						// Update the global model with accrued gradients
+						if( ACCRUE_BSP_GRADIENTS ) {
+							updateGlobalModel(_accGradients);
+							_accGradients = null;
+						}
+						
 						// Broadcast the updated model
 						resetFinishedStates();
 						broadcastModel();
@@ -146,6 +156,7 @@ public abstract class ParamServer
 					break;
 				}
 				case ASP: {
+					updateGlobalModel(gradients);
 					broadcastModel(workerID);
 					break;
 				}
@@ -158,6 +169,13 @@ public abstract class ParamServer
 		}
 	}
 
+	private void updateGlobalModel(ListObject gradients) {
+		Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
+		_model = updateLocalModel(_ec, gradients, _model);
+		if (DMLScript.STATISTICS)
+			Statistics.accPSAggregationTime((long) tAgg.stop());
+	}
+
 	/**
 	 * A service method for updating model with gradients
 	 *

http://git-wip-us.apache.org/repos/asf/systemml/blob/63a1e2ac/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
index ecfac66..3aee170 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -50,6 +50,7 @@ import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
+import org.apache.sysml.runtime.functionobjects.Plus;
 import org.apache.sysml.runtime.instructions.cp.Data;
 import org.apache.sysml.runtime.instructions.cp.ListObject;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -57,6 +58,7 @@ import org.apache.sysml.runtime.matrix.MetaDataFormat;
 import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 
 public class ParamservUtils {
 
@@ -88,6 +90,10 @@ public class ParamservUtils {
 
 	public static void cleanupListObject(ExecutionContext ec, String lName) {
 		ListObject lo = (ListObject) ec.removeVariable(lName);
+		cleanupListObject(lo);
+	}
+
+	public static void cleanupListObject(ListObject lo) {
 		lo.getData().forEach(ParamservUtils::cleanupData);
 	}
 
@@ -258,4 +264,22 @@ public class ParamservUtils {
 		String fname = cfn[1];
 		return ec.getProgram().getFunctionProgramBlock(ns, fname);
 	}
+	
+	public static ListObject accrueGradients(ListObject accGradients, ListObject gradients) {
+		return accrueGradients(accGradients, gradients, false);
+	}
+	
+	public static ListObject accrueGradients(ListObject accGradients, ListObject gradients, boolean par) {
+		if (accGradients == null)
+			return ParamservUtils.copyList(gradients);
+		IntStream range = IntStream.range(0, accGradients.getLength());
+		(par ? range.parallel() : range).forEach(i -> {
+			MatrixBlock mb1 = ((MatrixObject) accGradients.getData().get(i)).acquireRead();
+			MatrixBlock mb2 = ((MatrixObject) gradients.getData().get(i)).acquireRead();
+			mb1.binaryOperationsInPlace(new BinaryOperator(Plus.getPlusFnObject()), mb2);
+			((MatrixObject) accGradients.getData().get(i)).release();
+			((MatrixObject) gradients.getData().get(i)).release();
+		});
+		return accGradients;
+	}
 }