You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/08/29 21:58:18 UTC
[systemds] branch master updated: [SYSTEMDS-2554-7] Federated frame
transformapply, incl binning, omit
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 3743282 [SYSTEMDS-2554-7] Federated frame transformapply, incl binning, omit
3743282 is described below
commit 37432820067175b8b49f21754bc4df7959971d7b
Author: Kevin Innerebner <ke...@yahoo.com>
AuthorDate: Sat Aug 29 22:33:58 2020 +0200
[SYSTEMDS-2554-7] Federated frame transformapply, incl binning, omit
Closes #1032.
---
.../controlprogram/federated/FederatedRange.java | 6 +
.../federated/FederatedWorkerHandler.java | 10 +-
.../controlprogram/federated/FederationMap.java | 7 +-
.../cp/ParameterizedBuiltinCPInstruction.java | 7 +-
.../instructions/fed/FEDInstructionUtils.java | 3 +-
...tiReturnParameterizedBuiltinFEDInstruction.java | 58 +++--
.../fed/ParameterizedBuiltinFEDInstruction.java | 152 ++++++++++--
.../sysds/runtime/io/FileFormatPropertiesCSV.java | 8 +-
.../sysds/runtime/transform/encode/Encoder.java | 27 +-
.../sysds/runtime/transform/encode/EncoderBin.java | 120 ++++++++-
.../runtime/transform/encode/EncoderComposite.java | 14 +-
.../runtime/transform/encode/EncoderDummycode.java | 25 +-
.../runtime/transform/encode/EncoderFactory.java | 10 +-
.../transform/encode/EncoderFeatureHash.java | 38 ++-
.../runtime/transform/encode/EncoderMVImpute.java | 2 +-
.../runtime/transform/encode/EncoderOmit.java | 145 ++++++++---
.../transform/encode/EncoderPassThrough.java | 14 +-
.../runtime/transform/encode/EncoderRecode.java | 15 +-
.../sysds/runtime/transform/meta/TfMetaUtils.java | 79 +++---
.../org/apache/sysds/runtime/util/HDFSTool.java | 7 +
.../org/apache/sysds/runtime/util/IndexRange.java | 18 +-
.../TransformFederatedEncodeApplyTest.java | 273 +++++++++++++++++++++
.../TransformFederatedEncodeDecodeTest.java | 15 +-
.../transform/TransformFederatedEncodeApply.dml | 36 +++
24 files changed, 904 insertions(+), 185 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
index 23d0269..4289cfe 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRange.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.controlprogram.federated;
import java.util.Arrays;
+import org.apache.sysds.runtime.util.IndexRange;
+
public class FederatedRange implements Comparable<FederatedRange> {
private long[] _beginDims;
private long[] _endDims;
@@ -119,4 +121,8 @@ public class FederatedRange implements Comparable<FederatedRange> {
_endDims[1] = tmpEnd;
return this;
}
+
+ public IndexRange asIndexRange() {
+ return new IndexRange(_beginDims[0], _endDims[0], _beginDims[1], _endDims[1]);
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index a2e62fe..b5f0ec8 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -45,6 +45,7 @@ import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
@@ -184,10 +185,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
// read metadata
FileFormat fmt = null;
+ boolean header = false;
try {
String mtdname = DataExpression.getMTDFileName(filename);
Path path = new Path(mtdname);
-
FileSystem fs = IOUtilFunctions.getFileSystem(mtdname); //no auto-close
try (BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) {
JSONObject mtd = JSONHelper.parse(br);
@@ -198,7 +199,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
mc.setCols(mtd.getLong(DataExpression.READCOLPARAM));
if(mtd.containsKey(DataExpression.READNNZPARAM))
mc.setNonZeros(mtd.getLong(DataExpression.READNNZPARAM));
-
+ if (mtd.has(DataExpression.DELIM_HAS_HEADER_ROW))
+ header = mtd.getBoolean(DataExpression.DELIM_HAS_HEADER_ROW);
cd = (CacheableData<?>) PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
fmt = FileFormat.safeValueOf(mtd.getString(DataExpression.FORMAT_TYPE));
}
@@ -209,6 +211,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
//put meta data object in symbol table, read on first operation
cd.setMetaData(new MetaDataFormat(mc, fmt));
+ // TODO send FileFormatProperties with request and use them for CSV, this is currently a workaround so reading
+ // of CSV files works
+ cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER,
+ DataExpression.DEFAULT_DELIM_SPARSE));
cd.enableCleanup(false); //guard against deletion
_ecm.get(tid).setVariable(String.valueOf(id), cd);
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index ea8aa29..7d537c9 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -242,10 +242,9 @@ public class FederationMap
public long getMaxIndexInRange(int dim) {
- long maxIx = 0;
- for(FederatedRange range : _fedMap.keySet())
- maxIx = Math.max(range.getEndDims()[dim], maxIx);
- return maxIx;
+ return _fedMap.keySet().stream()
+ .mapToLong(range -> range.getEndDims()[dim]).max()
+ .orElse(-1L);
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index cfb20e3..5c71780 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -37,6 +37,7 @@ import org.apache.sysds.parser.ParameterizedBuiltinFunctionExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.TensorObject;
@@ -304,7 +305,7 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
//get input spec and path
String spec = getParameterMap().get("spec");
String path = getParameterMap().get(ParameterizedBuiltinFunctionExpression.TF_FN_PARAM_MTD);
- String delim = getParameterMap().containsKey("sep") ? getParameterMap().get("sep") : TfUtils.TXMTD_SEP;
+ String delim = getParameterMap().getOrDefault("sep", TfUtils.TXMTD_SEP);
//execute transform meta data read
FrameBlock meta = null;
@@ -457,8 +458,8 @@ public class ParameterizedBuiltinCPInstruction extends ComputationCPInstruction
}
}
- public MatrixObject getTarget(ExecutionContext ec) {
- return ec.getMatrixObject(params.get("target"));
+ public CacheableData<?> getTarget(ExecutionContext ec) {
+ return ec.getCacheableData(params.get("target"));
}
private CPOperand getTargetOperand() {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index a1b0a08..2e41aa5 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -79,7 +79,8 @@ public class FEDInstructionUtils {
if(pinst.getOpcode().equals("replace") && pinst.getTarget(ec).isFederated()) {
fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
- else if(pinst.getOpcode().equals("transformdecode") && pinst.getTarget(ec).isFederated()) {
+ else if((pinst.getOpcode().equals("transformdecode") || pinst.getOpcode().equals("transformapply")) &&
+ pinst.getTarget(ec).isFederated()) {
return ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index b9b6203..0fe12b9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -43,11 +43,15 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderBin;
import org.apache.sysds.runtime.transform.encode.EncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderDummycode;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.EncoderFeatureHash;
+import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.EncoderPassThrough;
import org.apache.sysds.runtime.transform.encode.EncoderRecode;
+import org.apache.sysds.runtime.util.IndexRange;
public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
protected final ArrayList<CPOperand> _outputs;
@@ -86,10 +90,19 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
// obtain and pin input frame
FrameObject fin = ec.getFrameObject(input1.getName());
String spec = ec.getScalarInput(input2).getStringValue();
+
+ String[] colNames = new String[(int) fin.getNumColumns()];
+ Arrays.fill(colNames, "");
// the encoder in which the complete encoding information will be aggregated
EncoderComposite globalEncoder = new EncoderComposite(
- Arrays.asList(new EncoderRecode(), new EncoderPassThrough(), new EncoderDummycode()));
+ // IMPORTANT: Encoder order matters
+ Arrays.asList(new EncoderRecode(),
+ new EncoderFeatureHash(),
+ new EncoderPassThrough(),
+ new EncoderBin(),
+ new EncoderDummycode(),
+ new EncoderOmit(true)));
// first create encoders at the federated workers, then collect them and aggregate them to a single large
// encoder
FederationMap fedMapping = fin.getFedMapping();
@@ -98,39 +111,55 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
// create an encoder with the given spec. The columnOffset (which is 1 based) has to be used to
// tell the federated worker how much the indexes in the spec have to be offset.
- Future<FederatedResponse> response = data.executeFederatedOperation(
- new FederatedRequest(RequestType.EXEC_UDF, data.getVarID(),
+ Future<FederatedResponse> responseFuture = data.executeFederatedOperation(
+ new FederatedRequest(RequestType.EXEC_UDF, -1,
new CreateFrameEncoder(data.getVarID(), spec, columnOffset)));
// collect responses with encoders
try {
- Encoder encoder = (Encoder) response.get().getData()[0];
+ FederatedResponse response = responseFuture.get();
+ Encoder encoder = (Encoder) response.getData()[0];
// merge this encoder into a composite encoder
synchronized(globalEncoder) {
globalEncoder.mergeAt(encoder, columnOffset);
}
+ // no synchronization necessary since names should anyway match
+ String[] subRangeColNames = (String[]) response.getData()[1];
+ System.arraycopy(subRangeColNames, 0, colNames, (int) range.getBeginDims()[1], subRangeColNames.length);
}
catch(Exception e) {
throw new DMLRuntimeException("Federated encoder creation failed: " + e.getMessage());
}
return null;
});
+ FrameBlock meta = new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING);
+ meta.setColumnNames(colNames);
+ globalEncoder.getMetaData(meta);
+ globalEncoder.initMetaData(meta);
+
+ encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(getOutput(0)));
+
+ // release input and outputs
+ ec.setFrameOutput(getOutput(1).getName(), meta);
+ }
+
+ public static void encodeFederatedFrames(FederationMap fedMapping, Encoder globalEncoder,
+ MatrixObject transformedMat) {
long varID = FederationUtils.getNextFedDataID();
FederationMap transformedFedMapping = fedMapping.mapParallel(varID, (range, data) -> {
// copy because we reuse it
long[] beginDims = range.getBeginDims();
long[] endDims = range.getEndDims();
- int colStart = (int) beginDims[1] + 1;
- int colEnd = (int) endDims[1] + 1;
+ IndexRange ixRange = new IndexRange(beginDims[0], endDims[0], beginDims[1], endDims[1]).add(1);// make 1-based
// update begin end dims (column part) considering columns added by dummycoding
globalEncoder.updateIndexRanges(beginDims, endDims);
// get the encoder segment that is relevant for this federated worker
- Encoder encoder = globalEncoder.subRangeEncoder(colStart, colEnd);
+ Encoder encoder = globalEncoder.subRangeEncoder(ixRange);
try {
FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(RequestType.EXEC_UDF,
- varID, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
+ -1, new ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
}
@@ -141,18 +170,11 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
});
// construct a federated matrix with the encoded data
- MatrixObject transformedMat = ec.getMatrixObject(getOutput(0));
- transformedMat.getDataCharacteristics().setRows(transformedFedMapping.getMaxIndexInRange(0));
- transformedMat.getDataCharacteristics().setCols(transformedFedMapping.getMaxIndexInRange(1));
- // set the federated mapping for the matrix
+ transformedMat.getDataCharacteristics().setDimension(
+ transformedFedMapping.getMaxIndexInRange(0), transformedFedMapping.getMaxIndexInRange(1));
transformedMat.setFedMapping(transformedFedMapping);
-
- // release input and outputs
- ec.setFrameOutput(getOutput(1).getName(),
- globalEncoder.getMetaData(new FrameBlock((int) fin.getNumColumns(), Types.ValueType.STRING)));
}
-
public static class CreateFrameEncoder extends FederatedUDF {
private static final long serialVersionUID = 2376756757742169692L;
private final String _spec;
@@ -179,7 +201,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
fo.release();
// create federated response
- return new FederatedResponse(ResponseType.SUCCESS, encoder);
+ return new FederatedResponse(ResponseType.SUCCESS, new Object[] {encoder, fb.getColumnNames()});
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 47f912d..204019f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -23,17 +23,20 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
+import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
@@ -51,6 +54,10 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
+import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.EncoderOmit;
public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
protected final LinkedHashMap<String, String> params;
@@ -113,7 +120,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
if(opcode.equalsIgnoreCase("replace")) {
// similar to unary federated instructions, get federated input
// execute instruction, and derive federated output matrix
- MatrixObject mo = getTarget(ec);
+ MatrixObject mo = (MatrixObject) getTarget(ec);
FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
new CPOperand[] {getTargetOperand()}, new long[] {mo.getFedMapping().getID()});
mo.getFedMapping().execute(getTID(), true, fr1);
@@ -125,22 +132,24 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
else if(opcode.equalsIgnoreCase("transformdecode"))
transformDecode(ec);
+ else if(opcode.equalsIgnoreCase("transformapply"))
+ transformApply(ec);
else {
throw new DMLRuntimeException("Unknown opcode : " + opcode);
}
}
-
+
private void transformDecode(ExecutionContext ec) {
// acquire locks
MatrixObject mo = ec.getMatrixObject(params.get("target"));
FrameBlock meta = ec.getFrameInput(params.get("meta"));
String spec = params.get("spec");
-
+
Decoder globalDecoder = DecoderFactory
- .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns());
-
+ .createDecoder(spec, meta.getColumnNames(), null, meta, (int) mo.getNumColumns());
+
FederationMap fedMapping = mo.getFedMapping();
-
+
ValueType[] schema = new ValueType[(int) mo.getNumColumns()];
long varID = FederationUtils.getNextFedDataID();
FederationMap decodedMapping = fedMapping.mapParallel(varID, (range, data) -> {
@@ -153,22 +162,21 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
// get the decoder segment that is relevant for this federated worker
Decoder decoder = globalDecoder
- .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore);
+ .subRangeDecoder((int) beginDims[1] + 1, (int) endDims[1] + 1, colStartBefore);
FrameBlock metaSlice = new FrameBlock();
synchronized(meta) {
meta.slice(0, meta.getNumRows() - 1, (int) beginDims[1], (int) endDims[1] - 1, metaSlice);
}
-
-
+
FederatedResponse response;
try {
response = data.executeFederatedOperation(
- new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, varID,
+ new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
new DecodeMatrix(data.getVarID(), varID, metaSlice, decoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
-
+
ValueType[] subSchema = (ValueType[]) response.getData()[0];
synchronized(schema) {
// It would be possible to assert that different federated workers don't give different value
@@ -181,7 +189,7 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
}
return null;
});
-
+
// construct a federated matrix with the encoded data
FrameObject decodedFrame = ec.getFrameObject(output);
decodedFrame.setSchema(globalDecoder.getSchema());
@@ -189,19 +197,94 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
decodedFrame.getDataCharacteristics().setCols(globalDecoder.getSchema().length);
// set the federated mapping for the matrix
decodedFrame.setFedMapping(decodedMapping);
-
+
// release locks
ec.releaseFrameInput(params.get("meta"));
}
- public MatrixObject getTarget(ExecutionContext ec) {
- return ec.getMatrixObject(params.get("target"));
+ private void transformApply(ExecutionContext ec) {
+ // acquire locks
+ FrameObject fo = ec.getFrameObject(params.get("target"));
+ FrameBlock meta = ec.getFrameInput(params.get("meta"));
+ String spec = params.get("spec");
+
+ FederationMap fedMapping = fo.getFedMapping();
+
+ // get column names for the EncoderFactory
+ String[] colNames = new String[(int) fo.getNumColumns()];
+ Arrays.fill(colNames, "");
+
+ fedMapping.forEachParallel((range, data) -> {
+ try {
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new GetColumnNames(data.getVarID()))).get();
+
+ // no synchronization necessary since names should anyway match
+ String[] subRangeColNames = (String[]) response.getData()[0];
+ System.arraycopy(subRangeColNames, 0, colNames, (int) range.getBeginDims()[1], subRangeColNames.length);
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ Encoder globalEncoder = EncoderFactory.createEncoder(spec, colNames, colNames.length, meta);
+
+ // check if EncoderOmit exists
+ List<Encoder> encoders = ((EncoderComposite) globalEncoder).getEncoders();
+ int omitIx = -1;
+ for(int i = 0; i < encoders.size(); i++) {
+ if(encoders.get(i) instanceof EncoderOmit) {
+ omitIx = i;
+ break;
+ }
+ }
+ if(omitIx != -1) {
+ // extra step, build the omit encoder: we need information about all the rows to omit, if our federated
+ // ranges are split up row-wise we need to build the encoder separately and combine it
+ buildOmitEncoder(fedMapping, encoders, omitIx);
+ }
+
+ MultiReturnParameterizedBuiltinFEDInstruction
+ .encodeFederatedFrames(fedMapping, globalEncoder, ec.getMatrixObject(getOutputVariableName()));
+
+ // release locks
+ ec.releaseFrameInput(params.get("meta"));
+ }
+
+ private static void buildOmitEncoder(FederationMap fedMapping, List<Encoder> encoders, int omitIx) {
+ Encoder omitEncoder = encoders.get(omitIx);
+ EncoderOmit newOmit = new EncoderOmit(true);
+ fedMapping.forEachParallel((range, data) -> {
+ try {
+ EncoderOmit subRangeEncoder = (EncoderOmit) omitEncoder.subRangeEncoder(range.asIndexRange().add(1));
+ FederatedResponse response = data
+ .executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new InitRowsToRemoveOmit(data.getVarID(), subRangeEncoder))).get();
+
+ // no synchronization necessary since names should anyway match
+ Encoder builtEncoder = (Encoder) response.getData()[0];
+ newOmit.mergeAt(builtEncoder, (int) (range.getBeginDims()[1] + 1));
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+ encoders.remove(omitIx);
+ encoders.add(omitIx, newOmit);
+ }
+
+ public CacheableData<?> getTarget(ExecutionContext ec) {
+ return ec.getCacheableData(params.get("target"));
}
private CPOperand getTargetOperand() {
return new CPOperand(params.get("target"), ValueType.FP64, DataType.MATRIX);
}
-
+
public static class DecodeMatrix extends FederatedUDF {
private static final long serialVersionUID = 2376756757742169692L;
private final long _outputID;
@@ -235,7 +318,42 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
// add it to the list of variables
ec.setVariable(String.valueOf(_outputID), fo);
// return schema
- return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {fo.getSchema()});
+ return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fo.getSchema()});
+ }
+ }
+
+ private static class GetColumnNames extends FederatedUDF {
+ private static final long serialVersionUID = -7831469841164270004L;
+
+ public GetColumnNames(long varID) {
+ super(new long[] {varID});
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]);
+ FrameBlock fb = fo.acquireReadAndRelease();
+ // return column names
+ return new FederatedResponse(ResponseType.SUCCESS, new Object[] {fb.getColumnNames()});
+ }
+ }
+
+ private static class InitRowsToRemoveOmit extends FederatedUDF {
+ private static final long serialVersionUID = -8196730717390438411L;
+
+ EncoderOmit _encoder;
+
+ public InitRowsToRemoveOmit(long varID, EncoderOmit encoder) {
+ super(new long[] {varID});
+ _encoder = encoder;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data... data) {
+ FrameObject fo = (FrameObject) PrivacyMonitor.handlePrivacy(data[0]);
+ FrameBlock fb = fo.acquireReadAndRelease();
+ _encoder.build(fb);
+ return new FederatedResponse(ResponseType.SUCCESS, new Object[] {_encoder});
}
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
index 7049918..7b20e38 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FileFormatPropertiesCSV.java
@@ -54,6 +54,7 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser
}
public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean fill, double fillValue, String naStrings) {
+ this();
this.header = hasHeader;
this.delim = delim;
this.fill = fill;
@@ -68,6 +69,7 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser
}
public FileFormatPropertiesCSV(boolean hasHeader, String delim, boolean sparse) {
+ this();
this.header = hasHeader;
this.delim = delim;
this.sparse = sparse;
@@ -88,7 +90,11 @@ public class FileFormatPropertiesCSV extends FileFormatProperties implements Ser
public String getDelim() {
return delim;
}
-
+
+ public void setNAStrings(HashSet<String> naStrings) {
+ this.naStrings = naStrings;
+ }
+
public HashSet<String> getNAStrings() {
return naStrings;
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 19271f8..7f47192 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -20,17 +20,20 @@
package org.apache.sysds.runtime.transform.encode;
import java.io.Serializable;
+import java.util.ArrayList;
import java.util.Arrays;
-
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
+
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
-import org.apache.wink.json4j.JSONArray;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONArray;
/**
* Base class for all transform encoders providing both a row and block
@@ -125,14 +128,26 @@ public abstract class Encoder implements Serializable
*/
public abstract MatrixBlock apply(FrameBlock in, MatrixBlock out);
+ protected int[] subRangeColList(IndexRange ixRange) {
+ List<Integer> cols = new ArrayList<>();
+ for(int col : _colList) {
+ if(ixRange.inColRange(col)) {
+ // add the correct column, removed columns before start
+ // colStart - 1 because colStart is 1-based
+ int corrColumn = (int) (col - (ixRange.colStart - 1));
+ cols.add(corrColumn);
+ }
+ }
+ return cols.stream().mapToInt(i -> i).toArray();
+ }
+
/**
* Returns a new Encoder that only handles a sub range of columns.
*
- * @param colStart the start index of the sub-range (1-based, inclusive)
- * @param colEnd the end index of the sub-range (1-based, exclusive)
+ * @param ixRange the range (1-based, begin inclusive, end exclusive)
* @return an encoder of the same type, just for the sub-range
*/
- public Encoder subRangeEncoder(int colStart, int colEnd) {
+ public Encoder subRangeEncoder(IndexRange ixRange) {
throw new DMLRuntimeException(
this.getClass().getSimpleName() + " does not support the creation of a sub-range encoder");
}
@@ -166,7 +181,7 @@ public abstract class Encoder implements Serializable
*/
public void mergeAt(Encoder other, int col) {
throw new DMLRuntimeException(
- this.getClass().getName() + " does not support merging with " + other.getClass().getName());
+ this.getClass().getSimpleName() + " does not support merging with " + other.getClass().getSimpleName());
}
/**
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index 3be9ed9..351f68d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -20,19 +20,24 @@
package org.apache.sysds.runtime.transform.encode;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.apache.commons.lang.ArrayUtils;
-import org.apache.wink.json4j.JSONArray;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
+import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONArray;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class EncoderBin extends Encoder
{
@@ -49,7 +54,7 @@ public class EncoderBin extends Encoder
private double[][] _binMins = null;
private double[][] _binMaxs = null;
- public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen)
+ public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
throws JSONException, IOException
{
super( null, clen );
@@ -57,22 +62,35 @@ public class EncoderBin extends Encoder
return;
//parse column names or column ids
- List<Integer> collist = TfMetaUtils.parseBinningColIDs(parsedSpec, colnames);
+ List<Integer> collist = TfMetaUtils.parseBinningColIDs(parsedSpec, colnames, minCol, maxCol);
initColList(ArrayUtils.toPrimitive(collist.toArray(new Integer[0])));
//parse number of bins per column
boolean ids = parsedSpec.containsKey("ids") && parsedSpec.getBoolean("ids");
JSONArray group = (JSONArray) parsedSpec.get(TfMethod.BIN.toString());
_numBins = new int[collist.size()];
- for(int i=0; i < _numBins.length; i++) {
- JSONObject colspec = (JSONObject) group.get(i);
- int pos = collist.indexOf(ids ? colspec.getInt("id") :
- ArrayUtils.indexOf(colnames, colspec.get("name"))+1);
- _numBins[pos] = colspec.containsKey("numbins") ?
- colspec.getInt("numbins"): 1;
+ for (Object o : group) {
+ JSONObject colspec = (JSONObject) o;
+ int ixOffset = minCol == -1 ? 0 : minCol - 1;
+ int pos = collist.indexOf(ids ? colspec.getInt("id") - ixOffset :
+ ArrayUtils.indexOf(colnames, colspec.get("name")) + 1);
+ if(pos >= 0)
+ _numBins[pos] = colspec.containsKey("numbins") ? colspec.getInt("numbins") : 1;
}
}
+ public EncoderBin() {
+ super(new int[0], 0);
+ _numBins = new int[0];
+ }
+
+ private EncoderBin(int[] colList, int clen, int[] numBins, double[][] binMins, double[][] binMaxs) {
+ super(colList, clen);
+ _numBins = numBins;
+ _binMins = binMins;
+ _binMaxs = binMaxs;
+ }
+
@Override
public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
build(in);
@@ -121,7 +139,87 @@ public class EncoderBin extends Encoder
}
return out;
}
+
+ @Override
+ public Encoder subRangeEncoder(IndexRange ixRange) {
+ List<Integer> colsList = new ArrayList<>();
+ List<Integer> numBinsList = new ArrayList<>();
+ List<double[]> binMinsList = new ArrayList<>();
+ List<double[]> binMaxsList = new ArrayList<>();
+ for(int i = 0; i < _colList.length; i++) {
+ int col = _colList[i];
+ if(col >= ixRange.colStart && col < ixRange.colEnd) {
+ // add the correct column, removed columns before start
+ // colStart - 1 because colStart is 1-based
+ int corrColumn = (int) (col - (ixRange.colStart - 1));
+ colsList.add(corrColumn);
+ numBinsList.add(_numBins[i]);
+ binMinsList.add(_binMins[i]);
+ binMaxsList.add(_binMaxs[i]);
+ }
+ }
+ if(colsList.isEmpty())
+ // empty encoder -> sub range encoder does not exist
+ return null;
+ int[] colList = colsList.stream().mapToInt(i -> i).toArray();
+ return new EncoderBin(colList, (int) (ixRange.colEnd - ixRange.colStart),
+ numBinsList.stream().mapToInt((i) -> i).toArray(), binMinsList.toArray(new double[0][0]),
+ binMaxsList.toArray(new double[0][0]));
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderBin) {
+ EncoderBin otherBin = (EncoderBin) other;
+
+ // save the min, max as well as the number of bins for the column indexes
+ Map<Integer, MutableTriple<Integer, Double, Double>> ixBinsMap = new HashMap<>();
+ for(int i = 0; i < _colList.length; i++) {
+ ixBinsMap.put(_colList[i],
+ new MutableTriple<>(_numBins[i], _binMins[i][0], _binMaxs[i][_binMaxs[i].length - 1]));
+ }
+ for(int i = 0; i < otherBin._colList.length; i++) {
+ int column = otherBin._colList[i] + (col - 1);
+ MutableTriple<Integer, Double, Double> entry = ixBinsMap.get(column);
+ if(entry == null) {
+ ixBinsMap.put(column,
+ new MutableTriple<>(otherBin._numBins[i], otherBin._binMins[i][0],
+ otherBin._binMaxs[i][otherBin._binMaxs[i].length - 1]));
+ }
+ else {
+ // num bins will match
+ entry.middle = Math.min(entry.middle, otherBin._binMins[i][0]);
+ entry.right = Math.max(entry.right, otherBin._binMaxs[i][otherBin._binMaxs[i].length - 1]);
+ }
+ }
+
+ mergeColumnInfo(other, col);
+
+ // use the saved values to fill the arrays again
+ _numBins = new int[_colList.length];
+ _binMins = new double[_colList.length][];
+ _binMaxs = new double[_colList.length][];
+
+ for(int i = 0; i < _colList.length; i++) {
+ int column = _colList[i];
+ MutableTriple<Integer, Double, Double> entry = ixBinsMap.get(column);
+ _numBins[i] = entry.left;
+
+ double min = entry.middle;
+ double max = entry.right;
+ _binMins[i] = new double[_numBins[i]];
+ _binMaxs[i] = new double[_numBins[i]];
+ for(int j = 0; j < _numBins[i]; j++) {
+ _binMins[i][j] = min + j * (max - min) / _numBins[i];
+ _binMaxs[i][j] = min + (j + 1) * (max - min) / _numBins[i];
+ }
+ }
+ return;
+ }
+ super.mergeAt(other, col);
+ }
+
@Override
public FrameBlock getMetaData(FrameBlock meta) {
//allocate frame if necessary
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index cd21f45..c494676 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.IndexRange;
/**
* Simple composite encoder that applies a list of encoders
@@ -104,10 +105,10 @@ public class EncoderComposite extends Encoder
}
@Override
- public Encoder subRangeEncoder(int colStart, int colEnd) {
+ public Encoder subRangeEncoder(IndexRange ixRange) {
List<Encoder> subRangeEncoders = new ArrayList<>();
for (Encoder encoder : _encoders) {
- Encoder subEncoder = encoder.subRangeEncoder(colStart, colEnd);
+ Encoder subEncoder = encoder.subRangeEncoder(ixRange);
if (subEncoder != null) {
subRangeEncoders.add(subEncoder);
}
@@ -131,7 +132,7 @@ public class EncoderComposite extends Encoder
}
if(!mergedIn) {
throw new DMLRuntimeException("Tried to merge in encoder of class that is not present in "
- + "CompositeEncoder: " + otherEnc.getClass().getSimpleName());
+ + "EncoderComposite: " + otherEnc.getClass().getSimpleName());
}
}
// update dummycode encoder domain sizes based on distinctness information from other encoders
@@ -147,8 +148,11 @@ public class EncoderComposite extends Encoder
if (encoder.getClass() == other.getClass()) {
encoder.mergeAt(other, col);
// update dummycode encoder domain sizes based on distinctness information from other encoders
- if (encoder instanceof EncoderDummycode) {
- ((EncoderDummycode) encoder).updateDomainSizes(_encoders);
+ for (Encoder encDummy : _encoders) {
+ if (encDummy instanceof EncoderDummycode) {
+ ((EncoderDummycode) encDummy).updateDomainSizes(_encoders);
+ return;
+ }
}
return;
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index 8ff5e57..19d41ea 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -29,6 +29,7 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
@@ -54,6 +55,12 @@ public class EncoderDummycode extends Encoder
super(new int[0], 0);
}
+ public EncoderDummycode(int[] colList, int clen, int[] domainSizes, long dummycodedLength) {
+ super(colList, clen);
+ _domainSizes = domainSizes;
+ _dummycodedLength = dummycodedLength;
+ }
+
@Override
public int getNumCols() {
return (int)_dummycodedLength;
@@ -97,16 +104,16 @@ public class EncoderDummycode extends Encoder
}
@Override
- public Encoder subRangeEncoder(int colStart, int colEnd) {
+ public Encoder subRangeEncoder(IndexRange ixRange) {
List<Integer> cols = new ArrayList<>();
List<Integer> domainSizes = new ArrayList<>();
- int newDummycodedLength = colEnd - colStart;
- for(int i = 0; i < _colList.length; i++){
+ int newDummycodedLength = (int) ixRange.colSpan();
+ for(int i = 0; i < _colList.length; i++) {
int col = _colList[i];
- if(col >= colStart && col < colEnd) {
+ if(ixRange.inColRange(col)) {
// add the correct column, removed columns before start
// colStart - 1 because colStart is 1-based
- int corrColumn = col - (colStart - 1);
+ int corrColumn = (int) (col - (ixRange.colStart - 1));
cols.add(corrColumn);
domainSizes.add(_domainSizes[i]);
newDummycodedLength += _domainSizes[i] - 1;
@@ -116,12 +123,8 @@ public class EncoderDummycode extends Encoder
// empty encoder -> sub range encoder does not exist
return null;
- EncoderDummycode subRangeEncoder = new EncoderDummycode();
- subRangeEncoder._clen = colEnd - colStart;
- subRangeEncoder._colList = cols.stream().mapToInt(i -> i).toArray();
- subRangeEncoder._domainSizes = domainSizes.stream().mapToInt(i -> i).toArray();
- subRangeEncoder._dummycodedLength = newDummycodedLength;
- return subRangeEncoder;
+ return new EncoderDummycode(cols.stream().mapToInt(i -> i).toArray(), (int) ixRange.colSpan(),
+ domainSizes.stream().mapToInt(i -> i).toArray(), newDummycodedLength);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 57f7102..313e5b2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -73,7 +73,7 @@ public class EncoderFactory
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol)));
List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
- List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames);
+ List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
//note: any dummycode column requires recode as preparation, unless it follows binning
rcIDs = except(unionDistinct(rcIDs, except(dcIDs, binIDs)), haIDs);
List<Integer> ptIDs = except(except(UtilFunctions.getSeqList(1, clen, 1),
@@ -81,7 +81,7 @@ public class EncoderFactory
List<Integer> oIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString())));
+ TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
//create individual encoders
if( !rcIDs.isEmpty() ) {
@@ -90,7 +90,7 @@ public class EncoderFactory
lencoders.add(ra);
}
if( !haIDs.isEmpty() ) {
- EncoderFeatureHash ha = new EncoderFeatureHash(jSpec, colnames, clen);
+ EncoderFeatureHash ha = new EncoderFeatureHash(jSpec, colnames, clen, minCol, maxCol);
ha.setColList(ArrayUtils.toPrimitive(haIDs.toArray(new Integer[0])));
lencoders.add(ha);
}
@@ -98,11 +98,11 @@ public class EncoderFactory
lencoders.add(new EncoderPassThrough(
ArrayUtils.toPrimitive(ptIDs.toArray(new Integer[0])), clen));
if( !binIDs.isEmpty() )
- lencoders.add(new EncoderBin(jSpec, colnames, schema.length));
+ lencoders.add(new EncoderBin(jSpec, colnames, schema.length, minCol, maxCol));
if( !dcIDs.isEmpty() )
lencoders.add(new EncoderDummycode(jSpec, colnames, schema.length, minCol, maxCol));
if( !oIDs.isEmpty() )
- lencoders.add(new EncoderOmit(jSpec, colnames, schema.length));
+ lencoders.add(new EncoderOmit(jSpec, colnames, schema.length, minCol, maxCol));
if( !mvIDs.isEmpty() ) {
EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length);
ma.initRecodeIDList(rcIDs);
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
index 85c408b..9317dfb 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.transform.encode;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -35,10 +36,21 @@ public class EncoderFeatureHash extends Encoder
private static final long serialVersionUID = 7435806042138687342L;
private long _K;
- public EncoderFeatureHash(JSONObject parsedSpec, String[] colnames, int clen) throws JSONException {
+ public EncoderFeatureHash(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
+ throws JSONException {
super(null, clen);
- _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.HASH.toString());
- _K = getK(parsedSpec);
+ _colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol);
+ _K = getK(parsedSpec);
+ }
+
+ public EncoderFeatureHash(int[] colList, int clen, long K) {
+ super(colList, clen);
+ _K = K;
+ }
+
+ public EncoderFeatureHash() {
+ super(new int[0], 0);
+ _K = 0;
}
/**
@@ -89,6 +101,26 @@ public class EncoderFeatureHash extends Encoder
}
@Override
+ public Encoder subRangeEncoder(IndexRange ixRange) {
+ int[] colList = subRangeColList(ixRange);
+ if(colList.length == 0)
+ // empty encoder -> sub range encoder does not exist
+ return null;
+ return new EncoderFeatureHash(colList, (int) ixRange.colSpan(), _K);
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderFeatureHash) {
+ mergeColumnInfo(other, col);
+ if (((EncoderFeatureHash) other)._K != 0 && _K == 0)
+ _K = ((EncoderFeatureHash) other)._K;
+ return;
+ }
+ super.mergeAt(other, col);
+ }
+
+ @Override
public FrameBlock getMetaData(FrameBlock meta) {
if( !isApplicable() )
return meta;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index deba22f..56749a2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -82,7 +82,7 @@ public class EncoderMVImpute extends Encoder
super(null, clen);
//handle column list
- int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString());
+ int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString(), -1, -1);
initColList(collist);
//handle method list
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
index 283c196..26ba4e4 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
@@ -19,33 +19,57 @@
package org.apache.sysds.runtime.transform.encode;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
+import java.util.TreeSet;
+import java.util.stream.Collectors;
+
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class EncoderOmit extends Encoder
{
private static final long serialVersionUID = 1978852120416654195L;
- private int _rmRows = 0;
+ private boolean _federated = false;
+ //TODO perf replace with boolean[rlen] similar to removeEmpty
+ private TreeSet<Integer> _rmRows = new TreeSet<>();
- public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen)
+ public EncoderOmit(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol)
throws JSONException
{
super(null, clen);
if (!parsedSpec.containsKey(TfMethod.OMIT.toString()))
return;
- int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.OMIT.toString());
+ int[] collist = TfMetaUtils.parseJsonIDList(parsedSpec, colnames, TfMethod.OMIT.toString(), minCol, maxCol);
initColList(collist);
+ _federated = minCol != -1 || maxCol != -1;
+ }
+
+ public EncoderOmit() {
+ super(new int[0], 0);
+ }
+
+ public EncoderOmit(boolean federated) {
+ this();
+ _federated = federated;
+ }
+
+
+ private EncoderOmit(int[] colList, int clen, TreeSet<Integer> rmRows) {
+ super(colList, clen);
+ _rmRows = rmRows;
+ _federated = true;
}
public int getNumRemovedRows() {
- return _rmRows;
+ return _rmRows.size();
}
public boolean omit(String[] words, TfUtils agents)
@@ -67,45 +91,97 @@ public class EncoderOmit extends Encoder
}
@Override
- public void build(FrameBlock in) {
- //do nothing
+ public void build(FrameBlock in) {
+ if(_federated)
+ _rmRows = computeRmRows(in);
}
-
+
@Override
- public MatrixBlock apply(FrameBlock in, MatrixBlock out)
- {
- //determine output size
- int numRows = 0;
- for(int i=0; i<out.getNumRows(); i++) {
- boolean valid = true;
- for(int j=0; j<_colList.length; j++)
- valid &= !Double.isNaN(out.quickGetValue(i, _colList[j]-1));
- numRows += valid ? 1 : 0;
- }
-
- //copy over valid rows into the output
+ public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
+ // local rmRows for broadcasting encoder in spark
+ TreeSet<Integer> rmRows;
+ if(_federated)
+ rmRows = _rmRows;
+ else
+ rmRows = computeRmRows(in);
+
+ // determine output size
+ int numRows = out.getNumRows() - rmRows.size();
+
+ // copy over valid rows into the output
MatrixBlock ret = new MatrixBlock(numRows, out.getNumColumns(), false);
int pos = 0;
- for(int i=0; i<in.getNumRows(); i++) {
- //determine if valid row or omit
- boolean valid = true;
- for(int j=0; j<_colList.length; j++)
- valid &= !Double.isNaN(out.quickGetValue(i, _colList[j]-1));
- //copy row if necessary
- if( valid ) {
- for(int j=0; j<out.getNumColumns(); j++)
+ for(int i = 0; i < in.getNumRows(); i++) {
+ // copy row if necessary
+ if(!rmRows.contains(i)) {
+ for(int j = 0; j < out.getNumColumns(); j++)
ret.quickSetValue(pos, j, out.quickGetValue(i, j));
pos++;
}
}
-
- //keep info an remove rows
- _rmRows = out.getNumRows() - pos;
-
- return ret;
+
+ _rmRows = rmRows;
+
+ return ret;
+ }
+
+ private TreeSet<Integer> computeRmRows(FrameBlock in) {
+ TreeSet<Integer> rmRows = new TreeSet<>();
+ ValueType[] schema = in.getSchema();
+ for(int i = 0; i < in.getNumRows(); i++) {
+ boolean valid = true;
+ for(int colID : _colList) {
+ Object val = in.get(i, colID - 1);
+ valid &= !(val == null || (schema[colID - 1] == ValueType.STRING && val.toString().isEmpty()));
+ }
+ if(!valid)
+ rmRows.add(i);
+ }
+ return rmRows;
}
@Override
+ public Encoder subRangeEncoder(IndexRange ixRange) {
+ int[] colList = subRangeColList(ixRange);
+ if(colList.length == 0)
+ // empty encoder -> sub range encoder does not exist
+ return null;
+
+ TreeSet<Integer> rmRows = _rmRows.stream().filter((row) -> ixRange.inRowRange(row + 1))
+ .map((row) -> (int) (row - (ixRange.rowStart - 1))).collect(Collectors.toCollection(TreeSet::new));
+
+ return new EncoderOmit(colList, (int) (ixRange.colSpan()), rmRows);
+ }
+
+ @Override
+ public void mergeAt(Encoder other, int col) {
+ if(other instanceof EncoderOmit) {
+ mergeColumnInfo(other, col);
+ _rmRows.addAll(((EncoderOmit) other)._rmRows);
+ return;
+ }
+ super.mergeAt(other, col);
+ }
+
+ @Override
+ public void updateIndexRanges(long[] beginDims, long[] endDims) {
+ // first update begin dims
+ int numRowsToRemove = 0;
+ Integer removedRow = _rmRows.ceiling(0);
+ while(removedRow != null && removedRow < beginDims[0]) {
+ numRowsToRemove++;
+ removedRow = _rmRows.ceiling(removedRow + 1);
+ }
+ beginDims[0] -= numRowsToRemove;
+ // update end dims
+ while(removedRow != null && removedRow < endDims[0]) {
+ numRowsToRemove++;
+ removedRow = _rmRows.ceiling(removedRow + 1);
+ }
+ endDims[0] -= numRowsToRemove;
+ }
+
+ @Override
public FrameBlock getMetaData(FrameBlock out) {
//do nothing
return out;
@@ -116,4 +192,3 @@ public class EncoderOmit extends Encoder
//do nothing
}
}
-
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index d6ceb15..ccd235d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -25,6 +25,7 @@ import java.util.List;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
/**
@@ -72,17 +73,18 @@ public class EncoderPassThrough extends Encoder
}
@Override
- public Encoder subRangeEncoder(int colStart, int colEnd) {
+ public Encoder subRangeEncoder(IndexRange ixRange) {
List<Integer> colList = new ArrayList<>();
- for (int col : _colList) {
- if (col >= colStart && col < colEnd)
+ for(int col : _colList) {
+ if(col >= ixRange.colStart && col < ixRange.colEnd)
// add the correct column, removed columns before start
- colList.add(col - (colStart - 1));
+ colList.add((int) (col - (ixRange.colStart - 1)));
}
- if (colList.isEmpty())
+ if(colList.isEmpty())
// empty encoder -> return null
return null;
- return new EncoderPassThrough(colList.stream().mapToInt(i -> i).toArray(), colEnd - colStart);
+ return new EncoderPassThrough(colList.stream().mapToInt(i -> i).toArray(),
+ (int) (ixRange.colEnd - ixRange.colStart));
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index be1cba9..e195835 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -27,6 +27,7 @@ import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.lops.Lop;
@@ -164,25 +165,25 @@ public class EncoderRecode extends Encoder
}
@Override
- public Encoder subRangeEncoder(int colStart, int colEnd) {
+ public Encoder subRangeEncoder(IndexRange ixRange) {
List<Integer> cols = new ArrayList<>();
HashMap<Integer, HashMap<String, Long>> rcdMaps = new HashMap<>();
- for (int col : _colList) {
- if (col >= colStart && col < colEnd) {
+ for(int col : _colList) {
+ if(ixRange.inColRange(col)) {
// add the correct column, removed columns before start
// colStart - 1 because colStart is 1-based
- int corrColumn = col - (colStart - 1);
+ int corrColumn = (int) (col - (ixRange.colStart - 1));
cols.add(corrColumn);
// copy rcdMap for column
rcdMaps.put(corrColumn, new HashMap<>(_rcdMaps.get(col)));
}
}
- if (cols.isEmpty())
+ if(cols.isEmpty())
// empty encoder -> sub range encoder does not exist
return null;
-
+
int[] colList = cols.stream().mapToInt(i -> i).toArray();
- return new EncoderRecode(colList, colEnd - colStart, rcdMaps);
+ return new EncoderRecode(colList, (int) ixRange.colSpan(), rcdMaps);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
index 72fab7a..3f1a37b 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/meta/TfMetaUtils.java
@@ -133,16 +133,11 @@ public class TfMetaUtils
else {
ix = ArrayUtils.indexOf(colnames, attrs.get(i)) + 1;
}
- if(ix <= 0) {
- if (minCol == -1 && maxCol == -1) {
- // only if we remove some columns, ix -1 is expected
- throw new RuntimeException("Specified column '"
- + attrs.get(i)+"' does not exist.");
- }
- else // ignore column
- continue;
- }
- colList.add(ix);
+ if(ix > 0)
+ colList.add(ix);
+ else if(minCol == -1 && maxCol == -1)
+ // only if we remove some columns, ix -1 is expected
+ throw new RuntimeException("Specified column '" + attrs.get(i) + "' does not exist.");
}
//ensure ascending order of column IDs
@@ -152,33 +147,41 @@ public class TfMetaUtils
return arr;
}
- public static int[] parseJsonObjectIDList(JSONObject spec, String[] colnames, String group)
- throws JSONException
- {
- int[] colList = new int[0];
+ public static int[] parseJsonObjectIDList(JSONObject spec, String[] colnames, String group, int minCol, int maxCol)
+ throws JSONException {
+ List<Integer> colList = new ArrayList<>();
+ int[] arr = new int[0];
boolean ids = spec.containsKey("ids") && spec.getBoolean("ids");
-
- if( spec.containsKey(group) && spec.get(group) instanceof JSONArray )
- {
- JSONArray colspecs = (JSONArray)spec.get(group);
- colList = new int[colspecs.size()];
- for(int j=0; j<colspecs.size(); j++) {
- JSONObject colspec = (JSONObject) colspecs.get(j);
- colList[j] = ids ? colspec.getInt("id") :
- (ArrayUtils.indexOf(colnames, colspec.get("name")) + 1);
- if( colList[j] <= 0 ) {
- throw new RuntimeException("Specified column '" +
- colspec.get(ids?"id":"name")+"' does not exist.");
+
+ if(spec.containsKey(group) && spec.get(group) instanceof JSONArray) {
+ JSONArray colspecs = (JSONArray) spec.get(group);
+ for(Object o : colspecs) {
+ JSONObject colspec = (JSONObject) o;
+ int ix;
+ if(ids) {
+ ix = colspec.getInt("id");
+ if(maxCol != -1 && ix >= maxCol)
+ ix = -1;
+ if(minCol != -1 && ix >= 0)
+ ix -= minCol - 1;
+ }
+ else {
+ ix = ArrayUtils.indexOf(colnames, colspec.get("name")) + 1;
}
+ if(ix > 0)
+ colList.add(ix);
+ else if(minCol == -1 && maxCol == -1)
+ throw new RuntimeException(
+ "Specified column '" + colspec.get(ids ? "id" : "name") + "' does not exist.");
}
-
- //ensure ascending order of column IDs
- Arrays.sort(colList);
+
+ // ensure ascending order of column IDs
+ arr = colList.stream().mapToInt((i) -> i).sorted().toArray();
}
-
- return colList;
+
+ return arr;
}
-
+
/**
* Reads transform meta data from an HDFS file path and converts it into an in-memory
* FrameBlock object.
@@ -227,7 +230,7 @@ public class TfMetaUtils
//get list of recode ids
List<Integer> recodeIDs = parseRecodeColIDs(spec, colnames);
- List<Integer> binIDs = parseBinningColIDs(spec, colnames);
+ List<Integer> binIDs = parseBinningColIDs(spec, colnames, -1, -1);
//create frame block from in-memory strings
return convertToTransformMetaDataFrame(rows, colnames, recodeIDs, binIDs, meta, mvmeta);
@@ -282,7 +285,7 @@ public class TfMetaUtils
//get list of recode ids
List<Integer> recodeIDs = parseRecodeColIDs(spec, colnames);
- List<Integer> binIDs = parseBinningColIDs(spec, colnames);
+ List<Integer> binIDs = parseBinningColIDs(spec, colnames, -1, -1);
//create frame block from in-memory strings
return convertToTransformMetaDataFrame(rows, colnames, recodeIDs, binIDs, meta, mvmeta);
@@ -390,26 +393,26 @@ public class TfMetaUtils
return specRecodeIDs;
}
- public static List<Integer> parseBinningColIDs(String spec, String[] colnames)
+ public static List<Integer> parseBinningColIDs(String spec, String[] colnames, int minCol, int maxCol)
throws IOException
{
try {
JSONObject jSpec = new JSONObject(spec);
- return parseBinningColIDs(jSpec, colnames);
+ return parseBinningColIDs(jSpec, colnames, minCol, maxCol);
}
catch(JSONException ex) {
throw new IOException(ex);
}
}
- public static List<Integer> parseBinningColIDs(JSONObject jSpec, String[] colnames)
+ public static List<Integer> parseBinningColIDs(JSONObject jSpec, String[] colnames, int minCol, int maxCol)
throws IOException
{
try {
String binKey = TfMethod.BIN.toString();
if( jSpec.containsKey(binKey) && jSpec.get(binKey) instanceof JSONArray ) {
return Arrays.asList(ArrayUtils.toObject(
- TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey)));
+ TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, binKey, minCol, maxCol)));
}
else { //internally generates
return Arrays.asList(ArrayUtils.toObject(
diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
index 8b1e42e..af7471c 100644
--- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
@@ -31,6 +31,7 @@ import org.apache.hadoop.fs.FileUtil;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.permission.FsPermission;
import org.apache.hadoop.mapred.JobConf;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.OrderedJSONObject;
import org.apache.sysds.common.Types.DataType;
@@ -467,6 +468,12 @@ public class HDFSTool
} else {
mtd.put(DataExpression.AUTHORPARAM, "SystemDS");
}
+
+ if (formatProperties instanceof FileFormatPropertiesCSV) {
+ FileFormatPropertiesCSV csvProps = (FileFormatPropertiesCSV) formatProperties;
+ mtd.put(DataExpression.DELIM_HAS_HEADER_ROW, csvProps.hasHeader());
+ mtd.put(DataExpression.DELIM_DELIMITER, csvProps.getDelim());
+ }
SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z");
mtd.put(DataExpression.CREATEDPARAM, sdf.format(new Date()));
diff --git a/src/main/java/org/apache/sysds/runtime/util/IndexRange.java b/src/main/java/org/apache/sysds/runtime/util/IndexRange.java
index 69ada3b..4a8d999 100644
--- a/src/main/java/org/apache/sysds/runtime/util/IndexRange.java
+++ b/src/main/java/org/apache/sysds/runtime/util/IndexRange.java
@@ -51,7 +51,23 @@ public class IndexRange implements Serializable
rowStart + delta, rowEnd + delta,
colStart + delta, colEnd + delta);
}
-
+
+ public boolean inColRange(long col) {
+ return col >= colStart && col < colEnd;
+ }
+
+ public boolean inRowRange(long row) {
+ return row >= rowStart && row < rowEnd;
+ }
+
+ public long colSpan() {
+ return colEnd - colStart;
+ }
+
+ public long rowSpan() {
+ return rowEnd - rowStart;
+ }
+
@Override
public String toString() {
return "["+rowStart+":"+rowEnd+","+colStart+":"+colEnd+"]";
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
new file mode 100644
index 0000000..622e6e0
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.transform;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.io.MatrixReaderFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TransformFederatedEncodeApplyTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "TransformFederatedEncodeApply";
+ private final static String TEST_DIR = "functions/transform/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + TransformFederatedEncodeApplyTest.class.getSimpleName()
+ + "/";
+
+ // dataset and transform tasks without missing values
+ private final static String DATASET1 = "homes3/homes.csv";
+ private final static String SPEC1 = "homes3/homes.tfspec_recode.json";
+ private final static String SPEC1b = "homes3/homes.tfspec_recode2.json";
+ private final static String SPEC2 = "homes3/homes.tfspec_dummy.json";
+ private final static String SPEC2b = "homes3/homes.tfspec_dummy2.json";
+ private final static String SPEC3 = "homes3/homes.tfspec_bin.json"; // recode
+ private final static String SPEC3b = "homes3/homes.tfspec_bin2.json"; // recode
+ private final static String SPEC6 = "homes3/homes.tfspec_recode_dummy.json";
+ private final static String SPEC6b = "homes3/homes.tfspec_recode_dummy2.json";
+ private final static String SPEC7 = "homes3/homes.tfspec_binDummy.json"; // recode+dummy
+ private final static String SPEC7b = "homes3/homes.tfspec_binDummy2.json"; // recode+dummy
+ private final static String SPEC8 = "homes3/homes.tfspec_hash.json";
+ private final static String SPEC8b = "homes3/homes.tfspec_hash2.json";
+ private final static String SPEC9 = "homes3/homes.tfspec_hash_recode.json";
+ private final static String SPEC9b = "homes3/homes.tfspec_hash_recode2.json";
+
+ // dataset and transform tasks with missing values
+ private final static String DATASET2 = "homes/homes.csv";
+ // private final static String SPEC4 = "homes3/homes.tfspec_impute.json";
+ // private final static String SPEC4b = "homes3/homes.tfspec_impute2.json";
+ private final static String SPEC5 = "homes3/homes.tfspec_omit.json";
+ private final static String SPEC5b = "homes3/homes.tfspec_omit2.json";
+
+ private static final int[] BIN_col3 = new int[] {1, 4, 2, 3, 3, 2, 4};
+ private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3};
+
+ public enum TransformType {
+ RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY,
+ // IMPUTE,
+ OMIT,
+ HASH,
+ HASH_RECODE,
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"}));
+ }
+
+ @Test
+ public void testHomesRecodeIDsCSV() {
+ runTransformTest(TransformType.RECODE, false);
+ }
+
+ @Test
+ public void testHomesDummycodeIDsCSV() {
+ runTransformTest(TransformType.DUMMY, false);
+ }
+
+ @Test
+ public void testHomesRecodeDummycodeIDsCSV() {
+ runTransformTest(TransformType.RECODE_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesBinningIDsCSV() {
+ runTransformTest(TransformType.BIN, false);
+ }
+
+ @Test
+ public void testHomesBinningDummyIDsCSV() {
+ runTransformTest(TransformType.BIN_DUMMY, false);
+ }
+
+ @Test
+ public void testHomesOmitIDsCSV() {
+ runTransformTest(TransformType.OMIT, false);
+ }
+
+ // @Test
+ // public void testHomesImputeIDsCSV() {
+ // runTransformTest(TransformType.IMPUTE, false);
+ // }
+
+ @Test
+ public void testHomesRecodeColnamesCSV() {
+ runTransformTest(TransformType.RECODE, true);
+ }
+
+ @Test
+ public void testHomesDummycodeColnamesCSV() {
+ runTransformTest(TransformType.DUMMY, true);
+ }
+
+ @Test
+ public void testHomesRecodeDummycodeColnamesCSV() {
+ runTransformTest(TransformType.RECODE_DUMMY, true);
+ }
+
+ @Test
+ public void testHomesBinningColnamesCSV() {
+ runTransformTest(TransformType.BIN, true);
+ }
+
+ @Test
+ public void testHomesBinningDummyColnamesCSV() {
+ runTransformTest(TransformType.BIN_DUMMY, true);
+ }
+
+ @Test
+ public void testHomesOmitColnamesCSV() {
+ runTransformTest(TransformType.OMIT, true);
+ }
+
+ // @Test
+ // public void testHomesImputeColnamesCSV() {
+ // runTransformTest(TransformType.IMPUTE, true);
+ // }
+
+ @Test
+ public void testHomesHashColnamesCSV() {
+ runTransformTest(TransformType.HASH, true);
+ }
+
+ @Test
+ public void testHomesHashIDsCSV() {
+ runTransformTest(TransformType.HASH, false);
+ }
+
+ @Test
+ public void testHomesHashRecodeColnamesCSV() {
+ runTransformTest(TransformType.HASH_RECODE, true);
+ }
+
+ @Test
+ public void testHomesHashRecodeIDsCSV() {
+ runTransformTest(TransformType.HASH_RECODE, false);
+ }
+
+ private void runTransformTest(TransformType type, boolean colnames) {
+ ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
+
+ // set transform specification
+ String SPEC = null;
+ String DATASET = null;
+ switch(type) {
+ case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET = DATASET1; break;
+ case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET = DATASET1; break;
+ case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET = DATASET1; break;
+ // case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET = DATASET2; break;
+ case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET = DATASET2; break;
+ case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6; DATASET = DATASET1; break;
+ case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7; DATASET = DATASET1; break;
+ case HASH: SPEC = colnames ? SPEC8b : SPEC8; DATASET = DATASET1; break;
+ case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9; DATASET = DATASET1; break;
+ }
+
+ Thread t1 = null, t2 = null;
+ try {
+ getAndLoadTestConfiguration(TEST_NAME1);
+
+ int port1 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1);
+ int port2 = getRandomAvailablePort();
+ t2 = startLocalFedWorkerThread(port2);
+
+ FileFormatPropertiesCSV ffpCSV = new FileFormatPropertiesCSV(true, DataExpression.DEFAULT_DELIM_DELIMITER,
+ DataExpression.DEFAULT_DELIM_FILL, DataExpression.DEFAULT_DELIM_FILL_VALUE,
+ DATASET.equals(DATASET1) ? DataExpression.DEFAULT_NA_STRINGS : "NA" + DataExpression.DELIM_NA_STRING_SEP
+ + "");
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ // split up dataset
+ FrameBlock dataset = FrameReaderFactory.createFrameReader(FileFormat.CSV, ffpCSV)
+ .readFrameFromHDFS(HOME + "input/" + DATASET, -1, -1);
+
+ // default for write
+ ffpCSV.setNAStrings(UtilFunctions.defaultNaString);
+ FrameWriter fw = FrameWriterFactory.createFrameWriter(FileFormat.CSV, ffpCSV);
+
+ FrameBlock A = new FrameBlock();
+ dataset.slice(0, dataset.getNumRows() - 1, 0, dataset.getNumColumns() / 2 - 1, A);
+ fw.writeFrameToHDFS(A, input("A"), A.getNumRows(), A.getNumColumns());
+ HDFSTool.writeMetaDataFile(input("A.mtd"), null, A.getSchema(), Types.DataType.FRAME,
+ new MatrixCharacteristics(A.getNumRows(), A.getNumColumns()), FileFormat.CSV, ffpCSV);
+
+ FrameBlock B = new FrameBlock();
+ dataset.slice(0, dataset.getNumRows() - 1, dataset.getNumColumns() / 2, dataset.getNumColumns() - 1, B);
+ fw.writeFrameToHDFS(B, input("B"), B.getNumRows(), B.getNumColumns());
+ HDFSTool.writeMetaDataFile(input("B.mtd"), null, B.getSchema(), Types.DataType.FRAME,
+ new MatrixCharacteristics(B.getNumRows(), B.getNumColumns()), FileFormat.CSV, ffpCSV);
+
+ fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ programArgs = new String[] {"-nvargs", "in_A=" + TestUtils.federatedAddress(port1, input("A")),
+ "in_B=" + TestUtils.federatedAddress(port2, input("B")), "rows=" + dataset.getNumRows(),
+ "cols_A=" + A.getNumColumns(), "cols_B=" + B.getNumColumns(), "TFSPEC=" + HOME + "input/" + SPEC,
+ "TFDATA1=" + output("tfout1"), "TFDATA2=" + output("tfout2"), "OFMT=csv"};
+
+ runTest(true, false, null, -1);
+
+ // read input/output and compare
+ double[][] R1 = DataConverter.convertToDoubleMatrix(MatrixReaderFactory.createMatrixReader(FileFormat.CSV)
+ .readMatrixFromHDFS(output("tfout1"), -1L, -1L, 1000, -1));
+ double[][] R2 = DataConverter.convertToDoubleMatrix(MatrixReaderFactory.createMatrixReader(FileFormat.CSV)
+ .readMatrixFromHDFS(output("tfout2"), -1L, -1L, 1000, -1));
+ TestUtils.compareMatrices(R1, R2, R1.length, R1[0].length, 0);
+
+ // additional checks for binning as encode-decode impossible
+ if(type == TransformType.BIN) {
+ for(int i = 0; i < 7; i++) {
+ Assert.assertEquals(BIN_col3[i], R1[i][2], 1e-8);
+ Assert.assertEquals(BIN_col8[i], R1[i][7], 1e-8);
+ }
+ }
+ else if(type == TransformType.BIN_DUMMY) {
+ Assert.assertEquals(14, R1[0].length);
+ for(int i = 0; i < 7; i++) {
+ for(int j = 0; j < 4; j++) { // check dummy coded
+ Assert.assertEquals((j == BIN_col3[i] - 1) ? 1 : 0, R1[i][2 + j], 1e-8);
+ }
+ for(int j = 0; j < 3; j++) { // check dummy coded
+ Assert.assertEquals((j == BIN_col8[i] - 1) ? 1 : 0, R1[i][10 + j], 1e-8);
+ }
+ }
+ }
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ TestUtils.shutdownThreads(t1, t2);
+ resetExecMode(rtold);
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
index 29afa5b..c45be72 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java
@@ -115,11 +115,9 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
runTransformEncodeDecodeTest(false, true, Types.FileFormat.BINARY);
}
- private void runTransformEncodeDecodeTest(boolean recode, boolean sparse,
- Types.FileFormat format) {
- ExecMode platformOld = rtplatform;
- rtplatform = ExecMode.SINGLE_NODE;
-
+ private void runTransformEncodeDecodeTest(boolean recode, boolean sparse, Types.FileFormat format) {
+ ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
+
Thread t1 = null, t2 = null, t3 = null, t4 = null;
try {
getAndLoadTestConfiguration(TEST_NAME_RECODE);
@@ -197,11 +195,8 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase {
Assert.fail(ex.getMessage());
}
finally {
- TestUtils.shutdownThread(t1);
- TestUtils.shutdownThread(t2);
- TestUtils.shutdownThread(t3);
- TestUtils.shutdownThread(t4);
- rtplatform = platformOld;
+ TestUtils.shutdownThreads(t1, t2, t3, t4);
+ resetExecMode(rtold);
}
}
diff --git a/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
new file mode 100644
index 0000000..921242b
--- /dev/null
+++ b/src/test/scripts/functions/transform/TransformFederatedEncodeApply.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F1 = federated(type="frame", addresses=list($in_A, $in_B), ranges=
+ list(list(0,0), list($rows, $cols_A), # A range
+ list(0, $cols_A), list($rows, $cols_A + $cols_B))); # B range
+
+jspec = read($TFSPEC, data_type="scalar", value_type="string");
+
+[X, M] = transformencode(target=F1, spec=jspec);
+
+while(FALSE){}
+
+X2 = transformapply(target=F1, spec=jspec, meta=M);
+
+write(X, $TFDATA1, format="csv");
+write(X2, $TFDATA2, format="csv");
+