You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/09/06 17:22:45 UTC

[systemds] branch master updated: [SYSTEMDS-2556, 2560] Federated transform impute, improved omit

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new d15fa6e  [SYSTEMDS-2556,2560] Federated transform impute, improved omit
d15fa6e is described below

commit d15fa6e55682b8c8a1e85740b8f7cbffab4dd151
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Sun Sep 6 19:21:57 2020 +0200

    [SYSTEMDS-2556,2560] Federated transform impute, improved omit
    
    Closes #1046.
---
 .../federated/FederatedWorkerHandler.java          |   4 +-
 ...tiReturnParameterizedBuiltinFEDInstruction.java |   6 +-
 .../fed/ParameterizedBuiltinFEDInstruction.java    |   2 +-
 .../sysds/runtime/transform/encode/Encoder.java    |   9 +-
 .../sysds/runtime/transform/encode/EncoderBin.java |   4 +-
 .../runtime/transform/encode/EncoderComposite.java |   8 +-
 .../runtime/transform/encode/EncoderDummycode.java |   4 +-
 .../runtime/transform/encode/EncoderFactory.java   |   2 +-
 .../transform/encode/EncoderFeatureHash.java       |   4 +-
 .../runtime/transform/encode/EncoderMVImpute.java  | 388 ++++++++++-----------
 .../runtime/transform/encode/EncoderOmit.java      |  74 ++--
 .../transform/encode/EncoderPassThrough.java       |   4 +-
 .../runtime/transform/encode/EncoderRecode.java    |   4 +-
 .../TransformFederatedEncodeApplyTest.java         | 100 ++++--
 .../transform/TransformFederatedEncodeApply.dml    |   8 +-
 15 files changed, 320 insertions(+), 301 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b5f0ec8..2690ee6 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -100,9 +100,9 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
 			
 			//select the response for the entire batch of requests
 			if (!tmp.isSuccessful()) {
-				log.error("Command " + request.getType() + " failed: " 
+				log.error("Command " + request.getType() + " failed: "
 					+ tmp.getErrorMessage() + "full command: \n" + request.toString());
-				response = (response == null || response.isSuccessful()) 
+				response = (response == null || response.isSuccessful())
 					? tmp : response; //return first error
 			}
 			else if( request.getType() == RequestType.GET_VAR ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 0fe12b9..047aff3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -48,6 +48,7 @@ import org.apache.sysds.runtime.transform.encode.EncoderComposite;
 import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash;
+import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
 import org.apache.sysds.runtime.transform.encode.EncoderOmit;
 import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
@@ -102,7 +103,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
 				new EncoderPassThrough(),
 				new EncoderBin(),
 				new EncoderDummycode(),
-				new EncoderOmit(true)));
+				new EncoderOmit(true),
+				new EncoderMVImpute()));
 		// first create encoders at the federated workers, then collect them and aggregate them to a single large
 		// encoder
 		FederationMap fedMapping = fin.getFedMapping();
@@ -120,7 +122,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
 				Encoder encoder = (Encoder) response.getData()[0];
 				// merge this encoder into a composite encoder
 				synchronized(globalEncoder) {
-					globalEncoder.mergeAt(encoder, columnOffset);
+					globalEncoder.mergeAt(encoder, (int) (range.getBeginDims()[0] + 1), columnOffset);
 				}
 				// no synchronization necessary since names should anyway match
 				String[] subRangeColNames = (String[]) response.getData()[1];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 204019f..4f31d4f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -266,7 +266,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
 
 				// no synchronization necessary since names should anyway match
 				Encoder builtEncoder = (Encoder) response.getData()[0];
-				newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[1] + 1));
+				newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[0] + 1), (int) (range.getBeginDims()[1] + 1));
 			}
 			catch(Exception e) {
 				throw new DMLRuntimeException(e);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 7f47192..0758620 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -177,9 +177,10 @@ public abstract class Encoder implements Serializable
 	 * other <code>Encoder</code>.
 	 * 
 	 * @param other the encoder that should be merged in
-	 * @param col   the position where it should be placed (1-based)
+	 * @param row   the row where it should be placed (1-based)
+	 * @param col   the col where it should be placed (1-based)
 	 */
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		throw new DMLRuntimeException(
 			this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName());
 	}
