You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@systemds.apache.org by GitBox <gi...@apache.org> on 2021/01/05 10:38:22 UTC

[GitHub] [systemds] sebwrede commented on a change in pull request #1141: [SYSTEMDS-2550] Batch scaling and weighing of imbalanced workers

sebwrede commented on a change in pull request #1141:
URL: https://github.com/apache/systemds/pull/1141#discussion_r551816533



##########
File path: src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -492,8 +523,8 @@ public FederatedResponse execute(ExecutionContext ec, Data... data) {
 				ParamservUtils.cleanupData(ec, Statement.PS_FEATURES);
 				ParamservUtils.cleanupData(ec, Statement.PS_LABELS);
 				ec.removeVariable(ec.getVariable(Statement.PS_FED_BATCHCOUNTER_VARID).toString());
-				if( LOG.isInfoEnabled() )
-					LOG.info("[+]" + " completed batch " + localBatchNum);
+				/*if( LOG.isInfoEnabled() )
+					LOG.info("[+]" + " completed batch " + localBatchNum);*/

Review comment:
       If this is no longer needed, it should be deleted. 

##########
File path: src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
##########
@@ -58,48 +61,68 @@
 public class FederatedPSControlThread extends PSWorker implements Callable<Void> {
 	private static final long serialVersionUID = 6846648059569648791L;
 	protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
-	
-	Statement.PSRuntimeBalancing _runtimeBalancing;
+
 	FederatedData _featuresData;
 	FederatedData _labelsData;
 	final long _localStartBatchNumVarID;
 	final long _modelVarID;
-	int _numBatchesPerGlobalEpoch;
+
+	// runtime balancing
+	Statement.PSRuntimeBalancing _runtimeBalancing;
+	int _numBatchesPerEpoch;
 	int _possibleBatchesPerLocalEpoch;
+	boolean _weighing;
+	double _weighingFactor = 1;
 	boolean _cycleStartAt0 = false;
 
-	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq, Statement.PSRuntimeBalancing runtimeBalancing, int epochs, long batchSize, int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {
+	public FederatedPSControlThread(int workerID, String updFunc, Statement.PSFrequency freq,
+									Statement.PSRuntimeBalancing runtimeBalancing, boolean weighing, int epochs, long batchSize,
+									int numBatchesPerGlobalEpoch, ExecutionContext ec, ParamServer ps) {
 		super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
 
-		_numBatchesPerGlobalEpoch = numBatchesPerGlobalEpoch;
+		_numBatchesPerEpoch = numBatchesPerGlobalEpoch;
 		_runtimeBalancing = runtimeBalancing;
+		_weighing = weighing;
 		// generate the IDs for model and batch counter. These get overwritten on the federated worker each time
 		_localStartBatchNumVarID = FederationUtils.getNextFedDataID();
 		_modelVarID = FederationUtils.getNextFedDataID();
 	}
 
 	/**
 	 * Sets up the federated worker and control thread
+	 *
+	 * @param weighingFactor Gradients from this worker will be multiplied by this factor if weighing is enabled
 	 */
-	public void setup() {
+	public void setup(double weighingFactor) {
 		// prepare features and labels
 		_featuresData = (FederatedData) _features.getFedMapping().getMap().values().toArray()[0];
 		_labelsData = (FederatedData) _labels.getFedMapping().getMap().values().toArray()[0];
 
-		// calculate number of batches and get data size
+		// weighing factor is always set, but only used when weighing is specified
+		_weighingFactor = weighingFactor;
+
+		// different runtime balancing calculations
 		long dataSize = _features.getNumRows();
-		_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
-		if(!(_runtimeBalancing == Statement.PSRuntimeBalancing.RUN_MIN 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX)) {
-			_numBatchesPerGlobalEpoch = _possibleBatchesPerLocalEpoch;
+
+		// calculate scaled batch size if balancing via batch size.
+		// In some cases there will be some cycling
+		if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH) {
+			_batchSize = (int) Math.ceil((float) dataSize / _numBatchesPerEpoch);
 		}
 
-		if(_runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH 
-			|| _runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH_AND_WEIGH) {
-			throw new NotImplementedException();
+		// Calculate possible batches with batch size
+		_possibleBatchesPerLocalEpoch = (int) Math.ceil((double) dataSize / _batchSize);
+
+		// If no runtime balancing is specified, just run possible number of batches
+		// WARNING: Will get stuck on miss match
+		if(_runtimeBalancing == Statement.PSRuntimeBalancing.NONE) {
+			_numBatchesPerEpoch = _possibleBatchesPerLocalEpoch;

Review comment:
       Could you elaborate what a "miss match" means and also provide this as a warning in the log if a miss match could occur? If it is provided in the warning log, the user is warned that there could be problems with the way they try to execute it and the problems can then be mitigated by the user. 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -125,15 +127,24 @@ private void runFederated(ExecutionContext ec) {
 		System.out.println("[+] Running in federated mode");
 
 		// get inputs
-		PSFrequency freq = getFrequency();
-		PSUpdateType updateType = getUpdateType();
-		PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
-		FederatedPSScheme federatedPSScheme = getFederatedScheme();
 		String updFunc = getParam(PS_UPDATE_FUN);
 		String aggFunc = getParam(PS_AGGREGATION_FUN);
+		PSUpdateType updateType = getUpdateType();
+		PSFrequency freq = getFrequency();
+		FederatedPSScheme federatedPSScheme = getFederatedScheme();
+		PSRuntimeBalancing runtimeBalancing = getRuntimeBalancing();
+		boolean weighing = getWeighing();
+		int seed = getSeed();
+
+		LOG.info("[+] Update Type: " + updateType);
+		LOG.info("[+] Frequency: " + freq);
+		LOG.info("[+] Data Partitioning: " + federatedPSScheme);
+		LOG.info("[+] Runtime Balancing: " + runtimeBalancing);
+		LOG.info("[+] Weighing: " + weighing);
+		LOG.info("[+] Seed: " + seed);
 
 		// partition federated data
-		DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme)
+		DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed)
 				.doPartitioning(ec.getMatrixObject(getParam(PS_FEATURES)), ec.getMatrixObject(getParam(PS_LABELS)));
 		List<MatrixObject> pFeatures = result._pFeatures;
 		List<MatrixObject> pLabels = result._pLabels;

Review comment:
       Do you need to put pFeatures and pLabels in variables? It looks like it is only used once and much later in the method. The method could be shorter if you take pFeatures and pLabels directly from the result object. 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java
##########
@@ -143,7 +154,8 @@ private void runFederated(ExecutionContext ec) {
 		int numBatchesPerEpoch = 0;
 		if(runtimeBalancing == PSRuntimeBalancing.RUN_MIN) {
 			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._minRows / (float) getBatchSize());
-		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG) {
+		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_AVG
+				|| runtimeBalancing == PSRuntimeBalancing.SCALE_BATCH) {
 			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._avgRows / (float) getBatchSize());
  		} else if (runtimeBalancing == PSRuntimeBalancing.CYCLE_MAX) {
 			numBatchesPerEpoch = (int) Math.ceil(result._balanceMetrics._maxRows / (float) getBatchSize());

Review comment:
       To make this method shorter and easier to read, these lines could be a separate method that returns numBatchesPerEpoch.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org