You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/18 18:13:52 UTC

[systemds] branch master updated: [SYSTEMDS-2600, 2626] Fix federated backend request interference

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 8356ea6  [SYSTEMDS-2600,2626] Fix federated backend request interference
8356ea6 is described below

commit 8356ea6861ed5dc5cc15430bce26cf5e4bdd7c9e
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Tue Aug 18 19:47:16 2020 +0200

    [SYSTEMDS-2600,2626] Fix federated backend request interference
    
    This patch fixes two major issues of request interference from multiple
    coordinator threads.
    
    First, we now properly maintain separate execution context at the
    federated site for different request streams from parfor workers which
    otherwise could interfer (e.g., on rmvar instructions for shared input
    variables)
    
    Second, even within a stream of federated requests (e.g., execute and
    cleanup) could out output each other if there are no data dependencies
    or synchronization between them. We now added barriers for federated
    requests wherever this was necessary.
    
    Last, this patch also fixes unnecessary warning messages of the parfor
    optimizer, specifically in a setting with forced singlenode execution.
    
    Closes #1028.
---
 .../runtime/controlprogram/ParForProgramBlock.java |  1 +
 .../controlprogram/caching/CacheableData.java      | 10 +++-
 .../controlprogram/context/ExecutionContext.java   | 11 +++-
 .../context/SparkExecutionContext.java             |  2 +-
 .../federated/ExecutionContextMap.java             | 61 ++++++++++++++++++++++
 .../controlprogram/federated/FederatedRequest.java | 13 ++++-
 .../controlprogram/federated/FederatedWorker.java  | 12 ++---
 .../federated/FederatedWorkerHandler.java          | 44 ++++++++--------
 .../controlprogram/federated/FederationMap.java    | 45 ++++++++++++----
 .../parfor/opt/CostEstimatorHops.java              |  5 +-
 .../instructions/cp/VariableCPInstruction.java     |  2 +-
 .../fed/AggregateBinaryFEDInstruction.java         | 12 ++---
 .../fed/AggregateUnaryFEDInstruction.java          |  4 +-
 .../instructions/fed/AppendFEDInstruction.java     |  2 +-
 .../fed/BinaryMatrixMatrixFEDInstruction.java      |  8 +--
 .../fed/BinaryMatrixScalarFEDInstruction.java      |  4 +-
 .../runtime/instructions/fed/FEDInstruction.java   |  9 ++++
 .../instructions/fed/FEDInstructionUtils.java      | 20 ++++---
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  2 +-
 .../instructions/fed/TsmmFEDInstruction.java       |  4 +-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  2 +-
 .../functions/federated/FederatedKmeansTest.java   |  3 +-
 22 files changed, 200 insertions(+), 76 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 9e15139..c6d7e6e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -1177,6 +1177,7 @@ public class ParForProgramBlock extends ForProgramBlock
 			
 			//deep copy execution context (including prepare parfor update-in-place)
 			ExecutionContext cpEc = ProgramConverter.createDeepCopyExecutionContext(ec);
+			cpEc.setTID(pwID);
 
 			// If GPU mode is enabled, gets a GPUContext from the pool of GPUContexts
 			// and sets it in the ExecutionContext of the parfor
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index c809a84..949e60a 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -629,6 +629,10 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 		}
 	}
 	
+	public void clearData() {
+		clearData(-1);
+	}
+	
 	/**
 	 * Sets the cache block reference to <code>null</code>, abandons the old block.
 	 * Makes the "envelope" empty.  Run it to finalize the object (otherwise the
@@ -637,8 +641,10 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 	 * In-Status:  EMPTY, EVICTABLE, EVICTED;
 	 * Out-Status: EMPTY.
 	 * 
+	 * @param tid thread ID
+	 * 
 	 */