@@ -187,8 +188,8 @@ public abstract class Encoder implements Serializable
 	/**
 	 * Update index-ranges to after encoding. Note that only Dummycoding changes the ranges.
 	 *
-	 * @param beginDims the begin indexes before encoding
-	 * @param endDims   the end indexes before encoding
+	 * @param beginDims begin dimensions of range
+	 * @param endDims end dimensions of range
 	 */
 	public void updateIndexRanges(long[] beginDims, long[] endDims) {
 		// do nothing - default
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index 351f68d..4caee9b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -169,7 +169,7 @@ public class EncoderBin extends Encoder
 	}
 	
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderBin) {
 			EncoderBin otherBin = (EncoderBin) other;
 
@@ -217,7 +217,7 @@ public class EncoderBin extends Encoder
 			}
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index c494676..cc59932 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -117,7 +117,7 @@ public class EncoderComposite extends Encoder
 	}
 
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if (other instanceof EncoderComposite) {
 			EncoderComposite otherComposite = (EncoderComposite) other;
 			// TODO maybe assert that the _encoders never have the same type of encoder twice or more
@@ -125,7 +125,7 @@ public class EncoderComposite extends Encoder
 				boolean mergedIn = false;
 				for (Encoder encoder : _encoders) {
 					if (encoder.getClass() == otherEnc.getClass()) {
-						encoder.mergeAt(otherEnc, col);
+						encoder.mergeAt(otherEnc, row, col);
 						mergedIn = true;
 						break;
 					}
@@ -146,7 +146,7 @@ public class EncoderComposite extends Encoder
 		}
 		for (Encoder encoder : _encoders) {
 			if (encoder.getClass() == other.getClass()) {
-				encoder.mergeAt(other, col);
+				encoder.mergeAt(other, row, col);
 				// update dummycode encoder domain sizes based on distinctness information from other encoders
 				for (Encoder encDummy : _encoders) {
 					if (encDummy instanceof EncoderDummycode) {
@@ -157,7 +157,7 @@ public class EncoderComposite extends Encoder
 				return;
 			}
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index 19d41ea..f590a04 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -128,7 +128,7 @@ public class EncoderDummycode extends Encoder
 	}
 
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderDummycode) {
 			mergeColumnInfo(other, col);
 
@@ -138,7 +138,7 @@ public class EncoderDummycode extends Encoder
 			Arrays.fill(_domainSizes, 0, _colList.length, 1);
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 313e5b2..af929be 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -104,7 +104,7 @@ public class EncoderFactory
 			if( !oIDs.isEmpty() )
 				lencoders.add(new EncoderOmit(jSpec, colnames, schema.length, minCol, maxCol));
 			if( !mvIDs.isEmpty() ) {
-				EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length);
+				EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length, minCol, maxCol);
 				ma.initRecodeIDList(rcIDs);
 				lencoders.add(ma);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
index 9317dfb..3b6503b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
@@ -110,14 +110,14 @@ public class EncoderFeatureHash extends Encoder
 	}
 	
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderFeatureHash) {
 			mergeColumnInfo(other, col);
 			if (((EncoderFeatureHash) other)._K != 0 && _K == 0)
 				_K = ((EncoderFeatureHash) other)._K;
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 56749a2..534d16c 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -19,263 +19,119 @@
 
 package org.apache.sysds.runtime.transform.encode;
 
-import java.io.IOException;
-import java.util.BitSet;
+import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONException;
 import org.apache.wink.json4j.JSONObject;
-import org.apache.sysds.runtime.functionobjects.CM;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.functionobjects.Mean;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysds.runtime.instructions.cp.KahanObject;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
-import org.apache.sysds.runtime.transform.TfUtils;
 import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
-public class EncoderMVImpute extends Encoder 
+public class EncoderMVImpute extends Encoder
 {
 	private static final long serialVersionUID = 9057868620144662194L;
 
 	public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT }
 	
 	private MVMethod[] _mvMethodList = null;
-	private MVMethod[] _mvscMethodList = null; // scaling methods for attributes that are imputed and also scaled
-	
-	private BitSet _isMVScaled = null;
-	private CM _varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE);		// function object that understands variance computation
 	
 	// objects required to compute mean and variance of all non-missing entries 
-	private Mean _meanFn = Mean.getMeanFnObject();  // function object that understands mean computation
+	private final Mean _meanFn = Mean.getMeanFnObject();  // function object that understands mean computation
 	private KahanObject[] _meanList = null;         // column-level means, computed so far
 	private long[] _countList = null;               // #of non-missing values
 	
-	private CM_COV_Object[] _varList = null;        // column-level variances, computed so far (for scaling)
-
-	private int[]           _scnomvList = null;        // List of attributes that are scaled but not imputed
-	private MVMethod[]      _scnomvMethodList = null;  // scaling methods: 0 for invalid; 1 for mean-subtraction; 2 for z-scoring
-	private KahanObject[]   _scnomvMeanList = null;    // column-level means, for attributes scaled but not imputed
-	private long[]          _scnomvCountList = null;   // #of non-missing values, for attributes scaled but not imputed
-	private CM_COV_Object[] _scnomvVarList = null;     // column-level variances, computed so far
-	
 	private String[] _replacementList = null; // replacements: for global_mean, mean; and for global_mode, recode id of mode category
-	private String[] _NAstrings = null;
 	private List<Integer> _rcList = null;
 	private HashMap<Integer,HashMap<String,Long>> _hist = null;
 	
 	public String[] getReplacements() { return _replacementList; }
 	public KahanObject[] getMeans()   { return _meanList; }
-	public CM_COV_Object[] getVars()  { return _varList; }
-	public KahanObject[] getMeans_scnomv()   { return _scnomvMeanList; }
-	public CM_COV_Object[] getVars_scnomv()  { return _scnomvVarList; }
 	
-	public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen) 
+	public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
 		throws JSONException
 	{
 		super(null, clen);
 		
 		//handle column list
-		int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), -1, -1);
+		int[] collist = TfMetaUtils
+			.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol);
 		initColList(collist);
 	
 		//handle method list
-		parseMethodsAndReplacments(parsedSpec);
+		parseMethodsAndReplacements(parsedSpec, colnames, minCol);
 		
 		//create reuse histograms
 		_hist = new HashMap<>();
 	}
 	
