You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2020/11/17 22:19:25 UTC

[systemds] branch master updated: [SYSTEMDS-2730] Modified fed removeEmpty

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a8cb78  [SYSTEMDS-2730] Modified fed removeEmpty
2a8cb78 is described below

commit 2a8cb78827daed00fe016f6af22ab24f154be40c
Author: Olga <ov...@gmail.com>
AuthorDate: Tue Nov 17 21:21:55 2020 +0100

    [SYSTEMDS-2730] Modified fed removeEmpty
    
    This commits change the remove empty federated command, to
    among other things improve the split function performance.
    
    Closes #1109
---
 scripts/builtin/split.dml                          |  13 +-
 .../controlprogram/federated/FederationMap.java    | 204 +++++++++---------
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 234 +++++++++------------
 .../primitives/FederatedRemoveEmptyTest.java       |  10 +-
 .../federated/primitives/FederatedSplitTest.java   |   8 +-
 .../functions/federated/FederatedSplitTest.dml     |   3 +-
 .../federated/FederatedSplitTestReference.dml      |   2 +-
 7 files changed, 217 insertions(+), 257 deletions(-)

diff --git a/scripts/builtin/split.dml b/scripts/builtin/split.dml
index 5e6f1c5..c5c1066 100644
--- a/scripts/builtin/split.dml
+++ b/scripts/builtin/split.dml
@@ -53,12 +53,13 @@ m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean con
   }
   # sampled train/test splits
   else {
+    # create random select vector according to f and then
+    # extract tuples via permutation (selection) matrix multiply
+    # or directly via removeEmpty by selection vector
     I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
-    P1 = removeEmpty(target=diag(I), margin="rows", select=I);
-    P2 = removeEmpty(target=diag(I==0), margin="rows", select=I==0);
-    Xtrain = P1 %*% X;
-    Ytrain = P1 %*% Y;
-    Xtest = P2 %*% X;
-    Ytest = P2 %*% Y;
+    Xtrain = removeEmpty(target=X, margin="rows", select=I);
+    Ytrain = removeEmpty(target=Y, margin="rows", select=I);
+    Xtest = removeEmpty(target=X, margin="rows", select=(I==0));
+    Ytest = removeEmpty(target=Y, margin="rows", select=(I==0));
   }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 037ce8c..2ce3cb7 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -43,12 +43,11 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.IndexRange;
 