-	public synchronized void clearData() 
+	public synchronized void clearData(long tid) 
 	{
 		// check if cleanup enabled and possible 
 		if( !isCleanupEnabled() ) 
@@ -669,7 +675,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
 		
 		//clear federated matrix
 		if( _fedMapping != null )
-			_fedMapping.cleanup(_fedMapping.getID());
+			_fedMapping.cleanup(tid, _fedMapping.getID());
 		
 		// change object state EMPTY
 		setDirty(false);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index fcb5db3..a34b77e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -70,6 +70,7 @@ public class ExecutionContext {
 	
 	//symbol table
 	protected LocalVariableMap _variables;
+	protected long _tid = -1;
 	protected boolean _autoCreateVars;
 	
 	//lineage map, cache, prepared dedup blocks
@@ -131,6 +132,14 @@ public class ExecutionContext {
 	public void setAutoCreateVars(boolean flag) {
 		_autoCreateVars = flag;
 	}
+	
+	public void setTID(long tid) {
+		_tid = tid;
+	}
+	
+	public long getTID() {
+		return _tid;
+	}
 
 	/**
 	 * Get the i-th GPUContext
@@ -750,7 +759,7 @@ public class ExecutionContext {
 		try {
 			//compute ref count only if matrix cleanup actually necessary
 			if ( mo.isCleanupEnabled() && !getVariables().hasReferences(mo) )  {
-				mo.clearData(); //clean cached data
+				mo.clearData(getTID()); //clean cached data
 				if( fileExists ) {
 					HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
 					HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName()+".mtd");
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index bb6a94d..65348f1 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -1350,7 +1350,7 @@ public class SparkExecutionContext extends ExecutionContext
 			//compute ref count only if matrix cleanup actually necessary
 			if( !getVariables().hasReferences(mo) ) {
 				//clean cached data
-				mo.clearData();
+				mo.clearData(getTID());
 
 				//clean hdfs data if no pending rdd operations on it
 				if( mo.isHDFSFileExists() && mo.getFileName()!=null ) {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
new file mode 100644
index 0000000..1d06f46
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/ExecutionContextMap.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.federated;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
+
+public class ExecutionContextMap {
+	private final ExecutionContext _main;
+	private final Map<Long, ExecutionContext> _parEc;
+	
+	public ExecutionContextMap() {
+		_main = createExecutionContext();
+		_parEc = new ConcurrentHashMap<>();
+	}
+	
+	public ExecutionContext get(long tid) {
+		//return main execution context
+		if( tid <= 0 )
+			return _main;
+		
+		//atomic probe, create if necessary, and return
+		return _parEc.computeIfAbsent(tid,
+			k -> deriveExecutionContext(_main));
+	}
+	
+	private static ExecutionContext createExecutionContext() {
+		ExecutionContext ec = ExecutionContextFactory.createContext();
+		ec.setAutoCreateVars(true); //w/o createvar inst
+		return ec;
+	}
+	
+	private static ExecutionContext deriveExecutionContext(ExecutionContext ec) {
+		//derive execution context from main to make shared variables available
+		//but allow normal instruction processing and removal if necessary
+		ExecutionContext ec2 = ExecutionContextFactory
+			.createContext(ec.getVariables(), ec.getProgram());
+		ec2.setAutoCreateVars(true); //w/o createvar inst
+		return ec2;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 5618d36..d62e6f6 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -41,6 +41,7 @@ public class FederatedRequest implements Serializable {
 	
 	private RequestType _method;
 	private long _id;
+	private long _tid;
 	private List<Object> _data;
 	private boolean _checkPrivacy;
 	
@@ -73,6 +74,14 @@ public class FederatedRequest implements Serializable {
 		return _id;
 	}
 	
+	public long getTID() {
+		return _tid;
+	}
+	
+	public void setTID(long tid) {
+		_tid = tid;
+	}
+	
 	public Object getParam(int i) {
 		return _data.get(i);
 	}
@@ -112,7 +121,9 @@ public class FederatedRequest implements Serializable {
 		StringBuilder sb = new StringBuilder("FederatedRequest[");
 		sb.append(_method); sb.append(";");
 		sb.append(_id); sb.append(";");
-		sb.append(Arrays.toString(_data.toArray())); sb.append("]");
+		sb.append("t"); sb.append(_tid); sb.append(";");
+		if( _method != RequestType.PUT_VAR )
+			sb.append(Arrays.toString(_data.toArray())); sb.append("]");
 		return sb.toString();
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index 1eca3a9..dae75e4 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -32,21 +32,15 @@ import io.netty.handler.codec.serialization.ObjectDecoder;
 import io.netty.handler.codec.serialization.ObjectEncoder;
 import org.apache.log4j.Logger;
 import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
 
 public class FederatedWorker {
 	protected static Logger log = Logger.getLogger(FederatedWorker.class);
 
 	private int _port;
-	private final ExecutionContext _ec;
-	private final BasicProgramBlock _pb;
+	private final ExecutionContextMap _ecm;
 	
 	public FederatedWorker(int port) {
-		_ec = ExecutionContextFactory.createContext();
-		_ec.setAutoCreateVars(true); //w/o createvar inst
-		_pb = new BasicProgramBlock(null);
+		_ecm = new ExecutionContextMap();
 		_port = (port == -1) ?
 			Integer.parseInt(DMLConfig.DEFAULT_FEDERATED_PORT) : port;
 	}
@@ -65,7 +59,7 @@ public class FederatedWorker {
 							new ObjectDecoder(Integer.MAX_VALUE,
 								ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())))
 						.addLast("ObjectEncoder", new ObjectEncoder())
-						.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_ec, _pb));
+						.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_ecm));
 				}
 			}).option(ChannelOption.SO_BACKLOG, 128).childOption(ChannelOption.SO_KEEPALIVE, true);
 		try {
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 1afbfb1..00a8685 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -62,12 +62,13 @@ import java.util.Arrays;
 public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
 
-	private final ExecutionContext _ec;
-	private final BasicProgramBlock _pb;
+	private final ExecutionContextMap _ecm;
 	
-	public FederatedWorkerHandler(ExecutionContext ec, BasicProgramBlock pb) {
-		_ec = ec;
-		_pb = pb;
+	public FederatedWorkerHandler(ExecutionContextMap ecm) {
+		//Note: federated worker handler created for every command;
+		//and concurrent parfor threads at coordinator need separate
+		//execution contexts at the federated sites too
+		_ecm = ecm;
 	}
 
 	@Override
@@ -131,10 +132,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		checkNumParams(request.getNumParams(), 2);
 		String filename = (String) request.getParam(0);
 		DataType dt = DataType.valueOf((String)request.getParam(1));
-		return readData(filename, dt, request.getID());
+		return readData(filename, dt, request.getID(), request.getTID());
 	}
 
-	private FederatedResponse readData(String filename, Types.DataType dataType, long id) {
+	private FederatedResponse readData(String filename, Types.DataType dataType, long id, long tid) {
 		MatrixCharacteristics mc = new MatrixCharacteristics();
 		mc.setBlocksize(ConfigurationManager.getBlocksize());
 		CacheableData<?> cd;
@@ -180,7 +181,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 		cd.release();
 		
 		//TODO spawn async load of data, otherwise on first access
-		_ec.setVariable(String.valueOf(id), cd);
+		_ecm.get(tid).setVariable(String.valueOf(id), cd);
 		cd.enableCleanup(false); //guard against deletion
 		
 		if (dataType == Types.DataType.FRAME) {
@@ -193,7 +194,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	private FederatedResponse putVariable(FederatedRequest request) {
 		checkNumParams(request.getNumParams(), 1);
 		String varname = String.valueOf(request.getID());
-		if( _ec.containsVariable(varname) ) {
+		ExecutionContext ec = _ecm.get(request.getTID());
+		if( ec.containsVariable(varname) ) {
 			return new FederatedResponse(ResponseType.ERROR,
 				"Variable "+request.getID()+" already existing.");
 		}
@@ -206,22 +208,19 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 			data = (ScalarObject) request.getParam(0);
 		
 		//set variable and construct empty response
-		_ec.setVariable(varname, data);
+		ec.setVariable(varname, data);
 		return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
 	}
 	
 	private FederatedResponse getVariable(FederatedRequest request) {
 		checkNumParams(request.getNumParams(), 0);
-		if( !_ec.containsVariable(String.valueOf(request.getID())) ) {
+		ExecutionContext ec = _ecm.get(request.getTID());
+		if( !ec.containsVariable(String.valueOf(request.getID())) ) {
 			return new FederatedResponse(ResponseType.ERROR,
 				"Variable "+request.getID()+" does not exist at federated worker.");
 		}
 		//get variable and construct response
-		return getVariableData(request.getID());
-	}
-	
-	private FederatedResponse getVariableData(long varID) {
-		Data dataObject = _ec.getVariable(String.valueOf(varID));
+		Data dataObject = ec.getVariable(String.valueOf(request.getID()));
 		dataObject = PrivacyMonitor.handlePrivacy(dataObject);
 		switch (dataObject.getDataType()) {
 			case TENSOR:
@@ -240,11 +239,13 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	}
 	
 	private FederatedResponse execInstruction(FederatedRequest request) {
-		_pb.getInstructions().clear();
-		_pb.getInstructions().add(InstructionParser
+		ExecutionContext ec = _ecm.get(request.getTID());
+		BasicProgramBlock pb = new BasicProgramBlock(null);
+		pb.getInstructions().clear();
+		pb.getInstructions().add(InstructionParser
 			.parseSingleInstruction((String)request.getParam(0)));
 		try {
-			_pb.execute(_ec); //execute single instruction
+			pb.execute(ec); //execute single instruction
 		}
 		catch(Exception ex) {
 			return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
@@ -254,16 +255,17 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 	
 	private FederatedResponse execUDF(FederatedRequest request) {
 		checkNumParams(request.getNumParams(), 1);
+		ExecutionContext ec = _ecm.get(request.getTID());
 		
 		//get function and input parameters
 		FederatedUDF udf = (FederatedUDF) request.getParam(0);
 		Data[] inputs = Arrays.stream(udf.getInputIDs())
-			.mapToObj(id -> _ec.getVariable(String.valueOf(id)))
+			.mapToObj(id -> ec.getVariable(String.valueOf(id)))
 			.toArray(Data[]::new);
 		
 		//execute user-defined function
 		try {
-			return udf.execute(_ec, inputs);
+			return udf.execute(ec, inputs);
 		}
 		catch(Exception ex) {
 			return new FederatedResponse(ResponseType.ERROR, ex.getMessage());
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index d323bad..371c3ff 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.controlprogram.federated;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
@@ -95,22 +96,37 @@ public class FederationMap
 		return ret.toArray(new FederatedRequest[0]);
 	}
 	
-	@SuppressWarnings("unchecked")
-	public Future<FederatedResponse>[] execute(FederatedRequest... fr) {
-		List<Future<FederatedResponse>> ret = new ArrayList<>();
-		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-			ret.add(e.getValue().executeFederatedOperation(fr));
-		return ret.toArray(new Future[0]);
+	public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
+		return execute(tid, false, fr);
+	}
+	
+	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest... fr) {
+		return execute(tid, wait, null, fr);
+	}
+	
+	public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] frSlices, FederatedRequest... fr) {
+		return execute(tid, false, frSlices, fr);
 	}
 	
 	@SuppressWarnings("unchecked")
-	public Future<FederatedResponse>[] execute(FederatedRequest[] frSlices, FederatedRequest... fr) {
-		//executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
+	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices, FederatedRequest... fr) {
+		// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
+		setThreadID(tid, frSlices, fr);
 		List<Future<FederatedResponse>> ret = new ArrayList<>(); 
 		int pos = 0;
 		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-			ret.add(e.getValue().executeFederatedOperation(addAll(frSlices[pos++], fr)));
-		return ret.toArray(new Future[0]);
+			ret.add(e.getValue().executeFederatedOperation(
+				(frSlices!=null) ? addAll(frSlices[pos++], fr) : fr));
+		
+		// prepare results (future federated responses), with optional wait to ensure the 
+		// order of requests without data dependencies (e.g., cleanup RPCs)
+		Future<FederatedResponse>[] ret2 = ret.toArray(new Future[0]);
+		if( wait ) {
+			Arrays.stream(ret2).forEach(e -> {
+				try {e.get();} catch(Exception ex) {throw new DMLRuntimeException(ex);}
+			});
+		}
+		return ret2;
 	}
 	
 	public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
@@ -125,9 +141,10 @@ public class FederationMap
 		return readResponses;
 	}
 	
-	public void cleanup(long... id) {
+	public void cleanup(long tid, long... id) {
 		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
 			VariableCPInstruction.prepareRemoveInstruction(id).toString());
+		request.setTID(tid);
 		for(FederatedData fd : _fedMap.values())
 			fd.executeFederatedOperation(request);
 	}
@@ -204,6 +221,12 @@ public class FederationMap
 		fedMapCopy._ID = newVarID;
 		return fedMapCopy;
 	}
+	
+	private static void setThreadID(long tid, FederatedRequest[]... frsets) {
+		for( FederatedRequest[] frset : frsets )
+			if( frset != null )
+				Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
+	}
 
 	private static class MappingTask implements Callable<Void> {
 		private final FederatedRange _range;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
index a881d37..eb70d0c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.java
@@ -58,6 +58,7 @@ public class CostEstimatorHops extends CostEstimator
 		
 		//handle specific cases 
 		double DEFAULT_MEM_REMOTE = OptimizerUtils.isSparkExecutionMode() ? DEFAULT_MEM_SP : 0;
+		boolean forcedExec =  DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE || h.getForcedExecType()!=null;
 		
 		if( value >= DEFAULT_MEM_REMOTE )
 		{
@@ -67,7 +68,7 @@ public class CostEstimatorHops extends CostEstimator
 			}
 			//check for invalid cp memory estimate
 			else if ( h.getExecType()==ExecType.CP && value >= OptimizerUtils.getLocalMemBudget() ) {
-				if( DMLScript.getGlobalExecMode() != ExecMode.SINGLE_NODE && h.getForcedExecType()==null )
+				if( !forcedExec )
 					LOG.warn("Memory estimate larger than budget but CP exec type (op="+h.getOpString()+", name="+h.getName()+", memest="+h.getMemEstimate()+").");
 				value = DEFAULT_MEM_REMOTE;
 			}
@@ -84,7 +85,7 @@ public class CostEstimatorHops extends CostEstimator
 			value = DEFAULT_MEM_REMOTE;
 		}
 		
-		if( value <= 0 ) { //no mem estimate
+		if( value <= 0 && !forcedExec ) { //no mem estimate
 			LOG.warn("Cannot get memory estimate for hop (op="+h.getOpString()+", name="+h.getName()+", memest="+h.getMemEstimate()+").");
 			value = CostEstimator.DEFAULT_MEM_ESTIMATE_CP;
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 96cb4c6..c2a05da 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -741,7 +741,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace
 			// no other variable in the symbol table points to the same Data object as that of input1.getName()
 			
 			//remove matrix object from cache
-			m.clearData();
+			m.clearData(ec.getTID());
 		}
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 14f81bf..6fd6173 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -69,15 +69,15 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
 			if( mo2.getNumColumns() == 1 ) { //MV
 				FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
 				//execute federated operations and aggregate
-				Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2, fr3);
+				Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2, fr3);
 				MatrixBlock ret = FederationUtils.rbind(tmp);
-				mo1.getFedMapping().cleanup(fr1.getID(), fr2.getID());
+				mo1.getFedMapping().cleanup(getTID(), fr1.getID(), fr2.getID());
 				ec.setMatrixOutput(output.getName(), ret);
 			}
 			else { //MM
 				//execute federated operations and aggregate
-				mo1.getFedMapping().execute(fr1, fr2);
-				mo1.getFedMapping().cleanup(fr1.getID());
+				mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+				mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 				MatrixObject out = ec.getMatrixObject(output);
 				out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
 				out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
@@ -91,9 +91,9 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
 				new CPOperand[]{input1, input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
 			FederatedRequest fr3 = new FederatedRequest(RequestType.GET_VAR, fr2.getID());
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(fr1, fr2, fr3);
+			Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(getTID(), fr1, fr2, fr3);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo2.getFedMapping().cleanup(fr1[0].getID(), fr2.getID());
+			mo2.getFedMapping().cleanup(getTID(), fr1[0].getID(), fr2.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else { //other combinations
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index a9b655b..e87bf57 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -63,11 +63,11 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
 		
 		//execute federated commands and cleanups
 		FederationMap map = in.getFedMapping();
-		Future<FederatedResponse>[] tmp = map.execute(fr1, fr2);
-		map.cleanup(fr1.getID());
+		Future<FederatedResponse>[] tmp = map.execute(getTID(), fr1, fr2);
 		if( output.isScalar() )
 			ec.setVariable(output.getName(), FederationUtils.aggScalar(aop, tmp));
 		else
 			ec.setMatrixOutput(output.getName(), FederationUtils.aggMatrix(aop, tmp, map));
+		map.cleanup(getTID(), fr1.getID());
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 8fed7f7..985d117 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -80,7 +80,7 @@ public class AppendFEDInstruction extends BinaryFEDInstruction {
 			FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
 			FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{input1, input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
-			mo1.getFedMapping().execute(fr1, fr2);
+			mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
 			//derive new fed mapping for output
 			MatrixObject out = ec.getMatrixObject(output);
 			out.getDataCharacteristics().set(dc1.getRows(), dc1.getCols()+dc2.getCols(),
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 7813f6a..7166373 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -51,16 +51,16 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
 			fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
 				new long[]{mo1.getFedMapping().getID(), fr1[0].getID()});
 			//execute federated instruction and cleanup intermediates
-			mo1.getFedMapping().execute(fr1, fr2);
-			mo1.getFedMapping().cleanup(fr1[0].getID());
+			mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+			mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
 		}
 		else { //MM or MV col vector
 			FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
 			fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
 				new long[]{mo1.getFedMapping().getID(), fr1.getID()});
 			//execute federated instruction and cleanup intermediates
-			mo1.getFedMapping().execute(fr1, fr2);
-			mo1.getFedMapping().cleanup(fr1.getID());
+			mo1.getFedMapping().execute(getTID(), true, fr1, fr2);
+			mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 		}
 		
 		//derive new fed mapping for output
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
index 0e05ca8..75bfe33 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixScalarFEDInstruction.java
@@ -46,10 +46,10 @@ public class BinaryMatrixScalarFEDInstruction extends BinaryFEDInstruction
 			new CPOperand[]{matrix, (fr1 != null)?scalar:null},
 			new long[]{mo.getFedMapping().getID(), (fr1 != null)?fr1.getID():-1});
 		
-		mo.getFedMapping().execute((fr1!=null) ?
+		mo.getFedMapping().execute(getTID(), true, (fr1!=null) ?
 			new FederatedRequest[]{fr1, fr2}: new FederatedRequest[]{fr2});
 		if( fr1 != null )
-			mo.getFedMapping().cleanup(fr1.getID());
+			mo.getFedMapping().cleanup(getTID(), fr1.getID());
 		
 		//derive new fed mapping for output
 		MatrixObject out = ec.getMatrixObject(output);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
index 9e58e52..6df1b1e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstruction.java
@@ -39,6 +39,7 @@ public abstract class FEDInstruction extends Instruction {
 	
 	protected final FEDType _fedType;
 	protected final Operator _optr;
+	protected long _tid = -1; //main
 	
 	protected FEDInstruction(FEDType type, String opcode, String istr) {
 		this(type, null, opcode, istr);
@@ -60,6 +61,14 @@ public abstract class FEDInstruction extends Instruction {
 		return _fedType;
 	}
 	
+	public long getTID() {
+		return _tid;
+	}
+	
+	public void setTID(long tid) {
+		_tid = tid;
+	}
+	
 	@Override
 	public Instruction preprocessInstruction(ExecutionContext ec) {
 		Instruction tmp = super.preprocessInstruction(ec);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 0a5a2a2..4325456 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -34,13 +34,14 @@ public class FEDInstructionUtils {
 	// counterpart, since we do not propagate the information that a matrix is federated, therefore we can not decide
 	// to choose a federated instruction earlier.
 	public static Instruction checkAndReplaceCP(Instruction inst, ExecutionContext ec) {
+		FEDInstruction fedinst = null;
 		if (inst instanceof AggregateBinaryCPInstruction) {
 			AggregateBinaryCPInstruction instruction = (AggregateBinaryCPInstruction) inst;
 			if( instruction.input1.isMatrix() && instruction.input2.isMatrix() ) {
 				MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
 				MatrixObject mo2 = ec.getMatrixObject(instruction.input2);
 				if (mo1.isFederated() || mo2.isFederated()) {
-					return AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+					fedinst = AggregateBinaryFEDInstruction.parseInstruction(inst.getInstructionString());
 				}
 			}
 		}
@@ -49,7 +50,7 @@ public class FEDInstructionUtils {
 			if( instruction.input1.isMatrix() && ec.containsVariable(instruction.input1) ) {
 				MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
 				if (mo1.isFederated() && instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
-					return AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
+					fedinst = AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
 			}
 		}
 		else if (inst instanceof BinaryCPInstruction) {
@@ -57,13 +58,13 @@ public class FEDInstructionUtils {
 			if( (instruction.input1.isMatrix() && ec.getMatrixObject(instruction.input1).isFederated())
 				|| (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederated()) ) {
 				if(!instruction.getOpcode().equals("append")) //TODO support rbind/cbind
-					return BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
+					fedinst = BinaryFEDInstruction.parseInstruction(inst.getInstructionString());
 			}
 		}
 		else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
 			ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction)inst;
 			if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
-				return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
+				fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
 			}
 		}
 		else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
@@ -71,7 +72,7 @@ public class FEDInstructionUtils {
 			if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
 				CacheableData<?> fo = ec.getCacheableData(minst.input1);
 				if(fo.isFederated()) {
-					return MultiReturnParameterizedBuiltinFEDInstruction
+					fedinst = MultiReturnParameterizedBuiltinFEDInstruction
 						.parseInstruction(minst.getInstructionString());
 				}
 			}
@@ -80,8 +81,15 @@ public class FEDInstructionUtils {
 			MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
 			MatrixObject mo = ec.getMatrixObject(linst.input1);
 			if( mo.isFederated() )
-				return TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
+				fedinst = TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
 		}
+		
+		//set thread id for federated context management
+		if( fedinst != null ) {
+			fedinst.setTID(ec.getTID());
+			return fedinst;
+		}
+		
 		return inst;
 	}
 	
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 3a5ff8a..ec28965 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -99,7 +99,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 			MatrixObject mo = getTarget(ec);
 			FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
 				new CPOperand[]{getTargetOperand()}, new long[]{mo.getFedMapping().getID()});
-			mo.getFedMapping().execute(fr1);
+			mo.getFedMapping().execute(getTID(), true, fr1);
 			
 			//derive new fed mapping for output
 			MatrixObject out = ec.getMatrixObject(output);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index a3061ed..292bced 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -69,9 +69,9 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
 			FederatedRequest fr2 = new FederatedRequest(RequestType.GET_VAR, fr1.getID());
 			
 			//execute federated operations and aggregate
-			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(fr1, fr2);
+			Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(getTID(), fr1, fr2);
 			MatrixBlock ret = FederationUtils.aggAdd(tmp);
-			mo1.getFedMapping().cleanup(fr1.getID());
+			mo1.getFedMapping().cleanup(getTID(), fr1.getID());
 			ec.setMatrixOutput(output.getName(), ret);
 		}
 		else { //other combinations
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 7e63127..b40a637 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -100,7 +100,7 @@ public abstract class AutomatedTestBase {
 	public static final boolean TEST_GPU = false;
 	public static final double GPU_TOLERANCE = 1e-9;
 
-	public static final int FED_WORKER_WAIT = 500; // in ms
+	public static final int FED_WORKER_WAIT = 750; // in ms
 
 	// With OpenJDK 8u242 on Windows, the new changes in JDK are not allowing
 	// to set the native library paths internally thus breaking the code.
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
index a216fb3..6991797 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/FederatedKmeansTest.java
@@ -62,8 +62,7 @@ public class FederatedKmeansTest extends AutomatedTestBase {
 		// rows have to be even and > 1
 		return Arrays.asList(new Object[][] {
 			{10000, 10, 1}, {2000, 50, 1}, {1000, 100, 1},
-			//TODO support for multi-threaded federated interactions
-			//{10000, 10, 16}, {2000, 50, 16}, {1000, 100, 16}, //concurrent requests
+			{10000, 10, 4}, {2000, 50, 4}, {1000, 100, 4}, //concurrent requests
 		});
 	}