-	public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, String[] NAstrings, int clen)
-		throws JSONException 
-	{
-		super(null, clen);
-		boolean isMV = parsedSpec.containsKey(TfMethod.IMPUTE.toString());
-		boolean isSC = parsedSpec.containsKey(TfMethod.SCALE.toString());
-		_NAstrings = NAstrings;
-		
-		if(!isMV) {
-			// MV Impute is not applicable
-			_colList = null;
-			_mvMethodList = null;
-			_meanList = null;
-			_countList = null;
-			_replacementList = null;
-		}
-		else {
-			JSONObject mvobj = (JSONObject) parsedSpec.get(TfMethod.IMPUTE.toString());
-			JSONArray mvattrs = (JSONArray) mvobj.get(TfUtils.JSON_ATTRS);
-			JSONArray mvmthds = (JSONArray) mvobj.get(TfUtils.JSON_MTHD);
-			int mvLength = mvattrs.size();
-			
-			_colList = new int[mvLength];
-			_mvMethodList = new MVMethod[mvLength];
-			
-			_meanList = new KahanObject[mvLength];
-			_countList = new long[mvLength];
-			_varList = new CM_COV_Object[mvLength];
-			
-			_isMVScaled = new BitSet(_colList.length);
-			_isMVScaled.clear();
-			
-			for(int i=0; i < _colList.length; i++) {
-				_colList[i] = UtilFunctions.toInt(mvattrs.get(i));
-				_mvMethodList[i] = MVMethod.values()[UtilFunctions.toInt(mvmthds.get(i))]; 
-				_meanList[i] = new KahanObject(0, 0);
-			}
-			
-			_replacementList = new String[mvLength]; 	// contains replacements for all columns (scale and categorical)
-			
-			JSONArray constants = (JSONArray)mvobj.get(TfUtils.JSON_CONSTS);
-			for(int i=0; i < constants.size(); i++) {
-				if ( constants.get(i) == null )
-					_replacementList[i] = "NaN";
-				else
-					_replacementList[i] = constants.get(i).toString();
-			}
-		}
-		
-		// Handle scaled attributes
-		if ( !isSC )
-		{
-			// scaling is not applicable
-			_scnomvCountList = null;
-			_scnomvMeanList = null;
-			_scnomvVarList = null;
-		}
-		else
-		{
-			if ( _colList != null ) 
-				_mvscMethodList = new MVMethod[_colList.length];
-			
-			JSONObject scobj = (JSONObject) parsedSpec.get(TfMethod.SCALE.toString());
-			JSONArray scattrs = (JSONArray) scobj.get(TfUtils.JSON_ATTRS);
-			JSONArray scmthds = (JSONArray) scobj.get(TfUtils.JSON_MTHD);
-			int scLength = scattrs.size();
-			
-			int[] _allscaled = new int[scLength];
-			int scnomv = 0, colID;
-			byte mthd;
-			for(int i=0; i < scLength; i++)
-			{
-				colID = UtilFunctions.toInt(scattrs.get(i));
-				mthd = (byte) UtilFunctions.toInt(scmthds.get(i));
-				
-				_allscaled[i] = colID;
-				
-				// check if the attribute is also MV imputed
-				int mvidx = isApplicable(colID);
-				if(mvidx != -1)
-				{
-					_isMVScaled.set(mvidx);
-					_mvscMethodList[mvidx] = MVMethod.values()[mthd];
-					_varList[mvidx] = new CM_COV_Object();
-				}
-				else
-					scnomv++;	// count of scaled but not imputed 
-			}
-			
-			if(scnomv > 0)
-			{
-				_scnomvList = new int[scnomv];
-				_scnomvMethodList = new MVMethod[scnomv];
+	public EncoderMVImpute() {
+		super(new int[0], 0);
+	}
 	
-				_scnomvMeanList = new KahanObject[scnomv];
-				_scnomvCountList = new long[scnomv];
-				_scnomvVarList = new CM_COV_Object[scnomv];
-				
-				for(int i=0, idx=0; i < scLength; i++)
-				{
-					colID = UtilFunctions.toInt(scattrs.get(i));
-					mthd = (byte)UtilFunctions.toInt(scmthds.get(i));
-							
-					if(isApplicable(colID) == -1)
-					{	// scaled but not imputed
-						_scnomvList[idx] = colID;
-						_scnomvMethodList[idx] = MVMethod.values()[mthd];
-						_scnomvMeanList[idx] = new KahanObject(0, 0);
-						_scnomvVarList[idx] = new CM_COV_Object();
-						idx++;
-					}
-				}
-			}
-		}
+	
+	public EncoderMVImpute(int[] colList, MVMethod[] mvMethodList, String[] replacementList, KahanObject[] meanList,
+			long[] countList, List<Integer> rcList, int clen) {
+		super(colList, clen);
+		_mvMethodList = mvMethodList;
+		_replacementList = replacementList;
+		_meanList = meanList;
+		_countList = countList;
+		_rcList = rcList;
 	}
-
-	private void parseMethodsAndReplacments(JSONObject parsedSpec) throws JSONException {
+	
+	private void parseMethodsAndReplacements(JSONObject parsedSpec, String[] colnames, int offset) throws JSONException {
 		JSONArray mvspec = (JSONArray) parsedSpec.get(TfMethod.IMPUTE.toString());
+		boolean ids = parsedSpec.containsKey("ids") && parsedSpec.getBoolean("ids");
+		// make space for all elements
 		_mvMethodList = new MVMethod[mvspec.size()];
 		_replacementList = new String[mvspec.size()];
 		_meanList = new KahanObject[mvspec.size()];
 		_countList = new long[mvspec.size()];
-		for(int i=0; i < mvspec.size(); i++) {
-			JSONObject mvobj = (JSONObject)mvspec.get(i);
-			_mvMethodList[i] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase()); 
-			if( _mvMethodList[i] == MVMethod.CONSTANT ) {
-				_replacementList[i] = mvobj.getString("value").toString();
-			}
-			_meanList[i] = new KahanObject(0, 0);
-		}
-	}
-		
-	public void prepare(String[] words) throws IOException {
+		// sort for binary search
+		Arrays.sort(_colList);
 		
-		try {
-			String w = null;
-			if(_colList != null)
-			for(int i=0; i <_colList.length; i++) {
-				int colID = _colList[i];
-				w = UtilFunctions.unquote(words[colID-1].trim());
-				
-				try {
-				if(!TfUtils.isNA(_NAstrings, w)) {
-					_countList[i]++;
-					
-					boolean computeMean = (_mvMethodList[i] == MVMethod.GLOBAL_MEAN || _isMVScaled.get(i) );
-					if(computeMean) {
-						// global_mean
-						double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
-						_meanFn.execute2(_meanList[i], d, _countList[i]);
-						
-						if (_isMVScaled.get(i) && _mvscMethodList[i] == MVMethod.GLOBAL_MODE)
-							_varFn.execute(_varList[i], d);
-					}
-					else {
-						// global_mode or constant
-						// Nothing to do here. Mode is computed using recode maps.
-					}
-				}
-				} catch (NumberFormatException e) 
-				{
-					throw new RuntimeException("Encountered \"" + w + "\" in column ID \"" + colID + "\", when expecting a numeric value. Consider adding \"" + w + "\" to na.strings, along with an appropriate imputation method.");
+		int listIx = 0;
+		for(Object o : mvspec) {
+			JSONObject mvobj = (JSONObject) o;
+			int ixOffset = offset == -1 ? 0 : offset - 1;
+			// check for position -> -1 if not present
+			int pos = Arrays.binarySearch(_colList,
+				ids ? mvobj.getInt("id") - ixOffset : ArrayUtils.indexOf(colnames, mvobj.get("name")) + 1);
+			if(pos >= 0) {
+				// add to arrays
+				_mvMethodList[listIx] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
+				if(_mvMethodList[listIx] == MVMethod.CONSTANT) {
+					_replacementList[listIx] = mvobj.getString("value");
 				}
+				_meanList[listIx++] = new KahanObject(0, 0);
 			}
-			
-			// Compute mean and variance for attributes that are scaled but not imputed
-			if(_scnomvList != null)
-			for(int i=0; i < _scnomvList.length; i++) 
-			{
-				int colID = _scnomvList[i];
-				w = UtilFunctions.unquote(words[colID-1].trim());
-				double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
-				_scnomvCountList[i]++; 		// not required, this is always equal to total #records processed
-				_meanFn.execute2(_scnomvMeanList[i], d, _scnomvCountList[i]);
-				if(_scnomvMethodList[i] == MVMethod.GLOBAL_MODE)
-					_varFn.execute(_scnomvVarList[i], d);
-			}
-		} catch(Exception e) {
-			throw new IOException(e);
 		}
+		// make arrays required size
+		_mvMethodList = Arrays.copyOf(_mvMethodList, listIx);
+		_replacementList = Arrays.copyOf(_replacementList, listIx);
+		_meanList = Arrays.copyOf(_meanList, listIx);
+		_countList = Arrays.copyOf(_countList, listIx);
 	}
 	
 	public MVMethod getMethod(int colID) {
-		int idx = isApplicable(colID);		
+		int idx = isApplicable(colID);
 		if(idx == -1)
 			return MVMethod.INVALID;
 		else
@@ -287,8 +143,8 @@ public class EncoderMVImpute extends Encoder
 		return (idx == -1) ? 0 : _countList[idx];
 	}
 	
-	public String getReplacement(int colID)  {
-		int idx = isApplicable(colID);		
+	public String getReplacement(int colID) {
+		int idx = isApplicable(colID);
 		return (idx == -1) ? null : _replacementList[idx];
 	}
 	
@@ -321,7 +177,7 @@ public class EncoderMVImpute extends Encoder
 						if( key != null && !key.isEmpty() ) {
 							Long val = hist.get(key);
 							hist.put(key, (val!=null) ? val+1 : 1);
-						}	
+						}
 					}
 					_hist.put(colID, hist);
 					long max = Long.MIN_VALUE; 
@@ -349,12 +205,98 @@ public class EncoderMVImpute extends Encoder
 		}
 		return out;
 	}
+
+	@Override
+	public Encoder subRangeEncoder(IndexRange ixRange) {
+		Map<Integer, ColInfo> map = new HashMap<>();
+		for(int i = 0; i < _colList.length; i++) {
+			int col = _colList[i];
+			if(ixRange.inColRange(col))
+				map.put((int) (_colList[i] - (ixRange.colStart - 1)),
+					new ColInfo(_mvMethodList[i], _replacementList[i], _meanList[i], _countList[i], _hist.get(i)));
+		}
+		if(map.size() == 0)
+			// empty encoder -> sub range encoder does not exist
+			return null;
+
+		int[] colList = new int[map.size()];
+		MVMethod[] mvMethodList = new MVMethod[map.size()];
+		String[] replacementList = new String[map.size()];
+		KahanObject[] meanList = new KahanObject[map.size()];
+		long[] countList = new long[map.size()];
+
+		fillListsFromMap(map, colList, mvMethodList, replacementList, meanList, countList, _hist);
+
+		if(_rcList == null)
+			_rcList = new ArrayList<>();
+		List<Integer> rcList = _rcList.stream().filter(ixRange::inColRange).map(i -> (int) (i - (ixRange.colStart - 1)))
+			.collect(Collectors.toList());
+
+		return new EncoderMVImpute(colList, mvMethodList, replacementList, meanList, countList, rcList,
+			(int) ixRange.colSpan());
+	}
+
+	private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] colList, MVMethod[] mvMethodList,
+		String[] replacementList, KahanObject[] meanList, long[] countList,
+		HashMap<Integer, HashMap<String, Long>> hist) {
+		int i = 0;
+		for(Entry<Integer, ColInfo> entry : map.entrySet()) {
+			colList[i] = entry.getKey();
+			mvMethodList[i] = entry.getValue()._method;
+			replacementList[i] = entry.getValue()._replacement;
+			meanList[i] = entry.getValue()._mean;
+			countList[i++] = entry.getValue()._count;
+
+			hist.put(entry.getKey(), entry.getValue()._hist);
+		}
+	}
+
+	@Override
+	public void mergeAt(Encoder other, int row, int col) {
+		if(other instanceof EncoderMVImpute) {
+			EncoderMVImpute otherImpute = (EncoderMVImpute) other;
+			Map<Integer, ColInfo> map = new HashMap<>();
+			for(int i = 0; i < _colList.length; i++) {
+				map.put(_colList[i],
+					new ColInfo(_mvMethodList[i], _replacementList[i], _meanList[i], _countList[i], _hist.get(i + 1)));
+			}
+			for(int i = 0; i < other._colList.length; i++) {
+				int column = other._colList[i] + (col - 1);
+				ColInfo otherColInfo = new ColInfo(otherImpute._mvMethodList[i], otherImpute._replacementList[i],
+					otherImpute._meanList[i], otherImpute._countList[i], otherImpute._hist.get(i + 1));
+				ColInfo colInfo = map.get(column);
+				if(colInfo == null)
+					map.put(column, otherColInfo);
+				else
+					colInfo.merge(otherColInfo);
+			}
+
+			_colList = new int[map.size()];
+			_mvMethodList = new MVMethod[map.size()];
+			_replacementList = new String[map.size()];
+			_meanList = new KahanObject[map.size()];
+			_countList = new long[map.size()];
+			_hist = new HashMap<>();
+
+			fillListsFromMap(map, _colList, _mvMethodList, _replacementList, _meanList, _countList, _hist);
+			// update number of columns
+			_clen = Math.max(_clen, col - 1 + other._clen);
+
+			if(_rcList == null)
+				_rcList = new ArrayList<>();
+			Set<Integer> rcSet = new HashSet<>(_rcList);
+			rcSet.addAll(otherImpute._rcList.stream().map(i -> i + (col - 1)).collect(Collectors.toSet()));
+			_rcList = new ArrayList<>(rcSet);
+			return;
+		}
+		super.mergeAt(other, row, col);
+	}
 	
 	@Override
 	public FrameBlock getMetaData(FrameBlock out) {
 		for( int j=0; j<_colList.length; j++ ) {
 			out.getColumnMetadata(_colList[j]-1)
-			   .setMvValue(_replacementList[j]);
+				.setMvValue(_replacementList[j]);
 		}
 		return out;
 	}