-public class FederationMap
-{
+public class FederationMap {
 	public enum FType {
-		ROW, //row partitioned, groups of rows
-		COL, //column partitioned, groups of columns
-		FULL,  // Meaning both Row and Column indicating a single federated location and a full matrix
+		ROW, // row partitioned, groups of rows
+		COL, // column partitioned, groups of columns
+		FULL, // Meaning both Row and Column indicating a single federated location and a full matrix
 		OTHER;
 
 		public boolean isRowPartitioned() {
@@ -56,11 +55,11 @@ public class FederationMap
 		}
 
 		public boolean isColPartitioned() {
-			return this == ROW || this == FULL;
+			return this == COL || this == FULL;
 		}
 
-		public boolean isType(FType t){
-			switch (t) {
+		public boolean isType(FType t) {
+			switch(t) {
 				case ROW:
 					return isRowPartitioned();
 				case COL:
@@ -72,161 +71,159 @@ public class FederationMap
 			}
 		}
 	}
-	
+
 	private long _ID = -1;
 	private final Map<FederatedRange, FederatedData> _fedMap;
 	private FType _type;
-	
+
 	public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
 		this(-1, fedMap);
 	}
-	
+
 	public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
 		this(ID, fedMap, FType.OTHER);
 	}
-	
+
 	public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
 		_ID = ID;
 		_fedMap = fedMap;
 		_type = type;
 	}
-	
+
 	public long getID() {
 		return _ID;
 	}
-	
+
 	public FType getType() {
 		return _type;
 	}
-	
+
 	public boolean isInitialized() {
 		return _ID >= 0;
 	}
-	
+
 	public void setType(FType type) {
 		_type = type;
 	}
-	
+
 	public int getSize() {
 		return _fedMap.size();
 	}
-	
+
 	public FederatedRange[] getFederatedRanges() {
 		return _fedMap.keySet().toArray(new FederatedRange[0]);
 	}
 
-	public Map<FederatedRange, FederatedData> getFedMapping(){
+	public Map<FederatedRange, FederatedData> getFedMapping() {
 		return _fedMap;
 	}
-	
+
 	public FederatedRequest broadcast(CacheableData<?> data) {
-		//prepare single request for all federated data
+		// prepare single request for all federated data
 		long id = FederationUtils.getNextFedDataID();
 		CacheBlock cb = data.acquireReadAndRelease();
 		return new FederatedRequest(RequestType.PUT_VAR, id, cb);
 	}
-	
+
 	public FederatedRequest broadcast(ScalarObject scalar) {
-		//prepare single request for all federated data
+		// prepare single request for all federated data
 		long id = FederationUtils.getNextFedDataID();
 		return new FederatedRequest(RequestType.PUT_VAR, id, scalar);
 	}
-	
+
 	/**
-	 * Creates separate slices of an input data object according
-	 * to the index ranges of federated data. Theses slices are then
-	 * wrapped in separate federated requests for broadcasting.
+	 * Creates separate slices of an input data object according to the index ranges of federated data. Theses slices
+	 * are then wrapped in separate federated requests for broadcasting.
 	 * 
-	 * @param data input data object (matrix, tensor, frame)
-	 * @param transposed false: slice according to federated data,
-	 *                   true: slice according to transposed federated data
+	 * @param data       input data object (matrix, tensor, frame)
+	 * @param transposed false: slice according to federated data, true: slice according to transposed federated data
 	 * @return array of federated requests corresponding to federated data
 	 */
 	public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transposed) {
-		//prepare broadcast id and pin input
+		// prepare broadcast id and pin input
 		long id = FederationUtils.getNextFedDataID();
 		CacheBlock cb = data.acquireReadAndRelease();
-		
-		//prepare indexing ranges
+
+		// prepare indexing ranges
 		int[][] ix = new int[_fedMap.size()][];
 		int pos = 0;
 		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
-			int rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
-			int ru = transposed ? cb.getNumRows()-1 : e.getKey().getEndDimsInt()[0]-1;
-			int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
-			int cu = transposed ? e.getKey().getEndDimsInt()[0]-1 : cb.getNumColumns()-1;
+			int rl, ru, cl, cu;
+			// TODO Handle different cases than ROW aligned Matrices.
+			rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
+			ru = transposed ? cb.getNumRows() - 1 : e.getKey().getEndDimsInt()[0] - 1;
+			cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
+			cu = transposed ? e.getKey().getEndDimsInt()[0] - 1 : cb.getNumColumns() - 1;
 			ix[pos++] = new int[] {rl, ru, cl, cu};
 		}
-		
-		//multi-threaded block slicing and federation request creation
+
+		// multi-threaded block slicing and federation request creation
 		FederatedRequest[] ret = new FederatedRequest[ix.length];
-		Arrays.parallelSetAll(ret, i ->
-			new FederatedRequest(RequestType.PUT_VAR, id,
-			cb.slice(ix[i][0], ix[i][1], ix[i][2], ix[i][3], new MatrixBlock())));
+		Arrays.parallelSetAll(ret,
+			i -> new FederatedRequest(RequestType.PUT_VAR, id,
+				cb.slice(ix[i][0], ix[i][1], ix[i][2], ix[i][3], new MatrixBlock())));
 		return ret;
 	}
