You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/28 23:02:23 UTC
[systemds] branch master updated: [SYSTEMDS-2555] Federated
transform dummycode encode/decode
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 9efd321 [SYSTEMDS-2555] Federated transform dummycode encode/decode
9efd321 is described below
commit 9efd32125ac0142fb108f8b25f5b8cbecea1c06c
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Sat Aug 29 00:17:51 2020 +0200
[SYSTEMDS-2555] Federated transform dummycode encode/decode
Closes #1031.
---
.../controlprogram/federated/FederationMap.java | 7 ++
...tiReturnParameterizedBuiltinFEDInstruction.java | 25 +++--
.../fed/ParameterizedBuiltinFEDInstruction.java | 62 ++++++-----
.../sysds/runtime/transform/decode/Decoder.java | 27 ++++-
.../runtime/transform/decode/DecoderComposite.java | 27 ++++-
.../runtime/transform/decode/DecoderDummycode.java | 59 +++++++++-
.../transform/decode/DecoderPassThrough.java | 40 ++++++-
.../runtime/transform/decode/DecoderRecode.java | 29 +++++
.../sysds/runtime/transform/encode/Encoder.java | 10 ++
.../sysds/runtime/transform/encode/EncoderBin.java | 2 +-
.../runtime/transform/encode/EncoderComposite.java | 20 +++-
.../runtime/transform/encode/EncoderDummycode.java | 122 +++++++++++++++++++--
.../runtime/transform/encode/EncoderFactory.java | 2 +-
.../transform/encode/EncoderPassThrough.java | 3 -
.../runtime/transform/encode/EncoderRecode.java | 13 ++-
.../TransformFederatedEncodeDecodeTest.java | 96 +++++++++++-----
...dml => TransformDummyFederatedEncodeDecode.dml} | 5 -
.../transform/TransformEncodeDecodeDummySpec.json | 5 +
...ml => TransformRecodeFederatedEncodeDecode.dml} | 0
19 files changed, 457 insertions(+), 97 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 72d1196..ea8aa29 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -240,6 +240,13 @@ public class FederationMap
return this;
}
+
+ public long getMaxIndexInRange(int dim) {
+ long maxIx = 0;
+ for(FederatedRange range : _fedMap.keySet())
+ maxIx = Math.max(range.getEndDims()[dim], maxIx);
+ return maxIx;
+ }
/**
* Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair. The function should
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 5d25729..b9b6203 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -44,6 +44,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
@@ -88,7 +89,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
// the encoder in which the complete encoding information will be aggregated
EncoderComposite globalEncoder = new EncoderComposite(
- Arrays.asList(new EncoderRecode(), new EncoderPassThrough()));
+ Arrays.asList(new EncoderRecode(), new EncoderPassThrough(), new EncoderDummycode()));
// first create encoders at the federated workers, then collect them and aggregate them to a single large
// encoder
FederationMap fedMapping = fin.getFedMapping();
@@ -115,14 +116,21 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
});
long varID = FederationUtils.getNextFedDataID();
FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> {
- int colStart = (int) range.getBeginDims()[1] + 1;
- int colEnd = (int) range.getEndDims()[1] + 1;
+ // copy because we reuse it
+ long[] beginDims = range.getBeginDims();
+ long[] endDims = range.getEndDims();
+ int colStart = (int) beginDims[1] + 1;
+ int colEnd = (int) endDims[1] + 1;
+
+ // update begin end dims (column part) considering columns added by dummycoding
+ globalEncoder.updateIndexRanges(beginDims, endDims);
+
// get the encoder segment that is relevant for this federated worker
Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd);
+
try {
- FederatedResponse response = data.executeFederatedOperation(
- new FederatedRequest(RequestType.EXEC_UDF, varID,
- new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
+ FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(RequestType.EXEC_UDF,
+ varID, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
}
@@ -134,13 +142,14 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
// construct a federated matrix with the encoded data
MatrixObject transformedMat = ec.getMatrixObject(getOutput(0));
- transformedMat.getDataCharacteristics().set(fin.getDataCharacteristics());
+ transformedMat.getDataCharacteristics().setRows(transformedFedMapping.getMaxIndexInRange(0));
+ transformedMat.getDataCharacteristics().setCols(transformedFedMapping.getMaxIndexInRange(1));
// set the federated mapping for the matrix
transformedMat.setFedMapping(transformedFedMapping);
// release input and outputs
ec.setFrameOutput(getOutput(1).getName(),
- globalEncoder.getMetaData(new FrameBlock(globalEncoder.getNumCols(), Types.ValueType.STRING)));
+ globalEncoder.getMetaData(new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING)));
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index e3523ed..47f912d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -102,7 +102,8 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
}
else {
- throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
+ throw new DMLRuntimeException(
+ "Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
}
}
@@ -135,22 +136,36 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
FrameBlock meta = ec.getFrameInput(params.get("meta"));
String spec = params.get("spec");
+ Decoder globalDecoder = DecoderFactory
+ .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns());
+
FederationMap fedMapping = mo.getFedMapping();
ValueType[] schema = new ValueType[(int) mo.getNumColumns()];
long varID = FederationUtils.getNextFedDataID();
FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> {
- int columnOffset = (int) range.getBeginDims()[1] + 1;
+ long[] beginDims = range.getBeginDims();
+ long[] endDims = range.getEndDims();
+ int colStartBefore = (int) beginDims[1];
+
+ // update begin end dims (column part) considering columns added by dummycoding
+ globalDecoder.updateIndexRanges(beginDims, endDims);
- FrameBlock subMeta = new FrameBlock();
+ // get the decoder segment that is relevant for this federated worker
+ Decoder decoder = globalDecoder
+ .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore);
+
+ FrameBlock metaSlice = new FrameBlock();
synchronized(meta) {
- meta.slice(0, meta.getNumRows() - 1, columnOffset - 1, (int) range.getEndDims()[1] - 1, subMeta);
+ meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice);
}
+
FederatedResponse response;
try {
- response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- varID, new DecodeMatrix(data.getVarID(), varID, subMeta, spec, columnOffset))).get();
+ response = data.executeFederatedOperation(
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, varID,
+ new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
@@ -158,7 +173,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
synchronized(schema) {
// It would be possible to assert that different federated workers don't give different value
// types for the same columns, but the performance impact is not worth the effort
- System.arraycopy(subSchema, 0, schema, columnOffset - 1, subSchema.length);
+ System.arraycopy(subSchema, 0, schema, colStartBefore, subSchema.length);
}
}
catch(Exception e) {
@@ -169,8 +184,9 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
// construct a federated matrix with the encoded data
FrameObject decodedFrame = ec.getFrameObject(output);
- decodedFrame.setSchema(schema);
+ decodedFrame.setSchema(globalDecoder.getSchema());
decodedFrame.getDataCharacteristics().set(mo.getDataCharacteristics());
+ decodedFrame.getDataCharacteristics().setCols(globalDecoder.getSchema().length);
// set the federated mapping for the matrix
decodedFrame.setFedMapping(decodedMapping);
@@ -185,34 +201,28 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
}
-
+
public static class DecodeMatrix extends FederatedUDF {
private static final long serialVersionUID = 2376756757742169692L;
- private final long _outputID;
+ private final long _outputID;
private final FrameBlock _meta;
- private final String _spec;
- private final int _globalOffset;
-
- public DecodeMatrix(long input, long outputID, FrameBlock meta, String spec, int globalOffset) {
- super(new long[]{input});
+ private final Decoder _decoder;
+
+ public DecodeMatrix(long input, long outputID, FrameBlock meta, Decoder decoder) {
+ super(new long[] {input});
_outputID = outputID;
_meta = meta;
- _spec = spec;
- _globalOffset = globalOffset;
+ _decoder = decoder;
}
-
- @Override
+
public FederatedResponse execute(ExecutionContext ec, Data... data) {
MatrixObject mo = (MatrixObject) PrivacyMonitor.handlePrivacy(data[0]);
MatrixBlock mb = mo.acquireRead();
String[] colNames = _meta.getColumnNames();
-
- // compute transformdecode
- Decoder decoder = DecoderFactory.createDecoder(_spec, colNames, null,
- _meta, mb.getNumColumns(), _globalOffset, _globalOffset + mb.getNumColumns());
- FrameBlock fbout = decoder.decode(mb, new FrameBlock(decoder.getSchema()));
+
+ FrameBlock fbout = _decoder.decode(mb, new FrameBlock(_decoder.getSchema()));
fbout.setColumnNames(Arrays.copyOfRange(colNames, 0, fbout.getNumColumns()));
-
+
// copy characteristics
MatrixCharacteristics mc = new MatrixCharacteristics(mo.getDataCharacteristics());
FrameObject fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(),
@@ -221,7 +231,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
fo.acquireModify(fbout);
fo.release();
mo.release();
-
+
// add it to the list of variables
ec.setVariable(String.valueOf(_outputID), fo);
// return schema
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
index 2aeda0f..4417387 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.transform.decode;
import java.io.Serializable;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -64,6 +65,30 @@ public abstract class Decoder implements Serializable
* @return returns given output frame block for convenience
*/
public abstract FrameBlock decode(MatrixBlock in, FrameBlock out);
-
+
+ /**
+ * Returns a new Decoder that only handles a sub range of columns. The sub-range refers to the columns after
+ * decoding.
+ *
+ * @param colStart the start index of the sub-range (1-based, inclusive)
+ * @param colEnd the end index of the sub-range (1-based, exclusive)
+ * @param dummycodedOffset the offset of dummycoded segments before colStart
+ * @return a decoder of the same type, just for the sub-range
+ */
+ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
+ throw new DMLRuntimeException(
+ getClass().getSimpleName() + " does not support the creation of a sub-range decoder");
+ }
+
+ /**
+ * Update index-ranges to after decoding. Note that only Dummycoding changes the ranges.
+ *
+ * @param beginDims the begin indexes before encoding
+ * @param endDims the end indexes before encoding
+ */
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ // do nothing - default
+ }
+
public abstract void initMetaData(FrameBlock meta);
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
index 69fcb41..263e064 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.transform.decode;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types.ValueType;
@@ -26,9 +28,9 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
/**
- * Simple composite decoder that applies a list of decoders
+ * Simple composite decoder that applies a list of decoders
* in specified order. By implementing the default decoder API
- * it can be used as a drop-in replacement for any other decoder.
+ * it can be used as a drop-in replacement for any other decoder.
*
*/
public class DecoderComposite extends Decoder
@@ -45,13 +47,30 @@ public class DecoderComposite extends Decoder
@Override
public FrameBlock decode(MatrixBlock in, FrameBlock out) {
for( Decoder decoder : _decoders )
- out = decoder.decode(in, out);
+ out = decoder.decode(in, out);
return out;
}
@Override
+ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
+ List<Decoder> subRangeDecoders = new ArrayList<>();
+ for (Decoder decoder : _decoders) {
+ Decoder subDecoder = decoder.subRangeDecoder(colStart, colEnd, dummycodedOffset);
+ if (subDecoder != null)
+ subRangeDecoders.add(subDecoder);
+ }
+ return new DecoderComposite(Arrays.copyOfRange(_schema, colStart-1, colEnd-1), subRangeDecoders);
+ }
+
+ @Override
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ for(Decoder dec : _decoders)
+ dec.updateIndexRanges(beginDims, endDims);
+ }
+
+ @Override
public void initMetaData(FrameBlock meta) {
for( Decoder decoder : _decoders )
- decoder.initMetaData(meta);
+ decoder.initMetaData(meta);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
index 0ad2187..ab1fbc8 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.transform.decode;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -43,6 +46,7 @@ public class DecoderDummycode extends Decoder
@Override
public FrameBlock decode(MatrixBlock in, FrameBlock out) {
+ //TODO perf (exploit sparse representation for better asymptotic behavior)
out.ensureAllocatedColumns(in.getNumRows());
for( int i=0; i<in.getNumRows(); i++ )
for( int j=0; j<_colList.length; j++ )
@@ -50,11 +54,60 @@ public class DecoderDummycode extends Decoder
if( in.quickGetValue(i, k-1) != 0 ) {
int col = _colList[j] - 1;
out.set(i, col, UtilFunctions.doubleToObject(
- out.getSchema()[col], k-_clPos[j]+1));
- }
+ out.getSchema()[col], k-_clPos[j]+1));
+ }
return out;
}
-
+
+ @Override
+ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
+ List<Integer> dcList = new ArrayList<>();
+ List<Integer> clPosList = new ArrayList<>();
+ List<Integer> cuPosList = new ArrayList<>();
+
+ // get the column IDs for the sub range of the dummycode columns and their destination positions,
+ // where they will be decoded to
+ for( int j=0; j<_colList.length; j++ ) {
+ int colID = _colList[j];
+ if (colID >= colStart && colID < colEnd) {
+ dcList.add(colID - (colStart - 1));
+ clPosList.add(_clPos[j] - dummycodedOffset);
+ cuPosList.add(_cuPos[j] - dummycodedOffset);
+ }
+ }
+ if (dcList.isEmpty())
+ return null;
+ // create sub-range decoder
+ int[] colList = dcList.stream().mapToInt(i -> i).toArray();
+ DecoderDummycode subRangeDecoder = new DecoderDummycode(
+ Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList);
+ subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray();
+ subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray();
+ return subRangeDecoder;
+ }
+
+ @Override
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ if(_colList == null)
+ return;
+
+ long lowerColDest = beginDims[1];
+ long upperColDest = endDims[1];
+ for(int i = 0; i < _colList.length; i++) {
+ long numDistinct = _cuPos[i] - _clPos[i];
+
+ if(_cuPos[i] <= beginDims[1] + 1)
+ if(numDistinct > 0)
+ lowerColDest -= numDistinct - 1;
+
+ if(_cuPos[i] <= endDims[1] + 1)
+ if(numDistinct > 0)
+ upperColDest -= numDistinct - 1;
+ }
+ beginDims[1] = lowerColDest;
+ endDims[1] = upperColDest;
+ }
+
@Override
public void initMetaData(FrameBlock meta) {
_clPos = new int[_colList.length]; //col lower pos
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
index 206ac74..753c666 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
@@ -19,6 +19,10 @@
package org.apache.sysds.runtime.transform.decode;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -50,14 +54,43 @@ public class DecoderPassThrough extends Decoder
int srcColID = _srcCols[j];
int tgtColID = _colList[j];
double val = in.quickGetValue(i, srcColID-1);
- out.set(i, tgtColID-1, UtilFunctions.doubleToObject(
- _schema[tgtColID-1], val));
+ out.set(i, tgtColID-1,
+ UtilFunctions.doubleToObject(_schema[tgtColID-1], val));
}
}
return out;
}
@Override
+ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
+ List<Integer> colList = new ArrayList<>();
+ List<Integer> dcList = new ArrayList<>();
+ List<Integer> srcList = new ArrayList<>();
+
+ for (int i = 0; i < _colList.length; i++) {
+ int colID = _colList[i];
+ if (colID >= colStart && colID < colEnd) {
+ colList.add(colID - (colStart - 1));
+ srcList.add(_srcCols[i] - dummycodedOffset);
+ }
+ }
+
+ Arrays.stream(_dcCols)
+ .filter(c -> c >= colStart && c < colEnd)
+ .forEach(c -> dcList.add(c));
+
+ if (colList.isEmpty())
+ // empty decoder -> return null
+ return null;
+
+ DecoderPassThrough decoder = new DecoderPassThrough(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1),
+ colList.stream().mapToInt(i -> i).toArray(),
+ dcList.stream().mapToInt(i -> i).toArray());
+ decoder._srcCols = srcList.stream().mapToInt(i -> i).toArray();
+ return decoder;
+ }
+
+ @Override
public void initMetaData(FrameBlock meta) {
if( _dcCols.length > 0 ) {
//prepare source column id mapping w/ dummy coding
@@ -69,8 +102,7 @@ public class DecoderPassThrough extends Decoder
ix1 ++;
}
else { //_colList[ix1] > _dcCols[ix2]
- off += (int)meta.getColumnMetadata()[_dcCols[ix2]-1]
- .getNumDistinct() - 1;
+ off += (int)meta.getColumnMetadata()[_dcCols[ix2]-1].getNumDistinct() - 1;
ix2 ++;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
index 5ebb8cc..9ae315f 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
@@ -19,8 +19,11 @@
package org.apache.sysds.runtime.transform.decode;
+import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -74,6 +77,32 @@ public class DecoderRecode extends Decoder
@Override
@SuppressWarnings("unchecked")
+ public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) {
+ List<Integer> cols = new ArrayList<>();
+ List<HashMap<Long, Object>> rcMaps = new ArrayList<>();
+ for(int i = 0; i < _colList.length; i++) {
+ int col = _colList[i];
+ if(col >= colStart && col < colEnd) {
+ // add the correct column, removed columns before start
+ // colStart - 1 because colStart is 1-based
+ int corrColumn = col - (colStart - 1);
+ cols.add(corrColumn);
+ rcMaps.add(new HashMap<>(_rcMaps[i]));
+ }
+ }
+ if(cols.isEmpty())
+ // empty encoder -> sub range encoder does not exist
+ return null;
+
+ int[] colList = cols.stream().mapToInt(i -> i).toArray();
+ DecoderRecode subRangeDecoder = new DecoderRecode(
+ Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), _onOut, colList);
+ subRangeDecoder._rcMaps = rcMaps.toArray(new HashMap[0]);
+ return subRangeDecoder;
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
public void initMetaData(FrameBlock meta) {
//initialize recode maps according to schema
_rcMaps = new HashMap[_colList.length];
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 5945e27..19271f8 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -168,6 +168,16 @@ public abstract class Encoder implements Serializable
throw new DMLRuntimeException(
this.getClass().getName() + " does not support merging with " + other.getClass().getName());
}
+
+ /**
+ * Update index-ranges to after encoding. Note that only Dummycoding changes the ranges.
+ *
+ * @param beginDims the begin indexes before encoding
+ * @param endDims the end indexes before encoding
+ */
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ // do nothing - default
+ }
/**
* Construct a frame block out of the transform meta data.
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index b170e22..3be9ed9 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -42,7 +42,7 @@ public class EncoderBin extends Encoder
public static final String MAX_PREFIX = "max";
public static final String NBINS_PREFIX = "nbins";
- private int[] _numBins = null;
+ protected int[] _numBins = null;
//frame transform-apply attributes
//TODO binMins is redundant and could be removed
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index e653307..cd21f45 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -134,17 +134,35 @@ public class EncoderComposite extends Encoder
+ "CompositeEncoder: " + otherEnc.getClass().getSimpleName());
}
}
+ // update dummycode encoder domain sizes based on distinctness information from other encoders
+ for (Encoder encoder : _encoders) {
+ if (encoder instanceof EncoderDummycode) {
+ ((EncoderDummycode) encoder).updateDomainSizes(_encoders);
+ return;
+ }
+ }
return;
}
for (Encoder encoder : _encoders) {
if (encoder.getClass() == other.getClass()) {
encoder.mergeAt(other, col);
+ // update dummycode encoder domain sizes based on distinctness information from other encoders
+ if (encoder instanceof EncoderDummycode) {
+ ((EncoderDummycode) encoder).updateDomainSizes(_encoders);
+ }
return;
}
}
super.mergeAt(other, col);
}
-
+
+ @Override
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ for(Encoder enc : _encoders) {
+ enc.updateIndexRanges(beginDims, endDims);
+ }
+ }
+
@Override
public FrameBlock getMetaData(FrameBlock out) {
if( _meta != null )
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index ea66479..8ff5e57 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -19,28 +19,40 @@
package org.apache.sysds.runtime.transform.encode;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class EncoderDummycode extends Encoder
{
private static final long serialVersionUID = 5832130477659116489L;
- private int[] _domainSizes = null; // length = #of dummycoded columns
+ public int[] _domainSizes = null; // length = #of dummycoded columns
private long _dummycodedLength = 0; // #of columns after dummycoded
- public EncoderDummycode(JSONObject parsedSpec, String[] colnames, int clen) throws JSONException {
+ public EncoderDummycode(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
+ throws JSONException {
super(null, clen);
-
- if ( parsedSpec.containsKey(TfMethod.DUMMYCODE.toString()) ) {
- int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.DUMMYCODE.toString());
+
+ if(parsedSpec.containsKey(TfMethod.DUMMYCODE.toString())) {
+ int[] collist = TfMetaUtils
+ .parseJsonIDList(parsedSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol);
initColList(collist);
}
}
+
+ public EncoderDummycode() {
+ super(new int[0], 0);
+ }
@Override
public int getNumCols() {
@@ -85,6 +97,102 @@ public class EncoderDummycode extends Encoder
}
@Override
+ public Encoder subRangeEncoder(int colStart, int colEnd) {
+ List<Integer> cols = new ArrayList<>();
+ List<Integer> domainSizes = new ArrayList<>();
+ int newDummycodedLength = colEnd - colStart;
+ for(int i = 0; i < _colList.length; i++){
+ int col = _colList[i];
+ if(col >= colStart && col < colEnd) {
+ // add the correct column, removed columns before start
+ // colStart - 1 because colStart is 1-based
+ int corrColumn = col - (colStart - 1);
+ cols.add(corrColumn);
+ domainSizes.add(_domainSizes[i]);
+ newDummycodedLength += _domainSizes[i] - 1;
+ }
+ }
+ if(cols.isEmpty())
+ // empty encoder -> sub range encoder does not exist
+ return null;
+
+ EncoderDummycode subRangeEncoder = new EncoderDummycode();
+ subRangeEncoder._clen = colEnd - colStart;
+ subRangeEncoder._colList = cols.stream().mapToInt(i -> i).toArray();
+ subRangeEncoder._domainSizes = domainSizes.stream().mapToInt(i -> i).toArray();
+ subRangeEncoder._dummycodedLength = newDummycodedLength;
+ return subRangeEncoder;
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderDummycode) {
+ mergeColumnInfo(other, col);
+
+ _domainSizes = new int[_colList.length];
+ _dummycodedLength = _clen;
+ // temporary, will be updated later
+ Arrays.fill(_domainSizes, 0, _colList.length, 1);
+ return;
+ }
+ super.mergeAt(other, col);
+ }
+
+ @Override
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ long[] initialBegin = Arrays.copyOf(beginDims, beginDims.length);
+ long[] initialEnd = Arrays.copyOf(endDims, endDims.length);
+ for(int i = 0; i < _colList.length; i++) {
+ // 1-based vs 0-based
+ if(_colList[i] < initialBegin[1] + 1) {
+ // new columns inserted left of the columns of this partial (federated) block
+ beginDims[1] += _domainSizes[i] - 1;
+ endDims[1] += _domainSizes[i] - 1;
+ }
+ else if(_colList[i] < initialEnd[1] + 1) {
+ // new columns inserted in this (federated) block
+ endDims[1] += _domainSizes[i] - 1;
+ }
+ }
+ }
+
+ public void updateDomainSizes(List<Encoder> encoders) {
+ if(_colList == null)
+ return;
+
+ // maps the column ids of the columns encoded by this Dummycode Encoder to their respective indexes
+ // in the _colList
+ Map<Integer, Integer> colIDToIxMap = new HashMap<>();
+ for (int i = 0; i < _colList.length; i++)
+ colIDToIxMap.put(_colList[i], i);
+
+ _dummycodedLength = _clen;
+ for (Encoder encoder : encoders) {
+ int[] distinct = null;
+ if (encoder instanceof EncoderRecode) {
+ EncoderRecode encoderRecode = (EncoderRecode) encoder;
+ distinct = encoderRecode.numDistinctValues();
+ }
+ else if (encoder instanceof EncoderBin) {
+ distinct = ((EncoderBin) encoder)._numBins;
+ }
+
+ if (distinct != null) {
+ // search for match of encoded columns
+ for (int i = 0; i < encoder._colList.length; i++) {
+ Integer ix = colIDToIxMap.get(encoder._colList[i]);
+
+ if (ix != null) {
+ // set size
+ _domainSizes[ix] = distinct[i];
+ _dummycodedLength += _domainSizes[ix] - 1;
+ }
+ }
+ }
+ }
+ }
+
+ @Override
public FrameBlock getMetaData(FrameBlock out) {
return out;
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 2070485..57f7102 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -100,7 +100,7 @@ public class EncoderFactory
if( !binIDs.isEmpty() )
lencoders.add(new EncoderBin(jSpec, colnames, schema.length));
if( !dcIDs.isEmpty() )
- lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length));
+ lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length, minCol, maxCol));
if( !oIDs.isEmpty() )
lencoders.add(new EncoderOmit(jSpec, colnames, schema.length));
if( !mvIDs.isEmpty() ) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index 8b3d36a..d6ceb15 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -73,9 +73,6 @@ public class EncoderPassThrough extends Encoder
@Override
public Encoder subRangeEncoder(int colStart, int colEnd) {
- if (colStart - 1 >= _clen)
- return null;
-
List<Integer> colList = new ArrayList<>();
for (int col : _colList) {
if (col >= colStart && col < colEnd)
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index d4b201e..be1cba9 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -165,9 +165,6 @@ public class EncoderRecode extends Encoder
@Override
public Encoder subRangeEncoder(int colStart, int colEnd) {
- if (colStart - 1 >= _clen)
- return null;
-
List<Integer> cols = new ArrayList<>();
HashMap<Integer, HashMap<String, Long>> rcdMaps = new HashMap<>();
for (int col : _colList) {
@@ -216,6 +213,16 @@ public class EncoderRecode extends Encoder
}
super.mergeAt(other, col);
}
+
+ public int[] numDistinctValues() {
+ int[] numDistinct = new int[_colList.length];
+
+ for( int j=0; j<_colList.length; j++ ) {
+ int colID = _colList[j]; //1-based
+ numDistinct[j] = _rcdMaps.get(colID).size();
+ }
+ return numDistinct;
+ }
@Override
public FrameBlock getMetaData(FrameBlock meta) {
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
index ceee7da..29afa5b 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
@@ -35,11 +35,13 @@ import org.junit.Assert;
import org.junit.Test;
public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
- private static final String TEST_NAME1 = "TransformFederatedEncodeDecode";
+ private static final String TEST_NAME_RECODE = "TransformRecodeFederatedEncodeDecode";
+ private static final String TEST_NAME_DUMMY = "TransformDummyFederatedEncodeDecode";
private static final String TEST_DIR = "functions/transform/";
private static final String TEST_CLASS_DIR = TEST_DIR+TransformFederatedEncodeDecodeTest.class.getSimpleName()+"/";
- private static final String SPEC = "TransformEncodeDecodeSpec.json";
+ private static final String SPEC_RECODE = "TransformEncodeDecodeSpec.json";
+ private static final String SPEC_DUMMYCODE = "TransformEncodeDecodeDummySpec.json";
private static final int rows = 1234;
private static final int cols = 2;
@@ -49,47 +51,78 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- addTestConfiguration(TEST_NAME1,
- new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"FO1", "FO2"}));
+ addTestConfiguration(TEST_NAME_RECODE,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_RECODE, new String[] {"FO1", "FO2"}));
}
@Test
- public void runTestCSVDenseCP() {
- runTransformEncodeDecodeTest(false, Types.FileFormat.CSV);
+ public void runComplexRecodeTestCSVDenseCP() {
+ runTransformEncodeDecodeTest(true, false, Types.FileFormat.CSV);
}
@Test
- public void runTestCSVSparseCP() {
- runTransformEncodeDecodeTest(true, Types.FileFormat.CSV);
+ public void runComplexRecodeTestCSVSparseCP() {
+ runTransformEncodeDecodeTest(true, true, Types.FileFormat.CSV);
}
@Test
- public void runTestTextcellDenseCP() {
- runTransformEncodeDecodeTest(false, Types.FileFormat.TEXT);
+ public void runComplexRecodeTestTextcellDenseCP() {
+ runTransformEncodeDecodeTest(true, false, Types.FileFormat.TEXT);
}
@Test
- public void runTestTextcellSparseCP() {
- runTransformEncodeDecodeTest(true, Types.FileFormat.TEXT);
+ public void runComplexRecodeTestTextcellSparseCP() {
+ runTransformEncodeDecodeTest(true, true, Types.FileFormat.TEXT);
}
@Test
- public void runTestBinaryDenseCP() {
- runTransformEncodeDecodeTest(false, Types.FileFormat.BINARY);
+ public void runComplexRecodeTestBinaryDenseCP() {
+ runTransformEncodeDecodeTest(true, false, Types.FileFormat.BINARY);
}
@Test
- public void runTestBinarySparseCP() {
- runTransformEncodeDecodeTest(true, Types.FileFormat.BINARY);
+ public void runComplexRecodeTestBinarySparseCP() {
+ runTransformEncodeDecodeTest(true, true, Types.FileFormat.BINARY);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestCSVDenseCP() {
+ runTransformEncodeDecodeTest(false, false, Types.FileFormat.CSV);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestCSVSparseCP() {
+ runTransformEncodeDecodeTest(false, true, Types.FileFormat.CSV);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestTextDenseCP() {
+ runTransformEncodeDecodeTest(false, false, Types.FileFormat.TEXT);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestTextSparseCP() {
+ runTransformEncodeDecodeTest(false, true, Types.FileFormat.TEXT);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestBinaryDenseCP() {
+ runTransformEncodeDecodeTest(false, false, Types.FileFormat.BINARY);
+ }
+
+ @Test
+ public void runSimpleDummycodeTestBinarySparseCP() {
+ runTransformEncodeDecodeTest(false, true, Types.FileFormat.BINARY);
}
- private void runTransformEncodeDecodeTest(boolean sparse, Types.FileFormat format) {
+ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse,
+ Types.FileFormat format) {
ExecMode platformOld = rtplatform;
rtplatform = ExecMode.SINGLE_NODE;
Thread t1 = null, t2 = null, t3 = null, t4 = null;
try {
- getAndLoadTestConfiguration(TEST_NAME1);
+ getAndLoadTestConfiguration(TEST_NAME_RECODE);
int port1 = getRandomAvailablePort();
t1 = startLocalFedWorkerThread(port1);
@@ -120,14 +153,15 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
writeInputFrameWithMTD("BU", BUpper, false, schema, format);
writeInputFrameWithMTD("BL", BLower, false, schema, format);
- fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME1 + ".dml";
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + (recode ? TEST_NAME_RECODE : TEST_NAME_DUMMY) + ".dml";
+ String spec_file = recode ? SPEC_RECODE : SPEC_DUMMYCODE;
programArgs = new String[] {"-nvargs",
"in_AU=" + TestUtils.federatedAddress("localhost", port1, input("AU")),
"in_AL=" + TestUtils.federatedAddress("localhost", port2, input("AL")),
"in_BU=" + TestUtils.federatedAddress("localhost", port3, input("BU")),
"in_BL=" + TestUtils.federatedAddress("localhost", port4, input("BL")), "rows=" + rows, "cols=" + cols,
- "spec_file=" + SCRIPT_DIR + TEST_DIR + SPEC, "out1=" + output("FO1"), "out2=" + output("FO2"),
+ "spec_file=" + SCRIPT_DIR + TEST_DIR + spec_file, "out1=" + output("FO1"), "out2=" + output("FO2"),
"format=" + format.toString()};
// run test
@@ -144,16 +178,18 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
+ val, expected, val);
}
}
- // TODO federate the aggregated result so that the decode is applied in a federated environment
- // compare matrices (values recoded to identical codes)
- FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2);
- HashMap<String, Long> cFA = getCounts(A, B);
- Iterator<String[]> iterFO = FO.getStringRowIterator();
- while(iterFO.hasNext()) {
- String[] row = iterFO.next();
- Double expected = (double) cFA.get(row[1]);
- Double val = (row[0] != null) ? Double.parseDouble(row[0]) : 0;
- Assert.assertEquals("Output aggregates don't match: " + expected + " vs " + val, expected, val);
+ if(recode) {
+ // TODO federate the aggregated result so that the decode is applied in a federated environment
+ // compare matrices (values recoded to identical codes)
+ FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2);
+ HashMap<String, Long> cFA = getCounts(A, B);
+ Iterator<String[]> iterFO = FO.getStringRowIterator();
+ while(iterFO.hasNext()) {
+ String[] row = iterFO.next();
+ Double expected = (double) cFA.get(row[1]);
+ Double val = (row[0] != null) ? Double.parseDouble(row[0]) : 0;
+ Assert.assertEquals("Output aggregates don't match: " + expected + " vs " + val, expected, val);
+ }
}
}
catch(Exception ex) {
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml
similarity index 89%
copy from src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
copy to src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml
index 50174d7..f029719 100644
--- a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
+++ b/src/test/scripts/functions/transform/TransformDummyFederatedEncodeDecode.dml
@@ -28,11 +28,6 @@ jspec = read($spec_file, data_type="scalar", value_type="string");
[X, M] = transformencode(target=F, spec=jspec);
-A = aggregate(target=X[,1], groups=X[,2], fn="count");
-Ag = cbind(A, seq(1,nrow(A)));
-
-FO1 = transformdecode(target=Ag, spec=jspec, meta=M);
FO2 = transformdecode(target=X, spec=jspec, meta=M);
-write(FO1, $out1, format=$format);
write(FO2, $out2, format=$format);
diff --git a/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json b/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json
new file mode 100644
index 0000000..5f4aa12
--- /dev/null
+++ b/src/test/scripts/functions/transform/TransformEncodeDecodeDummySpec.json
@@ -0,0 +1,5 @@
+{
+ "ids": true
+ ,"dummycode": [ 2 ]
+
+}
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml b/src/test/scripts/functions/transform/TransformRecodeFederatedEncodeDecode.dml
similarity index 100%
rename from src/test/scripts/functions/transform/TransformFederatedEncodeDecode.dml
rename to src/test/scripts/functions/transform/TransformRecodeFederatedEncodeDecode.dml