@@ -391,4 +333,48 @@ public class EncoderMVImpute extends Encoder
 	public HashMap<String,Long> getHistogram( int colID ) {
 		return _hist.get(colID);
 	}
+	
+	private static class ColInfo {
+		MVMethod _method;
+		String _replacement;
+		KahanObject _mean;
+		long _count;
+		HashMap<String, Long> _hist;
+
+		ColInfo(MVMethod method, String replacement, KahanObject mean, long count, HashMap<String, Long> hist) {
+			_method = method;
+			_replacement = replacement;
+			_mean = mean;
+			_count = count;
+			_hist = hist;
+		}
+
+		public void merge(ColInfo otherColInfo) {
+			if(_method != otherColInfo._method)
+				throw new DMLRuntimeException("Tried to merge two different impute methods: " + _method.name() + " vs. "
+					+ otherColInfo._method.name());
+			switch(_method) {
+				case CONSTANT:
+					assert _replacement.equals(otherColInfo._replacement);
+					break;
+				case GLOBAL_MEAN:
+					_mean._sum *= _count;
+					_mean._correction *= _count;
+					KahanPlus.getKahanPlusFnObject().execute(_mean, otherColInfo._mean._sum * otherColInfo._count);
+					KahanPlus.getKahanPlusFnObject().execute(_mean,
+						otherColInfo._mean._correction * otherColInfo._count);
+					_count += otherColInfo._count;
+					break;
+				case GLOBAL_MODE:
+					if (_hist == null)
+						_hist = new HashMap<>(otherColInfo._hist);
+					else
+						// add counts
+						_hist.replaceAll((key, count) -> count + otherColInfo._hist.getOrDefault(key, 0L));
+					break;
+				default:
+					throw new DMLRuntimeException("Method `" + _method.name() + "` not supported for federated impute");
+			}
+		}
+	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
index 26ba4e4..bbc83e4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
@@ -19,9 +19,7 @@
 
 package org.apache.sysds.runtime.transform.encode;
 
-import java.util.TreeSet;
-import java.util.stream.Collectors;
-
+import java.util.Arrays;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -38,8 +36,7 @@ public class EncoderOmit extends Encoder
 	private static final long serialVersionUID = 1978852120416654195L;
 
 	private boolean _federated = false;
-	//TODO perf replace with boolean[rlen] similar to removeEmpty
-	private TreeSet<Integer> _rmRows = new TreeSet<>();
+	private boolean[] _rmRows = new boolean[0];
 
 	public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
 		throws JSONException 
@@ -61,19 +58,24 @@ public class EncoderOmit extends Encoder
 		_federated = federated;
 	}
 	