-	
+
 	public boolean isAligned(FederationMap that, boolean transposed) {
-		//determines if the two federated data are aligned row/column partitions
-		//at the same federated site (which allows for purely federated operation)
+		// determines if the two federated data are aligned row/column partitions
+		// at the same federated site (which allows for purely federated operation)
 		boolean ret = true;
 		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
-			FederatedRange range = !transposed ? e.getKey() :
-				new FederatedRange(e.getKey()).transpose();
+			FederatedRange range = !transposed ? e.getKey() : new FederatedRange(e.getKey()).transpose();
 			FederatedData dat2 = that._fedMap.get(range);
 			ret &= e.getValue().equalAddress(dat2);
 		}
 		return ret;
 	}
-	
+
 	public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
 		return execute(tid, false, fr);
 	}
-	
+
 	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest... fr) {
 		return execute(tid, wait, null, fr);
 	}
-	
+
 	public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] frSlices, FederatedRequest... fr) {
 		return execute(tid, false, frSlices, fr);
 	}
-	
+
 	@SuppressWarnings("unchecked")
-	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices, FederatedRequest... fr) {
+	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices,
+		FederatedRequest... fr) {
 		// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
 		setThreadID(tid, frSlices, fr);
-		List<Future<FederatedResponse>> ret = new ArrayList<>(); 
+		List<Future<FederatedResponse>> ret = new ArrayList<>();
 		int pos = 0;
 		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-			ret.add(e.getValue().executeFederatedOperation(
-				(frSlices!=null) ? addAll(frSlices[pos++], fr) : fr));
-		
-		// prepare results (future federated responses), with optional wait to ensure the 
+			ret.add(e.getValue().executeFederatedOperation((frSlices != null) ? addAll(frSlices[pos++], fr) : fr));
+
+		// prepare results (future federated responses), with optional wait to ensure the
 		// order of requests without data dependencies (e.g., cleanup RPCs)
-		if( wait )
+		if(wait)
 			FederationUtils.waitFor(ret);
 		return ret.toArray(new Future[0]);
 	}
-	
+
 	public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
-		if( !isInitialized() )
+		if(!isInitialized())
 			throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
-		
+
 		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
 		FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
 		for(Map.Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
-			readResponses.add(new ImmutablePair<>(e.getKey(), 
-				e.getValue().executeFederatedOperation(request)));
+			readResponses.add(new ImmutablePair<>(e.getKey(), e.getValue().executeFederatedOperation(request)));
 		return readResponses;
 	}
-	
+
 	public FederatedRequest cleanup(long tid, long... id) {
 		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
 			VariableCPInstruction.prepareRemoveInstruction(id).toString());
 		request.setTID(tid);
 		return request;
 	}
-	
+
 	public void execCleanup(long tid, long... id) {
 		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
 			VariableCPInstruction.prepareRemoveInstruction(id).toString());
@@ -234,16 +231,17 @@ public class FederationMap
 		List<Future<FederatedResponse>> tmp = new ArrayList<>();
 		for(FederatedData fd : _fedMap.values())
 			tmp.add(fd.executeFederatedOperation(request));
-		//wait to avoid interference w/ following requests
+		// wait to avoid interference w/ following requests
 		FederationUtils.waitFor(tmp);
 	}
-	
+
 	private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
 		FederatedRequest[] ret = new FederatedRequest[b.length + 1];
-		ret[0] = a; System.arraycopy(b, 0, ret, 1, b.length);
+		ret[0] = a;
+		System.arraycopy(b, 0, ret, 1, b.length);
 		return ret;
 	}
