You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/09/06 17:22:45 UTC
[systemds] branch master updated: [SYSTEMDS-2556,
2560] Federated transform impute, improved omit
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new d15fa6e [SYSTEMDS-2556,2560] Federated transform impute, improved omit
d15fa6e is described below
commit d15fa6e55682b8c8a1e85740b8f7cbffab4dd151
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Sun Sep 6 19:21:57 2020 +0200
[SYSTEMDS-2556,2560] Federated transform impute, improved omit
Closes #1046.
---
.../federated/FederatedWorkerHandler.java | 4 +-
...tiReturnParameterizedBuiltinFEDInstruction.java | 6 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 2 +-
.../sysds/runtime/transform/encode/Encoder.java | 9 +-
.../sysds/runtime/transform/encode/EncoderBin.java | 4 +-
.../runtime/transform/encode/EncoderComposite.java | 8 +-
.../runtime/transform/encode/EncoderDummycode.java | 4 +-
.../runtime/transform/encode/EncoderFactory.java | 2 +-
.../transform/encode/EncoderFeatureHash.java | 4 +-
.../runtime/transform/encode/EncoderMVImpute.java | 388 ++++++++++-----------
.../runtime/transform/encode/EncoderOmit.java | 74 ++--
.../transform/encode/EncoderPassThrough.java | 4 +-
.../runtime/transform/encode/EncoderRecode.java | 4 +-
.../TransformFederatedEncodeApplyTest.java | 100 ++++--
.../transform/TransformFederatedEncodeApply.dml | 8 +-
15 files changed, 320 insertions(+), 301 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b5f0ec8..2690ee6 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -100,9 +100,9 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
//select the response for the entire batch of requests
if (!tmp.isSuccessful()) {
- log.error("Command " + request.getType() + " failed: "
+ log.error("Command " + request.getType() + " failed: "
+ tmp.getErrorMessage() + "full command: \n" + request.toString());
- response = (response == null || response.isSuccessful())
+ response = (response == null || response.isSuccessful())
? tmp : response; //return first error
}
else if( request.getType() == RequestType.GET_VAR ) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 0fe12b9..047aff3 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -48,6 +48,7 @@ import org.apache.sysds.runtime.transform.encode.EncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash;
+import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
@@ -102,7 +103,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
new EncoderPassThrough(),
new EncoderBin(),
new EncoderDummycode(),
- new EncoderOmit(true)));
+ new EncoderOmit(true),
+ new EncoderMVImpute()));
// first create encoders at the federated workers, then collect them and aggregate them to a single large
// encoder
FederationMap fedMapping = fin.getFedMapping();
@@ -120,7 +122,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
Encoder encoder = (Encoder) response.getData()[0];
// merge this encoder into a composite encoder
synchronized(globalEncoder) {
- globalEncoder.mergeAt(encoder, columnOffset);
+ globalEncoder.mergeAt(encoder, (int) (range.getBeginDims()[0] + 1), columnOffset);
}
// no synchronization necessary since names should anyway match
String[] subRangeColNames = (String[]) response.getData()[1];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 204019f..4f31d4f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -266,7 +266,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
// no synchronization necessary since names should anyway match
Encoder builtEncoder = (Encoder) response.getData()[0];
- newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[1] + 1));
+ newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[0] + 1), (int) (range.getBeginDims()[1] + 1));
}
catch(Exception e) {
throw new DMLRuntimeException(e);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 7f47192..0758620 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -177,9 +177,10 @@ public abstract class Encoder implements Serializable
* other <code>Encoder</code>.
*
* @param other the encoder that should be merged in
- * @param col the position where it should be placed (1-based)
+ * @param row the row where it should be placed (1-based)
+ * @param col the col where it should be placed (1-based)
*/
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
throw new DMLRuntimeException(
this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName());
}
@@ -187,8 +188,8 @@ public abstract class Encoder implements Serializable
/**
* Update index-ranges to after encoding. Note that only Dummycoding changes the ranges.
*
- * @param beginDims the begin indexes before encoding
- * @param endDims the end indexes before encoding
+ * @param beginDims begin dimensions of range
+ * @param endDims end dimensions of range
*/
public void updateIndexRanges(long[] beginDims, long[] endDims) {
// do nothing - default
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index 351f68d..4caee9b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -169,7 +169,7 @@ public class EncoderBin extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderBin) {
EncoderBin otherBin = (EncoderBin) other;
@@ -217,7 +217,7 @@ public class EncoderBin extends Encoder
}
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index c494676..cc59932 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -117,7 +117,7 @@ public class EncoderComposite extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if (other instanceof EncoderComposite) {
EncoderComposite otherComposite = (EncoderComposite) other;
// TODO maybe assert that the _encoders never have the same type of encoder twice or more
@@ -125,7 +125,7 @@ public class EncoderComposite extends Encoder
boolean mergedIn = false;
for (Encoder encoder : _encoders) {
if (encoder.getClass() == otherEnc.getClass()) {
- encoder.mergeAt(otherEnc, col);
+ encoder.mergeAt(otherEnc, row, col);
mergedIn = true;
break;
}
@@ -146,7 +146,7 @@ public class EncoderComposite extends Encoder
}
for (Encoder encoder : _encoders) {
if (encoder.getClass() == other.getClass()) {
- encoder.mergeAt(other, col);
+ encoder.mergeAt(other, row, col);
// update dummycode encoder domain sizes based on distinctness information from other encoders
for (Encoder encDummy : _encoders) {
if (encDummy instanceof EncoderDummycode) {
@@ -157,7 +157,7 @@ public class EncoderComposite extends Encoder
return;
}
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index 19d41ea..f590a04 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -128,7 +128,7 @@ public class EncoderDummycode extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderDummycode) {
mergeColumnInfo(other, col);
@@ -138,7 +138,7 @@ public class EncoderDummycode extends Encoder
Arrays.fill(_domainSizes, 0, _colList.length, 1);
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 313e5b2..af929be 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -104,7 +104,7 @@ public class EncoderFactory
if( !oIDs.isEmpty() )
lencoders.add(new EncoderOmit(jSpec, colnames, schema.length, minCol, maxCol));
if( !mvIDs.isEmpty() ) {
- EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length);
+ EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length, minCol, maxCol);
ma.initRecodeIDList(rcIDs);
lencoders.add(ma);
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
index 9317dfb..3b6503b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
@@ -110,14 +110,14 @@ public class EncoderFeatureHash extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderFeatureHash) {
mergeColumnInfo(other, col);
if (((EncoderFeatureHash) other)._K != 0 && _K == 0)
_K = ((EncoderFeatureHash) other)._K;
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 56749a2..534d16c 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -19,263 +19,119 @@
package org.apache.sysds.runtime.transform.encode;
-import java.io.IOException;
-import java.util.BitSet;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Map.Entry;
+import java.util.Set;
+import java.util.stream.Collectors;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
-import org.apache.sysds.runtime.functionobjects.CM;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
-import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
-import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
-public class EncoderMVImpute extends Encoder
+public class EncoderMVImpute extends Encoder
{
private static final long serialVersionUID = 9057868620144662194L;
public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT }
private MVMethod[] _mvMethodList = null;
- private MVMethod[] _mvscMethodList = null; // scaling methods for attributes that are imputed and also scaled
-
- private BitSet _isMVScaled = null;
- private CM _varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE); // function object that understands variance computation
// objects required to compute mean and variance of all non-missing entries
- private Mean _meanFn = Mean.getMeanFnObject(); // function object that understands mean computation
+ private final Mean _meanFn = Mean.getMeanFnObject(); // function object that understands mean computation
private KahanObject[] _meanList = null; // column-level means, computed so far
private long[] _countList = null; // #of non-missing values
- private CM_COV_Object[] _varList = null; // column-level variances, computed so far (for scaling)
-
- private int[] _scnomvList = null; // List of attributes that are scaled but not imputed
- private MVMethod[] _scnomvMethodList = null; // scaling methods: 0 for invalid; 1 for mean-subtraction; 2 for z-scoring
- private KahanObject[] _scnomvMeanList = null; // column-level means, for attributes scaled but not imputed
- private long[] _scnomvCountList = null; // #of non-missing values, for attributes scaled but not imputed
- private CM_COV_Object[] _scnomvVarList = null; // column-level variances, computed so far
-
private String[] _replacementList = null; // replacements: for global_mean, mean; and for global_mode, recode id of mode category
- private String[] _NAstrings = null;
private List<Integer> _rcList = null;
private HashMap<Integer,HashMap<String,Long>> _hist = null;
public String[] getReplacements() { return _replacementList; }
public KahanObject[] getMeans() { return _meanList; }
- public CM_COV_Object[] getVars() { return _varList; }
- public KahanObject[] getMeans_scnomv() { return _scnomvMeanList; }
- public CM_COV_Object[] getVars_scnomv() { return _scnomvVarList; }
- public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen)
+ public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
throws JSONException
{
super(null, clen);
//handle column list
- int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), -1, -1);
+ int[] collist = TfMetaUtils
+ .parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol);
initColList(collist);
//handle method list
- parseMethodsAndReplacments(parsedSpec);
+ parseMethodsAndReplacements(parsedSpec, colnames, minCol);
//create reuse histograms
_hist = new HashMap<>();
}
- public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, String[] NAstrings, int clen)
- throws JSONException
- {
- super(null, clen);
- boolean isMV = parsedSpec.containsKey(TfMethod.IMPUTE.toString());
- boolean isSC = parsedSpec.containsKey(TfMethod.SCALE.toString());
- _NAstrings = NAstrings;
-
- if(!isMV) {
- // MV Impute is not applicable
- _colList = null;
- _mvMethodList = null;
- _meanList = null;
- _countList = null;
- _replacementList = null;
- }
- else {
- JSONObject mvobj = (JSONObject) parsedSpec.get(TfMethod.IMPUTE.toString());
- JSONArray mvattrs = (JSONArray) mvobj.get(TfUtils.JSON_ATTRS);
- JSONArray mvmthds = (JSONArray) mvobj.get(TfUtils.JSON_MTHD);
- int mvLength = mvattrs.size();
-
- _colList = new int[mvLength];
- _mvMethodList = new MVMethod[mvLength];
-
- _meanList = new KahanObject[mvLength];
- _countList = new long[mvLength];
- _varList = new CM_COV_Object[mvLength];
-
- _isMVScaled = new BitSet(_colList.length);
- _isMVScaled.clear();
-
- for(int i=0; i < _colList.length; i++) {
- _colList[i] = UtilFunctions.toInt(mvattrs.get(i));
- _mvMethodList[i] = MVMethod.values()[UtilFunctions.toInt(mvmthds.get(i))];
- _meanList[i] = new KahanObject(0, 0);
- }
-
- _replacementList = new String[mvLength]; // contains replacements for all columns (scale and categorical)
-
- JSONArray constants = (JSONArray)mvobj.get(TfUtils.JSON_CONSTS);
- for(int i=0; i < constants.size(); i++) {
- if ( constants.get(i) == null )
- _replacementList[i] = "NaN";
- else
- _replacementList[i] = constants.get(i).toString();
- }
- }
-
- // Handle scaled attributes
- if ( !isSC )
- {
- // scaling is not applicable
- _scnomvCountList = null;
- _scnomvMeanList = null;
- _scnomvVarList = null;
- }
- else
- {
- if ( _colList != null )
- _mvscMethodList = new MVMethod[_colList.length];
-
- JSONObject scobj = (JSONObject) parsedSpec.get(TfMethod.SCALE.toString());
- JSONArray scattrs = (JSONArray) scobj.get(TfUtils.JSON_ATTRS);
- JSONArray scmthds = (JSONArray) scobj.get(TfUtils.JSON_MTHD);
- int scLength = scattrs.size();
-
- int[] _allscaled = new int[scLength];
- int scnomv = 0, colID;
- byte mthd;
- for(int i=0; i < scLength; i++)
- {
- colID = UtilFunctions.toInt(scattrs.get(i));
- mthd = (byte) UtilFunctions.toInt(scmthds.get(i));
-
- _allscaled[i] = colID;
-
- // check if the attribute is also MV imputed
- int mvidx = isApplicable(colID);
- if(mvidx != -1)
- {
- _isMVScaled.set(mvidx);
- _mvscMethodList[mvidx] = MVMethod.values()[mthd];
- _varList[mvidx] = new CM_COV_Object();
- }
- else
- scnomv++; // count of scaled but not imputed
- }
-
- if(scnomv > 0)
- {
- _scnomvList = new int[scnomv];
- _scnomvMethodList = new MVMethod[scnomv];
+ public EncoderMVImpute() {
+ super(new int[0], 0);
+ }
- _scnomvMeanList = new KahanObject[scnomv];
- _scnomvCountList = new long[scnomv];
- _scnomvVarList = new CM_COV_Object[scnomv];
-
- for(int i=0, idx=0; i < scLength; i++)
- {
- colID = UtilFunctions.toInt(scattrs.get(i));
- mthd = (byte)UtilFunctions.toInt(scmthds.get(i));
-
- if(isApplicable(colID) == -1)
- { // scaled but not imputed
- _scnomvList[idx] = colID;
- _scnomvMethodList[idx] = MVMethod.values()[mthd];
- _scnomvMeanList[idx] = new KahanObject(0, 0);
- _scnomvVarList[idx] = new CM_COV_Object();
- idx++;
- }
- }
- }
- }
+
+ public EncoderMVImpute(int[] colList, MVMethod[] mvMethodList, String[] replacementList, KahanObject[] meanList,
+ long[] countList, List<Integer> rcList, int clen) {
+ super(colList, clen);
+ _mvMethodList = mvMethodList;
+ _replacementList = replacementList;
+ _meanList = meanList;
+ _countList = countList;
+ _rcList = rcList;
}
-
- private void parseMethodsAndReplacments(JSONObject parsedSpec) throws JSONException {
+
+ private void parseMethodsAndReplacements(JSONObject parsedSpec, String[] colnames, int offset) throws JSONException {
JSONArray mvspec = (JSONArray) parsedSpec.get(TfMethod.IMPUTE.toString());
+ boolean ids = parsedSpec.containsKey("ids") && parsedSpec.getBoolean("ids");
+ // make space for all elements
_mvMethodList = new MVMethod[mvspec.size()];
_replacementList = new String[mvspec.size()];
_meanList = new KahanObject[mvspec.size()];
_countList = new long[mvspec.size()];
- for(int i=0; i < mvspec.size(); i++) {
- JSONObject mvobj = (JSONObject)mvspec.get(i);
- _mvMethodList[i] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
- if( _mvMethodList[i] == MVMethod.CONSTANT ) {
- _replacementList[i] = mvobj.getString("value").toString();
- }
- _meanList[i] = new KahanObject(0, 0);
- }
- }
-
- public void prepare(String[] words) throws IOException {
+ // sort for binary search
+ Arrays.sort(_colList);
- try {
- String w = null;
- if(_colList != null)
- for(int i=0; i <_colList.length; i++) {
- int colID = _colList[i];
- w = UtilFunctions.unquote(words[colID-1].trim());
-
- try {
- if(!TfUtils.isNA(_NAstrings, w)) {
- _countList[i]++;
-
- boolean computeMean = (_mvMethodList[i] == MVMethod.GLOBAL_MEAN || _isMVScaled.get(i) );
- if(computeMean) {
- // global_mean
- double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
- _meanFn.execute2(_meanList[i], d, _countList[i]);
-
- if (_isMVScaled.get(i) && _mvscMethodList[i] == MVMethod.GLOBAL_MODE)
- _varFn.execute(_varList[i], d);
- }
- else {
- // global_mode or constant
- // Nothing to do here. Mode is computed using recode maps.
- }
- }
- } catch (NumberFormatException e)
- {
- throw new RuntimeException("Encountered \"" + w + "\" in column ID \"" + colID + "\", when expecting a numeric value. Consider adding \"" + w + "\" to na.strings, along with an appropriate imputation method.");
+ int listIx = 0;
+ for(Object o : mvspec) {
+ JSONObject mvobj = (JSONObject) o;
+ int ixOffset = offset == -1 ? 0 : offset - 1;
+ // check for position -> -1 if not present
+ int pos = Arrays.binarySearch(_colList,
+ ids ? mvobj.getInt("id") - ixOffset : ArrayUtils.indexOf(colnames, mvobj.get("name")) + 1);
+ if(pos >= 0) {
+ // add to arrays
+ _mvMethodList[listIx] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
+ if(_mvMethodList[listIx] == MVMethod.CONSTANT) {
+ _replacementList[listIx] = mvobj.getString("value");
}
+ _meanList[listIx++] = new KahanObject(0, 0);
}
-
- // Compute mean and variance for attributes that are scaled but not imputed
- if(_scnomvList != null)
- for(int i=0; i < _scnomvList.length; i++)
- {
- int colID = _scnomvList[i];
- w = UtilFunctions.unquote(words[colID-1].trim());
- double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString);
- _scnomvCountList[i]++; // not required, this is always equal to total #records processed
- _meanFn.execute2(_scnomvMeanList[i], d, _scnomvCountList[i]);
- if(_scnomvMethodList[i] == MVMethod.GLOBAL_MODE)
- _varFn.execute(_scnomvVarList[i], d);
- }
- } catch(Exception e) {
- throw new IOException(e);
}
+ // make arrays required size
+ _mvMethodList = Arrays.copyOf(_mvMethodList, listIx);
+ _replacementList = Arrays.copyOf(_replacementList, listIx);
+ _meanList = Arrays.copyOf(_meanList, listIx);
+ _countList = Arrays.copyOf(_countList, listIx);
}
public MVMethod getMethod(int colID) {
- int idx = isApplicable(colID);
+ int idx = isApplicable(colID);
if(idx == -1)
return MVMethod.INVALID;
else
@@ -287,8 +143,8 @@ public class EncoderMVImpute extends Encoder
return (idx == -1) ? 0 : _countList[idx];
}
- public String getReplacement(int colID) {
- int idx = isApplicable(colID);
+ public String getReplacement(int colID) {
+ int idx = isApplicable(colID);
return (idx == -1) ? null : _replacementList[idx];
}
@@ -321,7 +177,7 @@ public class EncoderMVImpute extends Encoder
if( key != null && !key.isEmpty() ) {
Long val = hist.get(key);
hist.put(key, (val!=null) ? val+1 : 1);
- }
+ }
}
_hist.put(colID, hist);
long max = Long.MIN_VALUE;
@@ -349,12 +205,98 @@ public class EncoderMVImpute extends Encoder
}
return out;
}
+
+ @Override
+ public Encoder subRangeEncoder(IndexRange ixRange) {
+ Map<Integer, ColInfo> map = new HashMap<>();
+ for(int i = 0; i < _colList.length; i++) {
+ int col = _colList[i];
+ if(ixRange.inColRange(col))
+ map.put((int) (_colList[i] - (ixRange.colStart - 1)),
+ new ColInfo(_mvMethodList[i], _replacementList[i], _meanList[i], _countList[i], _hist.get(i)));
+ }
+ if(map.size() == 0)
+ // empty encoder -> sub range encoder does not exist
+ return null;
+
+ int[] colList = new int[map.size()];
+ MVMethod[] mvMethodList = new MVMethod[map.size()];
+ String[] replacementList = new String[map.size()];
+ KahanObject[] meanList = new KahanObject[map.size()];
+ long[] countList = new long[map.size()];
+
+ fillListsFromMap(map, colList, mvMethodList, replacementList, meanList, countList, _hist);
+
+ if(_rcList == null)
+ _rcList = new ArrayList<>();
+ List<Integer> rcList = _rcList.stream().filter(ixRange::inColRange).map(i -> (int) (i - (ixRange.colStart - 1)))
+ .collect(Collectors.toList());
+
+ return new EncoderMVImpute(colList, mvMethodList, replacementList, meanList, countList, rcList,
+ (int) ixRange.colSpan());
+ }
+
+ private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] colList, MVMethod[] mvMethodList,
+ String[] replacementList, KahanObject[] meanList, long[] countList,
+ HashMap<Integer, HashMap<String, Long>> hist) {
+ int i = 0;
+ for(Entry<Integer, ColInfo> entry : map.entrySet()) {
+ colList[i] = entry.getKey();
+ mvMethodList[i] = entry.getValue()._method;
+ replacementList[i] = entry.getValue()._replacement;
+ meanList[i] = entry.getValue()._mean;
+ countList[i++] = entry.getValue()._count;
+
+ hist.put(entry.getKey(), entry.getValue()._hist);
+ }
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int row, int col) {
+ if(other instanceof EncoderMVImpute) {
+ EncoderMVImpute otherImpute = (EncoderMVImpute) other;
+ Map<Integer, ColInfo> map = new HashMap<>();
+ for(int i = 0; i < _colList.length; i++) {
+ map.put(_colList[i],
+ new ColInfo(_mvMethodList[i], _replacementList[i], _meanList[i], _countList[i], _hist.get(i + 1)));
+ }
+ for(int i = 0; i < other._colList.length; i++) {
+ int column = other._colList[i] + (col - 1);
+ ColInfo otherColInfo = new ColInfo(otherImpute._mvMethodList[i], otherImpute._replacementList[i],
+ otherImpute._meanList[i], otherImpute._countList[i], otherImpute._hist.get(i + 1));
+ ColInfo colInfo = map.get(column);
+ if(colInfo == null)
+ map.put(column, otherColInfo);
+ else
+ colInfo.merge(otherColInfo);
+ }
+
+ _colList = new int[map.size()];
+ _mvMethodList = new MVMethod[map.size()];
+ _replacementList = new String[map.size()];
+ _meanList = new KahanObject[map.size()];
+ _countList = new long[map.size()];
+ _hist = new HashMap<>();
+
+ fillListsFromMap(map, _colList, _mvMethodList, _replacementList, _meanList, _countList, _hist);
+ // update number of columns
+ _clen = Math.max(_clen, col - 1 + other._clen);
+
+ if(_rcList == null)
+ _rcList = new ArrayList<>();
+ Set<Integer> rcSet = new HashSet<>(_rcList);
+ rcSet.addAll(otherImpute._rcList.stream().map(i -> i + (col - 1)).collect(Collectors.toSet()));
+ _rcList = new ArrayList<>(rcSet);
+ return;
+ }
+ super.mergeAt(other, row, col);
+ }
@Override
public FrameBlock getMetaData(FrameBlock out) {
for( int j=0; j<_colList.length; j++ ) {
out.getColumnMetadata(_colList[j]-1)
- .setMvValue(_replacementList[j]);
+ .setMvValue(_replacementList[j]);
}
return out;
}
@@ -391,4 +333,48 @@ public class EncoderMVImpute extends Encoder
public HashMap<String,Long> getHistogram( int colID ) {
return _hist.get(colID);
}
+
+ private static class ColInfo {
+ MVMethod _method;
+ String _replacement;
+ KahanObject _mean;
+ long _count;
+ HashMap<String, Long> _hist;
+
+ ColInfo(MVMethod method, String replacement, KahanObject mean, long count, HashMap<String, Long> hist) {
+ _method = method;
+ _replacement = replacement;
+ _mean = mean;
+ _count = count;
+ _hist = hist;
+ }
+
+ public void merge(ColInfo otherColInfo) {
+ if(_method != otherColInfo._method)
+ throw new DMLRuntimeException("Tried to merge two different impute methods: " + _method.name() + " vs. "
+ + otherColInfo._method.name());
+ switch(_method) {
+ case CONSTANT:
+ assert _replacement.equals(otherColInfo._replacement);
+ break;
+ case GLOBAL_MEAN:
+ _mean._sum *= _count;
+ _mean._correction *= _count;
+ KahanPlus.getKahanPlusFnObject().execute(_mean, otherColInfo._mean._sum * otherColInfo._count);
+ KahanPlus.getKahanPlusFnObject().execute(_mean,
+ otherColInfo._mean._correction * otherColInfo._count);
+ _count += otherColInfo._count;
+ break;
+ case GLOBAL_MODE:
+ if (_hist == null)
+ _hist = new HashMap<>(otherColInfo._hist);
+ else
+ // add counts
+ _hist.replaceAll((key, count) -> count + otherColInfo._hist.getOrDefault(key, 0L));
+ break;
+ default:
+ throw new DMLRuntimeException("Method `" + _method.name() + "` not supported for federated impute");
+ }
+ }
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
index 26ba4e4..bbc83e4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
@@ -19,9 +19,7 @@
package org.apache.sysds.runtime.transform.encode;
-import java.util.TreeSet;
-import java.util.stream.Collectors;
-
+import java.util.Arrays;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -38,8 +36,7 @@ public class EncoderOmit extends Encoder
private static final long serialVersionUID = 1978852120416654195L;
private boolean _federated = false;
- //TODO perf replace with boolean[rlen] similar to removeEmpty
- private TreeSet<Integer> _rmRows = new TreeSet<>();
+ private boolean[] _rmRows = new boolean[0];
public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
throws JSONException
@@ -61,19 +58,24 @@ public class EncoderOmit extends Encoder
_federated = federated;
}
-
- private EncoderOmit(int[] colList, int clen, TreeSet<Integer> rmRows) {
+ private EncoderOmit(int[] colList, int clen, boolean[] rmRows) {
super(colList, clen);
_rmRows = rmRows;
_federated = true;
}
-
+
+ public int getNumRemovedRows(boolean[] rmRows) {
+ int cnt = 0;
+ for(boolean v : rmRows)
+ cnt += v ? 1 : 0;
+ return cnt;
+ }
+
public int getNumRemovedRows() {
- return _rmRows.size();
+ return getNumRemovedRows(_rmRows);
}
- public boolean omit(String[] words, TfUtils agents)
- {
+ public boolean omit(String[] words, TfUtils agents) {
if( !isApplicable() )
return false;
@@ -99,21 +101,21 @@ public class EncoderOmit extends Encoder
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
// local rmRows for broadcasting encoder in spark
- TreeSet<Integer> rmRows;
+ boolean[] rmRows;
if(_federated)
rmRows = _rmRows;
else
rmRows = computeRmRows(in);
// determine output size
- int numRows = out.getNumRows() - rmRows.size();
+ int numRows = out.getNumRows() - getNumRemovedRows(rmRows);
// copy over valid rows into the output
MatrixBlock ret = new MatrixBlock(numRows, out.getNumColumns(), false);
int pos = 0;
for(int i = 0; i < in.getNumRows(); i++) {
// copy row if necessary
- if(!rmRows.contains(i)) {
+ if(!rmRows[i]) {
for(int j = 0; j < out.getNumColumns(); j++)
ret.quickSetValue(pos, j, out.quickGetValue(i, j));
pos++;
@@ -125,17 +127,19 @@ public class EncoderOmit extends Encoder
return ret;
}
- private TreeSet<Integer> computeRmRows(FrameBlock in) {
- TreeSet<Integer> rmRows = new TreeSet<>();
+ private boolean[] computeRmRows(FrameBlock in) {
+ boolean[] rmRows = new boolean[in.getNumRows()];
ValueType[] schema = in.getSchema();
+ //TODO perf evaluate if column-wise scan more efficient
+ // (sequential but less impact of early abort)
for(int i = 0; i < in.getNumRows(); i++) {
- boolean valid = true;
for(int colID : _colList) {
Object val = in.get(i, colID - 1);
- valid &= !(val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty()));
+ if (val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty())) {
+ rmRows[i] = true;
+ break; // early abort
+ }
}
- if(!valid)
- rmRows.add(i);
}
return rmRows;
}
@@ -146,38 +150,38 @@ public class EncoderOmit extends Encoder
if(colList.length == 0)
// empty encoder -> sub range encoder does not exist
return null;
-
- TreeSet<Integer> rmRows = _rmRows.stream().filter((row) -> ixRange.inRowRange(row + 1))
- .map((row) -> (int) (row - (ixRange.rowStart - 1))).collect(Collectors.toCollection(TreeSet::new));
+ boolean[] rmRows = _rmRows;
+ if (_rmRows.length > 0)
+ rmRows = Arrays.copyOfRange(rmRows, (int) ixRange.rowStart - 1, (int) ixRange.rowEnd - 1);
return new EncoderOmit(colList, (int) (ixRange.colSpan()), rmRows);
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderOmit) {
mergeColumnInfo(other, col);
- _rmRows.addAll(((EncoderOmit) other)._rmRows);
+ EncoderOmit otherOmit = (EncoderOmit) other;
+ _rmRows = Arrays.copyOf(_rmRows, Math.max(_rmRows.length, (row - 1) + otherOmit._rmRows.length));
+ for (int i = 0; i < otherOmit._rmRows.length; i++)
+ _rmRows[(row - 1) + 1] |= otherOmit._rmRows[i];
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
public void updateIndexRanges(long[] beginDims, long[] endDims) {
// first update begin dims
int numRowsToRemove = 0;
- Integer removedRow = _rmRows.ceiling(0);
- while(removedRow != null && removedRow < beginDims[0]) {
- numRowsToRemove++;
- removedRow = _rmRows.ceiling(removedRow + 1);
- }
+ for (int i = 0; i < beginDims[0] - 1 && i < _rmRows.length; i++)
+ if (_rmRows[i])
+ numRowsToRemove++;
beginDims[0] -= numRowsToRemove;
// update end dims
- while(removedRow != null && removedRow < endDims[0]) {
- numRowsToRemove++;
- removedRow = _rmRows.ceiling(removedRow + 1);
- }
+ for (int i = 0; i < endDims[0] - 1 && i < _rmRows.length; i++)
+ if (_rmRows[i])
+ numRowsToRemove++;
endDims[0] -= numRowsToRemove;
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index ccd235d..ac414e9 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -88,12 +88,12 @@ public class EncoderPassThrough extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderPassThrough) {
mergeColumnInfo(other, col);
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index d8d524a..6a1ea0b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -189,7 +189,7 @@ public class EncoderRecode extends Encoder
}
@Override
- public void mergeAt(Encoder other, int col) {
+ public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderRecode) {
mergeColumnInfo(other, col);
@@ -214,7 +214,7 @@ public class EncoderRecode extends Encoder
}
return;
}
- super.mergeAt(other, col);
+ super.mergeAt(other, row, col);
}
public int[] numDistinctValues() {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index 622e6e0..3aa0981 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -19,6 +19,7 @@
package org.apache.sysds.test.functions.federated.transform;
+import java.io.IOException;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FileFormat;
@@ -42,8 +43,8 @@ import org.junit.Test;
public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "TransformFederatedEncodeApply";
private final static String TEST_DIR = "functions/transform/";
- private final static String TEST_CLASS_DIR = TEST_DIR + TransformFederatedEncodeApplyTest.class.getSimpleName()
- + "/";
+ private final static String TEST_CLASS_DIR = TEST_DIR
+ + TransformFederatedEncodeApplyTest.class.getSimpleName() + "/";
// dataset and transform tasks without missing values
private final static String DATASET1 = "homes3/homes.csv";
@@ -64,8 +65,8 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
// dataset and transform tasks with missing values
private final static String DATASET2 = "homes/homes.csv";
- // private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
- // private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
+ private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
+ private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
private final static String SPEC5 = "homes3/homes.tfspec_omit.json";
private final static String SPEC5b = "homes3/homes.tfspec_omit2.json";
@@ -73,11 +74,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3};
public enum TransformType {
- RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY,
- // IMPUTE,
- OMIT,
- HASH,
- HASH_RECODE,
+ RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, IMPUTE, OMIT, HASH, HASH_RECODE,
}
@Override
@@ -116,10 +113,10 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
runTransformTest(TransformType.OMIT, false);
}
- // @Test
- // public void testHomesImputeIDsCSV() {
- // runTransformTest(TransformType.IMPUTE, false);
- // }
+ @Test
+ public void testHomesImputeIDsCSV() {
+ runTransformTest(TransformType.IMPUTE, false);
+ }
@Test
public void testHomesRecodeColnamesCSV() {
@@ -151,10 +148,10 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
runTransformTest(TransformType.OMIT, true);
}
- // @Test
- // public void testHomesImputeColnamesCSV() {
- // runTransformTest(TransformType.IMPUTE, true);
- // }
+ @Test
+ public void testHomesImputeColnamesCSV() {
+ runTransformTest(TransformType.IMPUTE, true);
+ }
@Test
public void testHomesHashColnamesCSV() {
@@ -186,7 +183,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET = DATASET1; break;
case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET = DATASET1; break;
case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET = DATASET1; break;
- // case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break;
+ case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break;
case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET = DATASET2; break;
case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6; DATASET = DATASET1; break;
case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7; DATASET = DATASET1; break;
@@ -194,7 +191,7 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9; DATASET = DATASET1; break;
}
- Thread t1 = null, t2 = null;
+ Thread t1 = null, t2 = null, t3 = null, t4 = null;
try {
getAndLoadTestConfiguration(TEST_NAME1);
@@ -202,11 +199,14 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
t1 = startLocalFedWorkerThread(port1);
int port2 = getRandomAvailablePort();
t2 = startLocalFedWorkerThread(port2);
+ int port3 = getRandomAvailablePort();
+ t3 = startLocalFedWorkerThread(port3);
+ int port4 = getRandomAvailablePort();
+ t4 = startLocalFedWorkerThread(port4);
FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
- DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE,
- DATASET.equals(DATASET1) ? DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP
- + "");
+ DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE, DATASET.equals(DATASET1) ?
+ DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP + "");
String HOME = SCRIPT_DIR + TEST_DIR;
// split up dataset
FrameBlock dataset = FrameReaderFactory.createFrameReader(FileFormat.CSV, ffpCSV)
@@ -216,23 +216,37 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
ffpCSV.setNAStrings(UtilFunctions.defaultNaString);
FrameWriter fw = FrameWriterFactory.createFrameWriter(FileFormat.CSV, ffpCSV);
- FrameBlock A = new FrameBlock();
- dataset.slice(0, dataset.getNumRows() - 1, 0, dataset.getNumColumns() / 2 - 1, A);
- fw.writeFrameToHDFS(A, input("A"), A.getNumRows(), A.getNumColumns());
- HDFSTool.writeMetaDataFile(input("A.mtd"), null, A.getSchema(), Types.DataType.FRAME,
- new MatrixCharacteristics(A.getNumRows(), A.getNumColumns()), FileFormat.CSV, ffpCSV);
-
- FrameBlock B = new FrameBlock();
- dataset.slice(0, dataset.getNumRows() - 1, dataset.getNumColumns() / 2, dataset.getNumColumns() - 1, B);
- fw.writeFrameToHDFS(B, input("B"), B.getNumRows(), B.getNumColumns());
- HDFSTool.writeMetaDataFile(input("B.mtd"), null, B.getSchema(), Types.DataType.FRAME,
- new MatrixCharacteristics(B.getNumRows(), B.getNumColumns()), FileFormat.CSV, ffpCSV);
+ writeDatasetSlice(dataset, fw, ffpCSV, "AH",
+ 0,
+ dataset.getNumRows() / 2 - 1,
+ 0,
+ dataset.getNumColumns() / 2 - 1);
+
+ writeDatasetSlice(dataset, fw, ffpCSV, "AL",
+ dataset.getNumRows() / 2,
+ dataset.getNumRows() - 1,
+ 0,
+ dataset.getNumColumns() / 2 - 1);
+
+ writeDatasetSlice(dataset, fw, ffpCSV, "BH",
+ 0,
+ dataset.getNumRows() / 2 - 1,
+ dataset.getNumColumns() / 2,
+ dataset.getNumColumns() - 1);
+
+ writeDatasetSlice(dataset, fw, ffpCSV, "BL",
+ dataset.getNumRows() / 2,
+ dataset.getNumRows() - 1,
+ dataset.getNumColumns() / 2,
+ dataset.getNumColumns() - 1);
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
- programArgs = new String[] {"-nvargs", "in_A=" + TestUtils.federatedAddress(port1, input("A")),
- "in_B=" + TestUtils.federatedAddress(port2, input("B")), "rows=" + dataset.getNumRows(),
- "cols_A=" + A.getNumColumns(), "cols_B=" + B.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC,
- "TFDATA1=" + output("tfout1"), "TFDATA2=" + output("tfout2"), "OFMT=csv"};
+ programArgs = new String[] {"-nvargs", "in_AH=" + TestUtils.federatedAddress(port1, input("AH")),
+ "in_AL=" + TestUtils.federatedAddress(port2, input("AL")),
+ "in_BH=" + TestUtils.federatedAddress(port3, input("BH")),
+ "in_BL=" + TestUtils.federatedAddress(port4, input("BL")), "rows=" + dataset.getNumRows(),
+ "cols=" + dataset.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC, "TFDATA1=" + output("tfout1"),
+ "TFDATA2=" + output("tfout2"), "OFMT=csv"};
runTest(true, false, null, -1);
@@ -266,8 +280,18 @@ public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
throw new RuntimeException(ex);
}
finally {
- TestUtils.shutdownThreads(t1, t2);
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
resetExecMode(rtold);
}
}
+
+ private void writeDatasetSlice(FrameBlock dataset, FrameWriter fw, FileFormatPropertiesCSV ffpCSV, String name,
+ int rl, int ru, int cl, int cu) throws IOException {
+ FrameBlock AH = new FrameBlock();
+ dataset.slice(rl, ru, cl, cu, AH);
+ fw.writeFrameToHDFS(AH, input(name), AH.getNumRows(), AH.getNumColumns());
+ HDFSTool.writeMetaDataFile(input(DataExpression.getMTDFileName(name)), null, AH.getSchema(),
+ Types.DataType.FRAME, new MatrixCharacteristics(AH.getNumRows(), AH.getNumColumns()),
+ FileFormat.CSV, ffpCSV);
+ }
}
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
index 921242b..28cdcda 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
@@ -19,9 +19,11 @@
#
#-------------------------------------------------------------
-F1 = federated(type="frame", addresses=list($in_A, $in_B), ranges=
- list(list(0,0), list($rows, $cols_A), # A range
- list(0, $cols_A), list($rows, $cols_A + $cols_B))); # B range
+F1 = federated(type="frame", addresses=list($in_AH, $in_AL, $in_BH, $in_BL), ranges=list(
+ list(0,0), list($rows / 2, $cols / 2), # AH range
+ list($rows / 2,0), list($rows, $cols / 2), # AL range
+ list(0,$cols / 2), list($rows / 2, $cols), # BH range
+ list($rows / 2,$cols / 2), list($rows, $cols))); # BL range
jspec = read($TFSPEC, data_type="scalar", value_type="string");