-	
-	private EncoderOmit(int[] colList, int clen, TreeSet<Integer> rmRows) {
+	private EncoderOmit(int[] colList, int clen, boolean[] rmRows) {
 		super(colList, clen);
 		_rmRows = rmRows;
 		_federated = true;
 	}
-	
+
+	public int getNumRemovedRows(boolean[] rmRows) {
+		int cnt = 0;
+		for(boolean v : rmRows)
+			cnt += v ? 1 : 0;
+		return cnt;
+	}
+
 	public int getNumRemovedRows() {
-		return _rmRows.size();
+		return getNumRemovedRows(_rmRows);
 	}
 	
-	public boolean omit(String[] words, TfUtils agents) 
-	{
+	public boolean omit(String[] words, TfUtils agents) {
 		if( !isApplicable() )
 			return false;
 		
@@ -99,21 +101,21 @@ public class EncoderOmit extends Encoder
 	@Override
 	public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
 		// local rmRows for broadcasting encoder in spark
-		TreeSet<Integer> rmRows;
+		boolean[] rmRows;
 		if(_federated)
 			rmRows = _rmRows;
 		else
 			rmRows = computeRmRows(in);
 
 		// determine output size
-		int numRows = out.getNumRows() - rmRows.size();
+		int numRows = out.getNumRows() - getNumRemovedRows(rmRows);
 
 		// copy over valid rows into the output
 		MatrixBlock ret = new MatrixBlock(numRows, out.getNumColumns(), false);
 		int pos = 0;
 		for(int i = 0; i < in.getNumRows(); i++) {
 			// copy row if necessary
-			if(!rmRows.contains(i)) {
+			if(!rmRows[i]) {
 				for(int j = 0; j < out.getNumColumns(); j++)
 					ret.quickSetValue(pos, j, out.quickGetValue(i, j));
 				pos++;
@@ -125,17 +127,19 @@ public class EncoderOmit extends Encoder
 		return ret;
 	}
 
-	private TreeSet<Integer> computeRmRows(FrameBlock in) {
-		TreeSet<Integer> rmRows = new TreeSet<>();
+	private boolean[] computeRmRows(FrameBlock in) {
+		boolean[] rmRows = new boolean[in.getNumRows()];
 		ValueType[] schema = in.getSchema();
+		//TODO perf evaluate if column-wise scan more efficient
+		//  (sequential but less impact of early abort)
 		for(int i = 0; i < in.getNumRows(); i++) {
-			boolean valid = true;
 			for(int colID : _colList) {
 				Object val = in.get(i, colID - 1);
-				valid &= !(val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty()));
+				if (val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty())) {
+					rmRows[i] = true;
+					break; // early abort
+				}
 			}
-			if(!valid)
-				rmRows.add(i);
 		}
 		return rmRows;
 	}
@@ -146,38 +150,38 @@ public class EncoderOmit extends Encoder
 		if(colList.length == 0)
 			// empty encoder -> sub range encoder does not exist
 			return null;
-
-		TreeSet<Integer> rmRows = _rmRows.stream().filter((row) -> ixRange.inRowRange(row + 1))
-			.map((row) -> (int) (row - (ixRange.rowStart - 1))).collect(Collectors.toCollection(TreeSet::new));
+		boolean[] rmRows = _rmRows;
+		if (_rmRows.length > 0)
+			rmRows = Arrays.copyOfRange(rmRows, (int) ixRange.rowStart - 1, (int) ixRange.rowEnd - 1);
 
 		return new EncoderOmit(colList, (int) (ixRange.colSpan()), rmRows);
 	}
 
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderOmit) {
 			mergeColumnInfo(other, col);
-			_rmRows.addAll(((EncoderOmit) other)._rmRows);
+			EncoderOmit otherOmit = (EncoderOmit) other;
+			_rmRows = Arrays.copyOf(_rmRows, Math.max(_rmRows.length, (row - 1) + otherOmit._rmRows.length));
+			for (int i = 0; i < otherOmit._rmRows.length; i++)
+				_rmRows[(row - 1) + 1] |= otherOmit._rmRows[i];
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	@Override
 	public void updateIndexRanges(long[] beginDims, long[] endDims) {
 		// first update begin dims
 		int numRowsToRemove = 0;
-		Integer removedRow = _rmRows.ceiling(0);
-		while(removedRow != null && removedRow < beginDims[0]) {
-			numRowsToRemove++;
-			removedRow = _rmRows.ceiling(removedRow + 1);
-		}
+		for (int i = 0; i < beginDims[0] - 1 && i < _rmRows.length; i++)
+			if (_rmRows[i])
+				numRowsToRemove++;
 		beginDims[0] -= numRowsToRemove;
 		// update end dims
-		while(removedRow != null && removedRow < endDims[0]) {
-			numRowsToRemove++;
-			removedRow = _rmRows.ceiling(removedRow + 1);
-		}
+		for (int i = 0; i < endDims[0] - 1 && i < _rmRows.length; i++)
+			if (_rmRows[i])
+				numRowsToRemove++;
 		endDims[0] -= numRowsToRemove;
 	}
 	
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index ccd235d..ac414e9 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -88,12 +88,12 @@ public class EncoderPassThrough extends Encoder
 	}
 	
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderPassThrough) {
 			mergeColumnInfo(other, col);
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 
 	@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index d8d524a..6a1ea0b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -189,7 +189,7 @@ public class EncoderRecode extends Encoder
 	}
 
 	@Override
-	public void mergeAt(Encoder other, int col) {
+	public void mergeAt(Encoder other, int row, int col) {
 		if(other instanceof EncoderRecode) {
 			mergeColumnInfo(other, col);
 			
@@ -214,7 +214,7 @@ public class EncoderRecode extends Encoder
 			}
 			return;
 		}
-		super.mergeAt(other, col);
+		super.mergeAt(other, row, col);
 	}
 	
 	public int[] numDistinctValues() {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 622e6e0..3aa0981 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.functions.federated.transform;
 
+import java.io.IOException;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.FileFormat;
@@ -42,8 +43,8 @@ import org.junit.Test;
 public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 	private final static String TEST_NAME1 = "TransformFederatedEncodeApply";
 	private final static String TEST_DIR = "functions/transform/";
-	private final static String TEST_CLASS_DIR = TEST_DIR + TransformFederatedEncodeApplyTest.class.getSimpleName()
-		+ "/";
+	private final static String TEST_CLASS_DIR = TEST_DIR 
+		+ TransformFederatedEncodeApplyTest.class.getSimpleName() + "/";
 
 	// dataset and transform tasks without missing values
 	private final static String DATASET1 = "homes3/homes.csv";
@@ -64,8 +65,8 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 
 	// dataset and transform tasks with missing values
 	private final static String DATASET2 = "homes/homes.csv";
-	// private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
-	// private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
+	private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
+	private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
 	private final static String SPEC5 = "homes3/homes.tfspec_omit.json";
 	private final static String SPEC5b = "homes3/homes.tfspec_omit2.json";
 
@@ -73,11 +74,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 	private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3};
 
 	public enum TransformType {
-		RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY,
-		// IMPUTE,
-		OMIT,
-		HASH,
-		HASH_RECODE,
+		RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, IMPUTE, OMIT, HASH, HASH_RECODE,
 	}
 
 	@Override
@@ -116,10 +113,10 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 		runTransformTest(TransformType.OMIT, false);
 	}
 