-	
+
 	public FederationMap identCopy(long tid, long id) {
 		Future<FederatedResponse>[] copyInstr = execute(tid,
 			new FederatedRequest(RequestType.EXEC_INST, _ID,
@@ -262,25 +260,25 @@ public class FederationMap
 		copyFederationMap._type = _type;
 		return copyFederationMap;
 	}
-	
+
 	public FederationMap copyWithNewID() {
 		return copyWithNewID(FederationUtils.getNextFedDataID());
 	}
-	
+
 	public FederationMap copyWithNewID(long id) {
 		Map<FederatedRange, FederatedData> map = new TreeMap<>();
-		//TODO handling of file path, but no danger as never written
-		for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() ) {
+		// TODO handling of file path, but no danger as never written
+		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
 			if(e.getKey().getSize() != 0)
 				map.put(new FederatedRange(e.getKey()), e.getValue().copyWithNewID(id));
 		}
 		return new FederationMap(id, map, _type);
 	}
-	
+
 	public FederationMap copyWithNewID(long id, long clen) {
 		Map<FederatedRange, FederatedData> map = new TreeMap<>();
-		//TODO handling of file path, but no danger as never written
-		for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
+		// TODO handling of file path, but no danger as never written
+		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
 			map.put(new FederatedRange(e.getKey(), clen), e.getValue().copyWithNewID(id));
 		return new FederationMap(id, map, _type);
 	}
@@ -295,24 +293,28 @@ public class FederationMap
 	public FederationMap transpose() {
 		Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
 		_fedMap.clear();
-		for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
+		for(Entry<FederatedRange, FederatedData> e : tmp.entrySet()) {
 			_fedMap.put(new FederatedRange(e.getKey()).transpose(), e.getValue().copyWithNewID(_ID));
 		}
-		//derive output type
+		// derive output type
 		switch(_type) {
-			case FULL: _type = FType.FULL; break;
-			case ROW: _type = FType.COL; break;
-			case COL: _type = FType.ROW; break;
-			default: _type = FType.OTHER;
+			case FULL:
+				_type = FType.FULL;
+				break;
+			case ROW:
+				_type = FType.COL;
+				break;
+			case COL:
+				_type = FType.ROW;
+				break;
+			default:
+				_type = FType.OTHER;
 		}
 		return this;
 	}
 
-	
 	public long getMaxIndexInRange(int dim) {
-		return _fedMap.keySet().stream()
-			.mapToLong(range -> range.getEndDims()[dim]).max()
-			.orElse(-1L);
+		return _fedMap.keySet().stream().mapToLong(range -> range.getEndDims()[dim]).max().orElse(-1L);
 	}
 
 	/**
@@ -352,27 +354,27 @@ public class FederationMap
 		fedMapCopy._ID = newVarID;
 		return fedMapCopy;
 	}
-	
+
 	public FederationMap filter(IndexRange ixrange) {
-		FederationMap ret = this.clone(); //same ID
-		
+		FederationMap ret = this.clone(); // same ID
+
 		Iterator<Entry<FederatedRange, FederatedData>> iter = ret._fedMap.entrySet().iterator();
-		while( iter.hasNext() ) {
+		while(iter.hasNext()) {
 			Entry<FederatedRange, FederatedData> e = iter.next();
 			FederatedRange range = e.getKey();
-			long rs = range.getBeginDims()[0], re = range.getEndDims()[0],
-				cs = range.getBeginDims()[1], ce = range.getEndDims()[1];
-			boolean overlap = ((ixrange.colStart <= ce) && (ixrange.colEnd >= cs)
-				&& (ixrange.rowStart <= re) && (ixrange.rowEnd >= rs));
-			if( !overlap )
+			long rs = range.getBeginDims()[0], re = range.getEndDims()[0], cs = range.getBeginDims()[1],
+				ce = range.getEndDims()[1];
+			boolean overlap = ((ixrange.colStart <= ce) && (ixrange.colEnd >= cs) && (ixrange.rowStart <= re) &&
+				(ixrange.rowEnd >= rs));
+			if(!overlap)
 				iter.remove();
 		}
 		return ret;
 	}
-	
+
 	private static void setThreadID(long tid, FederatedRequest[]... frsets) {
-		for( FederatedRequest[] frset : frsets )
-			if( frset != null )
+		for(FederatedRequest[] frset : frsets)
+			if(frset != null)
 				Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
 	}
 
@@ -399,14 +401,14 @@ public class FederationMap
 	}
 
 	@Override
-	public String toString(){
+	public String toString() {
 		StringBuilder sb = new StringBuilder();
 		sb.append("Fed Map: " + _type);
 		sb.append("\t ID:" + _ID);
-		sb.append("\n"+ _fedMap);
+		sb.append("\n" + _fedMap);
 		return sb.toString();
 	}
-	
+
 	@Override
 	public FederationMap clone() {
 		return copyWithNewID(getID());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c50671e..6588909 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -19,7 +19,6 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
-import java.util.AbstractMap;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
@@ -136,9 +135,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 			out.getDataCharacteristics().set(mo.getDataCharacteristics());
 			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
 		}
-		else if(opcode.equals("rmempty")) {
+		else if(opcode.equals("rmempty"))
 			rmempty(ec);
-		}
 		else if(opcode.equalsIgnoreCase("transformdecode"))
 			transformDecode(ec);
 		else if(opcode.equalsIgnoreCase("transformapply"))
@@ -149,32 +147,33 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 	}
 
 	private void rmempty(ExecutionContext ec) {
+		String margin = params.get("margin");
+		if( !(margin.equals("rows") || margin.equals("cols")) )
+			throw new DMLRuntimeException("Unspupported margin identifier '"+margin+"'.");
+
 		MatrixObject mo = (MatrixObject) getTarget(ec);
+		MatrixObject select = params.containsKey("select") ? ec.getMatrixObject(params.get("select")) : null;
 		MatrixObject out = ec.getMatrixObject(output);
-		Map<FederatedRange, int[]> dcs;
-		if((instString.contains("margin=rows") && mo.isFederated(FederationMap.FType.ROW)) ||
-			(instString.contains("margin=cols") && mo.isFederated(FederationMap.FType.COL))) {
-			FederatedRequest fr1 = FederationUtils.callInstruction(instString,
-				output,
-				new CPOperand[] {getTargetOperand()},
-				new long[] {mo.getFedMapping().getID()});
-			mo.getFedMapping().execute(getTID(), true, fr1);
-			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
 
-			// new ranges
-			dcs = new HashMap<>();
-			out.getFedMapping().forEachParallel((range, data) -> {
+		boolean marginRow = params.get("margin").equals("rows");
+		boolean k = ((marginRow && mo.getFedMapping().getType().isColPartitioned()) ||
+			(!marginRow && mo.getFedMapping().getType().isRowPartitioned()));
+
+		MatrixBlock s = new MatrixBlock();
+		if(select == null && k) {
+			List<MatrixBlock> colSums = new ArrayList<>();
+			mo.getFedMapping().forEachParallel((range, data) -> {
 				try {
 					FederatedResponse response = data
 						.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-							new GetDataCharacteristics(data.getVarID())))
+							new GetVector(data.getVarID(), margin.equals("rows"))))
 						.get();
 
 					if(!response.isSuccessful())
 						response.throwExceptionFromResponse();
-					int[] subRangeCharacteristics = (int[]) response.getData()[0];
-					synchronized(dcs) {
-						dcs.put(range, subRangeCharacteristics);
+					MatrixBlock vector = (MatrixBlock) response.getData()[0];
+					synchronized(colSums) {
+						colSums.add(vector);
 					}
 				}
 				catch(Exception e) {
@@ -182,53 +181,75 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 				}
 				return null;
 			});
+			// find empty in matrix
+			BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
+			BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
+			s = colSums.get(0);
+			for(int i = 1; i < colSums.size(); i++)
+				s = s.binaryOperationsInPlace(plus, colSums.get(i));
+			s = s.binaryOperationsInPlace(greater, new MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
+			select = ExecutionContext.createMatrixObject(s);
+
+			long varID = FederationUtils.getNextFedDataID();
+			ec.setVariable(String.valueOf(varID), select);
+			params.put("select", String.valueOf(varID));
+			// construct new string
+			String[] oldString = InstructionUtils.getInstructionParts(instString);
+			String[] newString = new String[oldString.length+1];
+			newString[2] = "select="+varID;
+			System.arraycopy(oldString, 0, newString, 0,2);
+			System.arraycopy(oldString,2, newString, 3, newString.length-3);
+			instString = instString.replace(InstructionUtils.concatOperands(oldString), InstructionUtils.concatOperands(newString));
 		}
-		else {
-			Map.Entry<FederationMap, Map<FederatedRange, int[]>> entry = rmemptyC(ec, mo);
-			out.setFedMapping(entry.getKey());
-			dcs = entry.getValue();
-		}
-		out.getDataCharacteristics().set(mo.getDataCharacteristics());
-		for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) {
-			int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]);
-
-			out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
-				(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
-					i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
-
-			out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
-				out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
-
-			out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
-				(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
-					i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
 
-			out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
-				out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+		if (select == null) {
+			FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+				new CPOperand[] {getTargetOperand()},
+				new long[] {mo.getFedMapping().getID()});
+			mo.getFedMapping().execute(getTID(), true, fr1);
+			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
 		}
+		else if (!k) {
+			//construct commands: broadcast , fed rmempty, clean broadcast
+			FederatedRequest[] fr1 = mo.getFedMapping().broadcastSliced(select, !marginRow);
+			FederatedRequest fr2 = FederationUtils.callInstruction(instString,
+				output,
+				new CPOperand[] {getTargetOperand(), new CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+				new long[] {mo.getFedMapping().getID(), fr1[0].getID()});
+			FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1[0].getID());
+
+			//execute federated operations and set output
+			mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+		} else {
+			//construct commands: broadcast , fed rmempty, clean broadcast
+			FederatedRequest fr1 = mo.getFedMapping().broadcast(select);
+			FederatedRequest fr2 = FederationUtils.callInstruction(instString,
+				output,
+				new CPOperand[] {getTargetOperand(), new CPOperand(params.get("select"), ValueType.FP64, DataType.MATRIX)},
+				new long[] {mo.getFedMapping().getID(), fr1.getID()});
+			FederatedRequest fr3 = mo.getFedMapping().cleanup(getTID(), fr1.getID());
 
-		out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
-			out.getFedMapping().getMaxIndexInRange(1),
-			(int) mo.getBlocksize());
-	}
-
-	private Map.Entry<FederationMap, Map<FederatedRange, int[]>> rmemptyC(ExecutionContext ec, MatrixObject mo) {
-		boolean marginRow = instString.contains("margin=rows");
+			//execute federated operations and set output
+			mo.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+			out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+		}
 
-		// find empty in ranges
-		List<MatrixBlock> colSums = new ArrayList<>();
-		mo.getFedMapping().forEachParallel((range, data) -> {
+		// new ranges
+		Map<FederatedRange, int[]> dcs = new HashMap<>();
+		Map<FederatedRange, int[]> finalDcs1 = dcs;
+		out.getFedMapping().forEachParallel((range, data) -> {
 			try {
 				FederatedResponse response = data
 					.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-						new GetVector(data.getVarID(), marginRow)))
+						new GetDataCharacteristics(data.getVarID())))
 					.get();
 
 				if(!response.isSuccessful())
 					response.throwExceptionFromResponse();
-				MatrixBlock vector = (MatrixBlock) response.getData()[0];
-				synchronized(colSums) {
-					colSums.add(vector);
+				int[] subRangeCharacteristics = (int[]) response.getData()[0];
+				synchronized(finalDcs1) {
+					finalDcs1.put(range, subRangeCharacteristics);
 				}
 			}
 			catch(Exception e) {
@@ -236,46 +257,28 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 			}
 			return null;
 		});
+		dcs = finalDcs1;
+		out.getDataCharacteristics().set(mo.getDataCharacteristics());
+		for(int i = 0; i < mo.getFedMapping().getFederatedRanges().length; i++) {
+			int[] newRange = dcs.get(out.getFedMapping().getFederatedRanges()[i]);
 
-		// find empty in matrix
-		BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
-		BinaryOperator greater = InstructionUtils.parseBinaryOperator(">");
-		MatrixBlock tmp1 = colSums.get(0);
-		for(int i = 1; i < colSums.size(); i++)
-			tmp1 = tmp1.binaryOperationsInPlace(plus, colSums.get(i));
-		tmp1 = tmp1.binaryOperationsInPlace(greater, new MatrixBlock(tmp1.getNumRows(), tmp1.getNumColumns(), 0.0));
+			out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+				(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+					i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
 
-		// remove empty from matrix
-		Map<FederatedRange, int[]> dcs = new HashMap<>();
-		long varID = FederationUtils.getNextFedDataID();
-		MatrixBlock finalTmp = new MatrixBlock(tmp1);
-		FederationMap resMapping;
-		if(tmp1.sum() == (marginRow ? tmp1.getNumColumns() : tmp1.getNumRows())) {
-			resMapping = mo.getFedMapping();
-		}
-		else {
-			resMapping = mo.getFedMapping().mapParallel(varID, (range, data) -> {
-				try {
-					FederatedResponse response = data
-						.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-							new ParameterizedBuiltinFEDInstruction.RemoveEmpty(data.getVarID(), varID, finalTmp,
-								params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null,
-								Boolean.parseBoolean(params.get("empty.return").toLowerCase()), marginRow)))
-						.get();
-					if(!response.isSuccessful())
-						response.throwExceptionFromResponse();
-					int[] subRangeCharacteristics = (int[]) response.getData()[0];
-					synchronized(dcs) {
-						dcs.put(range, subRangeCharacteristics);
-					}
-				}
-				catch(Exception e) {
-					throw new DMLRuntimeException(e);
-				}
-				return null;
-			});
+			out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+				out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+			out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+				(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+					i == 0) ? 0 : out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+			out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+				out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
 		}
-		return new AbstractMap.SimpleEntry<>(resMapping, dcs);
+
+		out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+			out.getFedMapping().getMaxIndexInRange(1), (int) mo.getBlocksize());
 	}
 
 	private void transformDecode(ExecutionContext ec) {
@@ -506,52 +509,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 		@Override
 		public FederatedResponse execute(ExecutionContext ec, Data... data) {
 			MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
-			return new FederatedResponse(ResponseType.SUCCESS, new int[] {mb.getNumRows(), mb.getNumColumns()});
-		}
-	}
-
-	private static class RemoveEmpty extends FederatedUDF {
-
-		private static final long serialVersionUID = 12341521331L;
-		private final MatrixBlock _vector;
-		private final long _outputID;
-		private MatrixBlock _select;
-		private boolean _emptyReturn;
-		private final boolean _marginRow;
-
-		public RemoveEmpty(long varID, long outputID, MatrixBlock vector, MatrixBlock select, boolean emptyReturn,
-			boolean marginRow) {
-			super(new long[] {varID});
-			_vector = vector;
-			_outputID = outputID;
-			_select = select;
-			_emptyReturn = emptyReturn;
-			_marginRow = marginRow;
-		}
-
-		@Override
-		public FederatedResponse execute(ExecutionContext ec, Data... data) {
-			MatrixBlock mb = ((MatrixObject) data[0]).acquireReadAndRelease();
-
-			BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
-			BinaryOperator minus = InstructionUtils.parseBinaryOperator("-");
-
-			mb = mb.binaryOperationsInPlace(plus, new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), 1.0));
-			for(int i = 0; i < mb.getNumRows(); i++)
-				for(int j = 0; j < mb.getNumColumns(); j++)
-					if(_marginRow)
-						mb.setValue(i, j, _vector.getValue(i, 0) * mb.getValue(i, j));
-					else
-						mb.setValue(i, j, _vector.getValue(0, j) * mb.getValue(i, j));
-
-			MatrixBlock res = mb.removeEmptyOperations(new MatrixBlock(), _marginRow, _emptyReturn, _select);
-			res = res.binaryOperationsInPlace(minus, new MatrixBlock(res.getNumRows(), res.getNumColumns(), 1.0));
-
-			MatrixObject mout = ExecutionContext.createMatrixObject(res);
-			ec.setVariable(String.valueOf(_outputID), mout);
-
-			return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS,
-				new int[] {res.getNumRows(), res.getNumColumns()});
+			int r = mb.getDenseBlockValues() != null ? mb.getNumRows() : 0;
+			int c = mb.getDenseBlockValues() != null ? mb.getNumColumns(): 0;
+			return new FederatedResponse(ResponseType.SUCCESS, new int[] {r, c});
 		}
 	}
 
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
index a629270..10a6711 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRemoveEmptyTest.java
@@ -55,7 +55,10 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
 
 	@Parameterized.Parameters
 	public static Collection<Object[]> data() {
-		return Arrays.asList(new Object[][] {{20, 10, true}, {20, 12, false}});
+		return Arrays.asList(new Object[][] {
+			{20, 12, true},
+			{20, 12, false}
+		});
 	}
 
 	@Override
@@ -94,11 +97,6 @@ public class FederatedRemoveEmptyTest extends AutomatedTestBase {
 
 		for(int k : new int[] {1, 2, 3}) {
 			Arrays.fill(X3[k], 0);
-			if(!rowPartitioned) {
-				Arrays.fill(X1[k], 0);
-				Arrays.fill(X2[k], 0);
-				Arrays.fill(X4[k], 0);
-			}
 		}
 
 		MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
index 3e640c0..afd2ffe 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedSplitTest.java
@@ -54,7 +54,9 @@ public class FederatedSplitTest extends AutomatedTestBase {
 
 	@Parameterized.Parameters
 	public static Collection<Object[]> data() {
-		return Arrays.asList(new Object[][] {{152, 12, "TRUE"}, {132, 11, "FALSE"}});
+		return Arrays.asList(new Object[][] {
+			// {152, 12, "TRUE"}, 
+			{132, 11, "FALSE"}});
 	}
 
 	@Override
@@ -125,9 +127,7 @@ public class FederatedSplitTest extends AutomatedTestBase {
 		if(cont.equals("TRUE"))
 			Assert.assertTrue(heavyHittersContainsString("fed_rightIndex"));
 		else {
-			Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-			// TODO add federated diag operator.
-			// Assert.assertTrue(heavyHittersContainsString("fed_rdiag"));
+			Assert.assertTrue(heavyHittersContainsString("fed_rmempty"));
 		}
 		
 		TestUtils.shutdownThreads(t1, t2);
diff --git a/src/test/scripts/functions/federated/FederatedSplitTest.dml b/src/test/scripts/functions/federated/FederatedSplitTest.dml
index 44c59a9..e1fc647 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTest.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTest.dml
@@ -24,7 +24,6 @@ X = federated(addresses=list($X1, $X2),
 Y = federated(addresses=list($Y1, $Y2),
     ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
 
-
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y,f=0.95, cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
 write(Xte, $Z)
 print(toString(Xte))
diff --git a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
index 4db8e1f..962dd84 100644
--- a/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
+++ b/src/test/scripts/functions/federated/FederatedSplitTestReference.dml
@@ -21,6 +21,6 @@
 
 X = rbind(read($X1), read($X2))
 Y = rbind(read($Y1), read($Y2))
-[Xtr, Xte, Ytr, Yte] = split(X=X,Y=Y, f=0.95 ,cont=$Cont, seed = 13)
+[Xtr, Xte, Ytr, Yte] = split(X=X, Y=Y, f=0.95, cont=$Cont, seed = 13)
 write(Xte, $Z)
 print(toString(Xte))