-	// @Test
-	// public void testHomesImputeIDsCSV() {
-	// runTransformTest(TransformType.IMPUTE, false);
-	// }
+	@Test
+	public void testHomesImputeIDsCSV() {
+		runTransformTest(TransformType.IMPUTE, false);
+	}
 
 	@Test
 	public void testHomesRecodeColnamesCSV() {
@@ -151,10 +148,10 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 		runTransformTest(TransformType.OMIT, true);
 	}
 
-	// @Test
-	// public void testHomesImputeColnamesCSV() {
-	// runTransformTest(TransformType.IMPUTE, true);
-	// }
+	@Test
+	public void testHomesImputeColnamesCSV() {
+		runTransformTest(TransformType.IMPUTE, true);
+	}
 
 	@Test
 	public void testHomesHashColnamesCSV() {
@@ -186,7 +183,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 			case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET = DATASET1; break;
 			case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET = DATASET1; break;
 			case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET = DATASET1; break;
-			// case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break;
+			case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break;
 			case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET = DATASET2; break;
 			case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6; DATASET = DATASET1; break;
 			case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7; DATASET = DATASET1; break;
@@ -194,7 +191,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 			case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9; DATASET = DATASET1; break;
 		}
 
-		Thread t1 = null, t2 = null;
+		Thread t1 = null, t2 = null, t3 = null, t4 = null;
 		try {
 			getAndLoadTestConfiguration(TEST_NAME1);
 
@@ -202,11 +199,14 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 			t1 = startLocalFedWorkerThread(port1);
 			int port2 = getRandomAvailablePort();
 			t2 = startLocalFedWorkerThread(port2);
+			int port3 = getRandomAvailablePort();
+			t3 = startLocalFedWorkerThread(port3);
+			int port4 = getRandomAvailablePort();
+			t4 = startLocalFedWorkerThread(port4);
 
 			FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
-				DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE,
-				DATASET.equals(DATASET1) ? DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP
-					+ "");
+				DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
+				DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP + "");
 			String HOME = SCRIPT_DIR + TEST_DIR;
 			// split up dataset
 			FrameBlock dataset = FrameReaderFactory.createFrameReader(FileFormat.CSV, ffpCSV)
@@ -216,23 +216,37 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 			ffpCSV.setNAStrings(UtilFunctions.defaultNaString);
 			FrameWriter fw = FrameWriterFactory.createFrameWriter(FileFormat.CSV, ffpCSV);
 
-			FrameBlock A = new FrameBlock();
-			dataset.slice(0, dataset.getNumRows() - 1, 0, dataset.getNumColumns() / 2 - 1, A);
-			fw.writeFrameToHDFS(A, input("A"), A.getNumRows(), A.getNumColumns());
-			HDFSTool.writeMetaDataFile(input("A.mtd"), null, A.getSchema(), Types.DataType.FRAME,
-				new MatrixCharacteristics(A.getNumRows(), A.getNumColumns()), FileFormat.CSV, ffpCSV);
-
-			FrameBlock B = new FrameBlock();
-			dataset.slice(0, dataset.getNumRows() - 1, dataset.getNumColumns() / 2, dataset.getNumColumns() - 1, B);
-			fw.writeFrameToHDFS(B, input("B"), B.getNumRows(), B.getNumColumns());
-			HDFSTool.writeMetaDataFile(input("B.mtd"), null, B.getSchema(), Types.DataType.FRAME,
-				new MatrixCharacteristics(B.getNumRows(), B.getNumColumns()), FileFormat.CSV, ffpCSV);
+			writeDatasetSlice(dataset, fw, ffpCSV, "AH",
+				0,
+				dataset.getNumRows() / 2 - 1,
+				0,
+				dataset.getNumColumns() / 2 - 1);
+
+			writeDatasetSlice(dataset, fw, ffpCSV, "AL",
+				dataset.getNumRows() / 2,
+				dataset.getNumRows() - 1,
+				0,
+				dataset.getNumColumns() / 2 - 1);
+
+			writeDatasetSlice(dataset, fw, ffpCSV, "BH",
+				0,
+				dataset.getNumRows() / 2 - 1,
+				dataset.getNumColumns() / 2,
+				dataset.getNumColumns() - 1);
+
+			writeDatasetSlice(dataset, fw, ffpCSV, "BL",
+				dataset.getNumRows() / 2,
+				dataset.getNumRows() - 1,
+				dataset.getNumColumns() / 2,
+				dataset.getNumColumns() - 1);
 
 			fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
-			programArgs = new String[] {"-nvargs", "in_A=" + TestUtils.federatedAddress(port1, input("A")),
-				"in_B=" + TestUtils.federatedAddress(port2, input("B")), "rows=" + dataset.getNumRows(),
-				"cols_A=" + A.getNumColumns(), "cols_B=" + B.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC,
-				"TFDATA1=" + output("tfout1"), "TFDATA2=" + output("tfout2"), "OFMT=csv"};
+			programArgs = new String[] {"-nvargs", "in_AH=" + TestUtils.federatedAddress(port1, input("AH")),
+				"in_AL=" + TestUtils.federatedAddress(port2, input("AL")),
+				"in_BH=" + TestUtils.federatedAddress(port3, input("BH")),
+				"in_BL=" + TestUtils.federatedAddress(port4, input("BL")), "rows=" + dataset.getNumRows(),
+				"cols=" + dataset.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC, "TFDATA1=" + output("tfout1"),
+				"TFDATA2=" + output("tfout2"), "OFMT=csv"};
 
 			runTest(true, false, null, -1);
 
@@ -266,8 +280,18 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
 			throw new RuntimeException(ex);
 		}
 		finally {
-			TestUtils.shutdownThreads(t1, t2);
+			TestUtils.shutdownThreads(t1, t2, t3, t4);
 			resetExecMode(rtold);
 		}
 	}
+
+	private void writeDatasetSlice(FrameBlock dataset, FrameWriter fw, FileFormatPropertiesCSV ffpCSV, String name,
+		int rl, int ru, int cl, int cu) throws IOException {
+		FrameBlock AH = new FrameBlock();
+		dataset.slice(rl, ru, cl, cu, AH);
+		fw.writeFrameToHDFS(AH, input(name), AH.getNumRows(), AH.getNumColumns());
+		HDFSTool.writeMetaDataFile(input(DataExpression.getMTDFileName(name)), null, AH.getSchema(),
+			Types.DataType.FRAME, new MatrixCharacteristics(AH.getNumRows(), AH.getNumColumns()),
+			FileFormat.CSV, ffpCSV);
+	}
 }
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
index 921242b..28cdcda 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
@@ -19,9 +19,11 @@
 #
 #-------------------------------------------------------------
 
-F1 = federated(type="frame", addresses=list($in_A, $in_B), ranges=
-    list(list(0,0), list($rows, $cols_A), # A range
-    list(0, $cols_A), list($rows, $cols_A + $cols_B))); # B range
+F1 = federated(type="frame", addresses=list($in_AH, $in_AL, $in_BH, $in_BL), ranges=list(
+    list(0,0), list($rows / 2, $cols / 2), # AH range
+    list($rows / 2,0), list($rows, $cols / 2), # AL range
+    list(0,$cols / 2), list($rows / 2, $cols), # BH range
+    list($rows / 2,$cols / 2), list($rows, $cols))); # BL range
 
 jspec = read($TFSPEC, data_type="scalar", value_type="string");