You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/04/24 19:10:04 UTC
[systemml] branch master updated: [SYSTEMDS-361] New privacy
constraint meta data (compiler/runtime)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push:
new 013ca82 [SYSTEMDS-361] New privacy constraint meta data (compiler/runtime)
013ca82 is described below
commit 013ca8224c23b1d9f63e254162a56fb78bf74c96
Author: sebwrede <se...@hotmail.com>
AuthorDate: Fri Apr 24 20:00:13 2020 +0200
[SYSTEMDS-361] New privacy constraint meta data (compiler/runtime)
Closes #895.
---
docs/Tasks.txt | 7 +
.../java/org/apache/sysds/hops/AggBinaryOp.java | 3 +-
src/main/java/org/apache/sysds/hops/DataOp.java | 1 +
src/main/java/org/apache/sysds/hops/Hop.java | 21 ++-
src/main/java/org/apache/sysds/hops/LiteralOp.java | 1 +
src/main/java/org/apache/sysds/lops/DataGen.java | 4 +-
src/main/java/org/apache/sysds/lops/Lop.java | 18 +++
.../java/org/apache/sysds/lops/compile/Dag.java | 34 +++-
.../org/apache/sysds/parser/BinaryExpression.java | 26 ++--
.../org/apache/sysds/parser/DMLTranslator.java | 3 +
.../org/apache/sysds/parser/DataExpression.java | 135 +++++++---------
.../java/org/apache/sysds/parser/Identifier.java | 15 ++
.../controlprogram/caching/CacheableData.java | 16 +-
.../sysds/runtime/instructions/Instruction.java | 12 ++
.../instructions/cp/VariableCPInstruction.java | 2 +
.../org/apache/sysds/runtime/io/MatrixReader.java | 4 +-
.../sysds/runtime/privacy/PrivacyConstraint.java | 42 +++++
.../sysds/runtime/privacy/PrivacyPropagator.java | 38 +++++
.../org/apache/sysds/runtime/util/HDFSTool.java | 38 ++++-
.../org/apache/sysds/test/AutomatedTestBase.java | 43 +++++-
src/test/java/org/apache/sysds/test/TestUtils.java | 129 ++++++----------
.../test/functions/data/misc/WriteMMTest.java | 2 +-
.../MatrixMultiplicationPropagationTest.java | 171 +++++++++++++++++++++
.../MatrixMultiplicationPropagationTest.dml | 27 ++++
24 files changed, 591 insertions(+), 201 deletions(-)
diff --git a/docs/Tasks.txt b/docs/Tasks.txt
index 2283d57..d1e30c0 100644
--- a/docs/Tasks.txt
+++ b/docs/Tasks.txt
@@ -260,5 +260,12 @@ SYSTEMDS-340 Compiler Assisted Lineage Caching and Reuse
SYSTEMDS-350 Data Cleaning Framework
* 351 New builtin function for error correction by schema OK
+SYSTEMDS-360 Privacy/Data Exchange Constraints
+ * 361 Initial privacy meta data (compiler/runtime) OK
+ * 362 Runtime privacy propagation
+ * 363 Compile-time privacy propagation
+ * 364 Error handling violated privacy constraints
+ * 365 Extended privacy/data exchange constraints
+
Others:
* Break append instruction to cbind and rbind
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index b456cc8..a04d267 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -627,7 +627,8 @@ public class AggBinaryOp extends MultiThreadedHop
setOutputDimensions(matmultCP);
}
- setLineNumbers( matmultCP );
+ setLineNumbers(matmultCP);
+ setPrivacy(matmultCP);
setLops(matmultCP);
}
diff --git a/src/main/java/org/apache/sysds/hops/DataOp.java b/src/main/java/org/apache/sysds/hops/DataOp.java
index 7a22727..99cf91e 100644
--- a/src/main/java/org/apache/sysds/hops/DataOp.java
+++ b/src/main/java/org/apache/sysds/hops/DataOp.java
@@ -311,6 +311,7 @@ public class DataOp extends Hop
}
setLineNumbers(l);
+ setPrivacy(l);
setLops(l);
//add reblock/checkpoint lops if necessary
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index ba0dd03..79a251f 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.ArrayList;
@@ -72,6 +73,7 @@ public abstract class Hop implements ParseInfo
protected ValueType _valueType;
protected boolean _visited = false;
protected DataCharacteristics _dc = new MatrixCharacteristics();
+ protected PrivacyConstraint _privacyConstraint = new PrivacyConstraint();
protected UpdateType _updateType = UpdateType.COPY;
protected ArrayList<Hop> _parent = new ArrayList<>();
@@ -317,9 +319,10 @@ public abstract class Hop implements ParseInfo
throw new HopsException(ex);
}
- setOutputDimensions( reblock );
- setLineNumbers( reblock );
- setLops( reblock );
+ setOutputDimensions(reblock);
+ setLineNumbers(reblock);
+ setPrivacy(reblock);
+ setLops(reblock);
}
}
@@ -764,6 +767,14 @@ public abstract class Hop implements ParseInfo
return _dc.getNonZeros();
}
+ public void setPrivacy(PrivacyConstraint privacy){
+ _privacyConstraint = privacy;
+ }
+
+ public PrivacyConstraint getPrivacy(){
+ return _privacyConstraint;
+ }
+
public void setUpdateType(UpdateType update){
_updateType = update;
}
@@ -1385,6 +1396,10 @@ public abstract class Hop implements ParseInfo
protected void setLineNumbers(Lop lop) {
lop.setAllPositions(getFilename(), getBeginLine(), getBeginColumn(), getEndLine(), getEndColumn());
}
+
+ protected void setPrivacy(Lop lop) {
+ lop.setPrivacyConstraint(getPrivacy());
+ }
/**
* Set parse information.
diff --git a/src/main/java/org/apache/sysds/hops/LiteralOp.java b/src/main/java/org/apache/sysds/hops/LiteralOp.java
index a7151de..61a7acb 100644
--- a/src/main/java/org/apache/sysds/hops/LiteralOp.java
+++ b/src/main/java/org/apache/sysds/hops/LiteralOp.java
@@ -112,6 +112,7 @@ public class LiteralOp extends Hop
l.getOutputParameters().setDimensions(0, 0, 0, -1);
setLineNumbers(l);
+ setPrivacy(l);
setLops(l);
}
catch(LopsException e) {
diff --git a/src/main/java/org/apache/sysds/lops/DataGen.java b/src/main/java/org/apache/sysds/lops/DataGen.java
index 27a634c..ddc1a8a 100644
--- a/src/main/java/org/apache/sysds/lops/DataGen.java
+++ b/src/main/java/org/apache/sysds/lops/DataGen.java
@@ -127,8 +127,8 @@ public class DataGen extends Lop
//sanity checks
if ( _op != OpOpDG.RAND )
throw new LopsException("Invalid instruction generation for data generation method " + _op);
- if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.length - 2 && // tensor
- getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.length - 1 ) { // matrix
+ if( getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 2 && // tensor
+ getInputs().size() != DataExpression.RAND_VALID_PARAM_NAMES.size() - 1 ) { // matrix
throw new LopsException(printErrorLocation() + "Invalid number of operands ("
+ getInputs().size() + ") for a Rand operation");
}
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java
index 8bb7e1a..76f0caa 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -25,6 +25,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.lops.LopProperties.ExecType;
import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
@@ -112,6 +113,11 @@ public abstract class Lop
*/
protected ArrayList<Lop> inputs;
protected ArrayList<Lop> outputs;
+
+ /**
+ * Privacy Constraint
+ */
+ protected PrivacyConstraint privacyConstraint;
/**
* refers to #lops whose input is equal to the output produced by this lop.
@@ -273,6 +279,18 @@ public abstract class Lop
public void addOutput(Lop op) {
outputs.add(op);
}
+
+ /**
+ * Method to set privacy constraint of Lop.
+ * @param privacy privacy constraint instance
+ */
+ public void setPrivacyConstraint(PrivacyConstraint privacy){
+ privacyConstraint = privacy;
+ }
+
+ public PrivacyConstraint getPrivacyConstraint(){
+ return privacyConstraint;
+ }
public void setConsumerCount(int cc) {
consumerCount = cc;
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 6acfb87..65cbb99 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -352,6 +352,9 @@ public class Dag<N extends Lop>
String inst_string = n.getInstructions();
CPInstruction currInstr = CPInstructionParser.parseSingleInstruction(inst_string);
currInstr.setLocation(n);
+ // TODO find a more direct way of communicating the privacy constraints
+ // (visible to runtime explain); This change should apply to all occurrences.
+ currInstr.setPrivacyConstraint(n);
insts.add(currInstr);
} catch (DMLRuntimeException e) {
throw new LopsException(n.printErrorLocation() + "error generating instructions from input variables in Dag -- \n", e);
@@ -406,7 +409,10 @@ public class Dag<N extends Lop>
if (locationInfo != null)
currInstr.setLocation(locationInfo);
else
+ {
currInstr.setLocation(node);
+ currInstr.setPrivacyConstraint(node);
+ }
inst.add(currInstr);
excludeRemoveInstruction(label, deleteInst);
@@ -593,12 +599,21 @@ public class Dag<N extends Lop>
throw new LopsException("Error parsing the instruction:" + inst_string);
}
if (node._beginLine != 0)
+ {
currInstr.setLocation(node);
+ currInstr.setPrivacyConstraint(node);
+ }
else if ( !node.getOutputs().isEmpty() )
+ {
currInstr.setLocation(node.getOutputs().get(0));
+ currInstr.setPrivacyConstraint(node.getOutputs().get(0));
+ }
else if ( !node.getInputs().isEmpty() )
+ {
currInstr.setLocation(node.getInputs().get(0));
-
+ currInstr.setPrivacyConstraint(node.getInputs().get(0));
+ }
+
inst.add(currInstr);
} catch (Exception e) {
throw new LopsException(node.printErrorLocation() + "Problem generating simple inst - "
@@ -785,6 +800,7 @@ public class Dag<N extends Lop>
Instruction currInstr = VariableCPInstruction.prepareRemoveInstruction(oparams.getLabel());
currInstr.setLocation(node);
+ currInstr.setPrivacyConstraint(node);
out.addLastInstruction(currInstr);
}
@@ -806,6 +822,7 @@ public class Dag<N extends Lop>
oparams.getUpdateType());
createvarInst.setLocation(node);
+ createvarInst.setPrivacyConstraint(node);
out.addPreInstruction(createvarInst);
@@ -813,6 +830,7 @@ public class Dag<N extends Lop>
Instruction currInstr = VariableCPInstruction.prepareRemoveInstruction(oparams.getLabel());
currInstr.setLocation(node);
+ currInstr.setPrivacyConstraint(node);
out.addLastInstruction(currInstr);
}
@@ -832,10 +850,14 @@ public class Dag<N extends Lop>
new MatrixCharacteristics(fnOutParams.getNumRows(), fnOutParams.getNumCols(), (int)fnOutParams.getBlocksize(), fnOutParams.getNnz()),
oparams.getUpdateType());
- if (node._beginLine != 0)
+ if (node._beginLine != 0){
createvarInst.setLocation(node);
- else
+ createvarInst.setPrivacyConstraint(node);
+ }
+ else {
createvarInst.setLocation(fnOut);
+ createvarInst.setPrivacyConstraint(fnOut);
+ }
out.addPreInstruction(createvarInst);
}
@@ -985,8 +1007,10 @@ public class Dag<N extends Lop>
Instruction currInstr = (node.getExecType() == ExecType.SPARK) ?
SPInstructionParser.parseSingleInstruction(io_inst) :
CPInstructionParser.parseSingleInstruction(io_inst);
- currInstr.setLocation((!node.getInputs().isEmpty()
- && node.getInputs().get(0)._beginLine != 0) ? node.getInputs().get(0) : node);
+ Lop useNode = (!node.getInputs().isEmpty()
+ && node.getInputs().get(0)._beginLine != 0) ? node.getInputs().get(0) : node;
+ currInstr.setLocation(useNode);
+ currInstr.setPrivacyConstraint(useNode);
out.addLastInstruction(currInstr);
}
diff --git a/src/main/java/org/apache/sysds/parser/BinaryExpression.java b/src/main/java/org/apache/sysds/parser/BinaryExpression.java
index 86a4558..6c177e2 100644
--- a/src/main/java/org/apache/sysds/parser/BinaryExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BinaryExpression.java
@@ -23,6 +23,7 @@ import java.util.HashMap;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.privacy.PrivacyPropagator;
public class BinaryExpression extends Expression
@@ -126,25 +127,28 @@ public class BinaryExpression extends Expression
output.setValueType(resultVT);
checkAndSetDimensions(output, conditional);
- if (this.getOpCode() == Expression.BinaryOp.MATMULT) {
- if ((this.getLeft().getOutput().getDataType() != DataType.MATRIX) || (this.getRight().getOutput().getDataType() != DataType.MATRIX)) {
+ if (getOpCode() == Expression.BinaryOp.MATMULT) {
+ if ((getLeft().getOutput().getDataType() != DataType.MATRIX) || (getRight().getOutput().getDataType() != DataType.MATRIX)) {
// remove exception for now
// throw new LanguageException(
// "Matrix multiplication not supported for scalars",
// LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
}
- if (this.getLeft().getOutput().getDim2() != -1
- && this.getRight().getOutput().getDim1() != -1
- && this.getLeft().getOutput().getDim2() != this.getRight()
- .getOutput().getDim1())
+ if (getLeft().getOutput().getDim2() != -1 && getRight().getOutput().getDim1() != -1
+ && getLeft().getOutput().getDim2() != getRight().getOutput().getDim1())
{
- raiseValidateError("invalid dimensions for matrix multiplication (k1="+this.getLeft().getOutput().getDim2()+", k2="+this.getRight().getOutput().getDim1()+")",
- conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
+ raiseValidateError("invalid dimensions for matrix multiplication (k1="
+ +getLeft().getOutput().getDim2()+", k2="+getRight().getOutput().getDim1()+")",
+ conditional, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
}
- output.setDimensions(this.getLeft().getOutput().getDim1(), this
- .getRight().getOutput().getDim2());
+ output.setDimensions(getLeft().getOutput().getDim1(),
+ getRight().getOutput().getDim2());
}
+ // Set privacy of output
+ output.setPrivacy(PrivacyPropagator.MergeBinary(
+ getLeft().getOutput().getPrivacy(), getRight().getOutput().getPrivacy()));
+
this.setOutput(output);
}
@@ -199,7 +203,6 @@ public class BinaryExpression extends Expression
}
return "(" + leftString + " " + _opcode.toString() + " "
+ rightString + ")";
-
}
@Override
@@ -223,6 +226,5 @@ public class BinaryExpression extends Expression
|| (op == BinaryOp.MULT) || (op == BinaryOp.DIV)
|| (op == BinaryOp.MODULUS) || (op == BinaryOp.INTDIV)
|| (op == BinaryOp.POW);
-
}
}
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index e61c928..f1f64c1 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2082,6 +2082,8 @@ public class DMLTranslator
setIdentifierParams(currBuiltinOp, source.getOutput());
if( source.getOpCode()==DataExpression.DataOp.READ )
((DataOp)currBuiltinOp).setInputBlocksize(target.getBlocksize());
+ else if ( source.getOpCode() == DataExpression.DataOp.WRITE )
+ ((DataOp)currBuiltinOp).setPrivacy(hops.get(target.getName()).getPrivacy());
currBuiltinOp.setParseInfo(source);
return currBuiltinOp;
@@ -2747,6 +2749,7 @@ public class DMLTranslator
if( id.getNnz()>= 0 )
h.setNnz(id.getNnz());
h.setBlocksize(id.getBlocksize());
+ h.setPrivacy(id.getPrivacy());
}
private boolean prepareReadAfterWrite( DMLProgram prog, HashMap<String, DataIdentifier> pWrites ) {
diff --git a/src/main/java/org/apache/sysds/parser/DataExpression.java b/src/main/java/org/apache/sysds/parser/DataExpression.java
index 1b7ddc4..baa2b48 100644
--- a/src/main/java/org/apache/sysds/parser/DataExpression.java
+++ b/src/main/java/org/apache/sysds/parser/DataExpression.java
@@ -45,6 +45,8 @@ import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Set;
import java.util.Map.Entry;
@@ -96,6 +98,8 @@ public class DataExpression extends DataIdentifier
public static final String SCHEMAPARAM = "schema";
public static final String CREATEDPARAM = "created";
+ public static final String PRIVACY = "privacy";
+
// Parameter names relevant to reading/writing delimited/csv files
public static final String DELIM_DELIMITER = "sep";
public static final String DELIM_HAS_HEADER_ROW = "header";
@@ -107,31 +111,34 @@ public class DataExpression extends DataIdentifier
public static final String DELIM_SPARSE = "sparse"; // applicable only for write
- public static final String[] RAND_VALID_PARAM_NAMES =
- {RAND_ROWS, RAND_COLS, RAND_DIMS, RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, RAND_PDF, RAND_LAMBDA};
+ public static final Set<String> RAND_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(RAND_ROWS, RAND_COLS, RAND_DIMS,
+ RAND_MIN, RAND_MAX, RAND_SPARSITY, RAND_SEED, RAND_PDF, RAND_LAMBDA));
- public static final String[] RESHAPE_VALID_PARAM_NAMES =
- { RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS};
+ public static final Set<String> RESHAPE_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(RAND_BY_ROW, RAND_DIMNAMES, RAND_DATA, RAND_ROWS, RAND_COLS, RAND_DIMS));
- public static final String[] SQL_VALID_PARAM_NAMES = {SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY};
+ public static final Set<String> SQL_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(SQL_CONN, SQL_USER, SQL_PASS, SQL_QUERY));
- public static final String[] FEDERATED_VALID_PARAM_NAMES = {FED_ADDRESSES, FED_RANGES};
+ public static final Set<String> FEDERATED_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(FED_ADDRESSES, FED_RANGES));
// Valid parameter names in a metadata file
- public static final String[] READ_VALID_MTD_PARAM_NAMES =
- { IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM, FORMAT_TYPE,
- ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM, DESCRIPTIONPARAM,
- AUTHORPARAM, CREATEDPARAM,
+ public static final Set<String> READ_VALID_MTD_PARAM_NAMES =new HashSet<>(
+ Arrays.asList(IO_FILENAME, READROWPARAM, READCOLPARAM, READNNZPARAM,
+ FORMAT_TYPE, ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, DATATYPEPARAM,
+ VALUETYPEPARAM, SCHEMAPARAM, DESCRIPTIONPARAM, AUTHORPARAM, CREATEDPARAM,
// Parameters related to delimited/csv files.
- DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS
- };
+ DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS,
+ // Parameters related to privacy
+ PRIVACY));
- public static final String[] READ_VALID_PARAM_NAMES =
- { IO_FILENAME, READROWPARAM, READCOLPARAM, FORMAT_TYPE, DATATYPEPARAM, VALUETYPEPARAM, SCHEMAPARAM,
- ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNNZPARAM,
+ public static final Set<String> READ_VALID_PARAM_NAMES = new HashSet<>(
+ Arrays.asList(IO_FILENAME, READROWPARAM, READCOLPARAM, FORMAT_TYPE, DATATYPEPARAM,
+ VALUETYPEPARAM, SCHEMAPARAM, ROWBLOCKCOUNTPARAM, COLUMNBLOCKCOUNTPARAM, READNNZPARAM,
// Parameters related to delimited/csv files.
- DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS
- };
+ DELIM_FILL_VALUE, DELIM_DELIMITER, DELIM_FILL, DELIM_HAS_HEADER_ROW, DELIM_NA_STRINGS));
/* Default Values for delimited (CSV/LIBSVM) files */
public static final String DEFAULT_DELIM_DELIMITER = ",";
@@ -210,11 +217,8 @@ public class DataExpression extends DataIdentifier
return null;
}
// verify parameter names for read function
- boolean isValidName = false;
- for (String paramName : READ_VALID_PARAM_NAMES){
- if (paramName.equals(currName))
- isValidName = true;
- }
+ boolean isValidName = READ_VALID_PARAM_NAMES.contains(currName);
+
if (!isValidName){
errorListener.validationError(parseInfo, "attempted to add invalid read statement parameter " + currName);
return null;
@@ -466,15 +470,7 @@ public class DataExpression extends DataIdentifier
return;
}
// check name is valid
- boolean found = false;
- if (paramName != null ){
- for (String name : RAND_VALID_PARAM_NAMES){
- if (name.equals(paramName)) {
- found = true;
- break;
- }
- }
- }
+ boolean found = RAND_VALID_PARAM_NAMES.contains(paramName);
if (!found){
raiseValidateError("unexpected parameter \"" + paramName +
"\". Legal parameters for Rand statement are "
@@ -500,10 +496,7 @@ public class DataExpression extends DataIdentifier
public void addMatrixExprParam(String paramName, Expression paramValue)
{
// check name is valid
- boolean found = false;
- if (paramName != null ){
- found = Arrays.stream(RESHAPE_VALID_PARAM_NAMES).anyMatch((name) -> name.equals(paramName));
- }
+ boolean found = RESHAPE_VALID_PARAM_NAMES.contains(paramName);
if (!found){
raiseValidateError("unexpected parameter \"" + paramName +
@@ -529,10 +522,7 @@ public class DataExpression extends DataIdentifier
public void addTensorExprParam(String paramName, Expression paramValue)
{
// check name is valid
- boolean found = false;
- if (paramName != null ){
- found = Arrays.asList(RESHAPE_VALID_PARAM_NAMES).contains(paramName);
- }
+ boolean found = RESHAPE_VALID_PARAM_NAMES.contains(paramName);
if (!found){
raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for tensor statement are "
@@ -558,10 +548,7 @@ public class DataExpression extends DataIdentifier
public void addSqlExprParam(String paramName, Expression paramValue)
{
// check name is valid
- boolean found = false;
- if (paramName != null ){
- found = Arrays.asList(SQL_VALID_PARAM_NAMES).contains(paramName);
- }
+ boolean found = SQL_VALID_PARAM_NAMES.contains(paramName);
if (!found){
raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for sql statement are "
@@ -578,8 +565,7 @@ public class DataExpression extends DataIdentifier
public void addFederatedExprParam(String paramName, Expression paramValue) {
// check name is valid
- boolean found = (paramName != null ) &&
- Arrays.asList(FEDERATED_VALID_PARAM_NAMES).contains(paramName);
+ boolean found = FEDERATED_VALID_PARAM_NAMES.contains(paramName);
if (!found)
raiseValidateError("unexpected parameter \"" + paramName + "\". Legal parameters for federated statement are "
@@ -988,17 +974,11 @@ public class DataExpression extends DataIdentifier
|| key.equals(READNNZPARAM) || key.equals(DATATYPEPARAM) || key.equals(VALUETYPEPARAM)
|| key.equals(SCHEMAPARAM)) )
{
- String msg = "Only parameters allowed are: " + IO_FILENAME + ","
- + SCHEMAPARAM + ","
- + DELIM_HAS_HEADER_ROW + ","
- + DELIM_DELIMITER + ","
- + DELIM_FILL + ","
- + DELIM_FILL_VALUE + ","
- + READROWPARAM + ","
- + READCOLPARAM;
-
+ String msg = "Only parameters allowed are: " + Arrays.toString(new String[] {
+ IO_FILENAME, SCHEMAPARAM, DELIM_HAS_HEADER_ROW, DELIM_DELIMITER,
+ DELIM_FILL, DELIM_FILL_VALUE, READROWPARAM, READCOLPARAM});
raiseValidateError("Invalid parameter " + key + " in read statement: " +
- toString() + ". " + msg, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
+ toString() + ". " + msg, conditional, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
}
@@ -1087,18 +1067,25 @@ public class DataExpression extends DataIdentifier
isMatrix = true;
// set data type
- getOutput().setDataType(isMatrix ? DataType.MATRIX : DataType.FRAME);
-
- // set number non-zeros
- Expression ennz = this.getVarParam("nnz");
- long nnz = -1;
- if( ennz != null )
- {
- nnz = Long.valueOf(ennz.toString());
- getOutput().setNnz(nnz);
- }
-
- // Following dimension checks must be done when data type = MATRIX_DATA_TYPE
+ getOutput().setDataType(isMatrix ? DataType.MATRIX : DataType.FRAME);
+
+ // set number non-zeros
+ Expression ennz = getVarParam("nnz");
+ long nnz = -1;
+ if( ennz != null ) {
+ nnz = Long.valueOf(ennz.toString());
+ getOutput().setNnz(nnz);
+ }
+
+ // set privacy
+ Expression eprivacy = getVarParam("privacy");
+ boolean privacy = false;
+ if ( eprivacy != null ) {
+ privacy = Boolean.valueOf(eprivacy.toString());
+ getOutput().setPrivacy(privacy);
+ }
+
+ // Following dimension checks must be done when data type = MATRIX_DATA_TYPE
// initialize size of target data identifier to UNKNOWN
getOutput().setDimensions(-1, -1);
@@ -1919,13 +1906,10 @@ public class DataExpression extends DataIdentifier
}
}
- private void validateParams(boolean conditional, String[] validParamNames, String legalMessage) {
+ private void validateParams(boolean conditional, Set<String> validParamNames, String legalMessage) {
for( String key : _varParams.keySet() )
{
- boolean found = false;
- for (String name : validParamNames) {
- found |= name.equals(key);
- }
+ boolean found = validParamNames.contains(key);
if( !found ) {
raiseValidateError("unexpected parameter \"" + key + "\". "
+ legalMessage, conditional);
@@ -2061,11 +2045,7 @@ public class DataExpression extends DataIdentifier
Object key = e.getKey();
Object val = e.getValue();
- boolean isValidName = false;
- for (String paramName : READ_VALID_MTD_PARAM_NAMES){
- if (paramName.equals(key))
- isValidName = true;
- }
+ boolean isValidName = READ_VALID_MTD_PARAM_NAMES.contains(key);
if (!isValidName){ //wrong parameters always rejected
raiseValidateError("MTD file " + mtdFileName + " contains invalid parameter name: " + key, false);
@@ -2091,6 +2071,7 @@ public class DataExpression extends DataIdentifier
if ( key.toString().equalsIgnoreCase(DELIM_HAS_HEADER_ROW)
|| key.toString().equalsIgnoreCase(DELIM_FILL)
|| key.toString().equalsIgnoreCase(DELIM_SPARSE)
+ || key.toString().equalsIgnoreCase(PRIVACY)
) {
// parse these parameters as boolean values
BooleanIdentifier boolId = null;
diff --git a/src/main/java/org/apache/sysds/parser/Identifier.java b/src/main/java/org/apache/sysds/parser/Identifier.java
index 402069e..39da340 100644
--- a/src/main/java/org/apache/sysds/parser/Identifier.java
+++ b/src/main/java/org/apache/sysds/parser/Identifier.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
public abstract class Identifier extends Expression
{
@@ -34,6 +35,7 @@ public abstract class Identifier extends Expression
protected int _blocksize;
protected long _nnz;
protected FormatType _formatType;
+ protected PrivacyConstraint _privacy;
public Identifier() {
_dim1 = -1;
@@ -62,6 +64,7 @@ public abstract class Identifier extends Expression
_blocksize = i.getBlocksize();
_nnz = i.getNnz();
_formatType = i.getFormatType();
+ _privacy = i.getPrivacy();
}
public void setDimensionValueProperties(Identifier i) {
@@ -99,6 +102,14 @@ public abstract class Identifier extends Expression
public void setNnz(long nnzs){
_nnz = nnzs;
}
+
+ public void setPrivacy(boolean privacy){
+ _privacy = new PrivacyConstraint(privacy);
+ }
+
+ public void setPrivacy(PrivacyConstraint privacyConstraint){
+ _privacy = privacyConstraint;
+ }
public long getDim1(){
return _dim1;
@@ -131,6 +142,10 @@ public abstract class Identifier extends Expression
public long getNnz(){
return _nnz;
}
+
+ public PrivacyConstraint getPrivacy(){
+ return _privacy;
+ }
@Override
public void validateExpression(HashMap<String,DataIdentifier> ids, HashMap<String,ConstIdentifier> constVars, boolean conditional)
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index a27318a..32b6162 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -46,6 +46,7 @@ import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
import org.apache.sysds.utils.Statistics;
@@ -161,6 +162,11 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
* must get the OutputInfo that matches with InputInfo stored inside _mtd.
*/
protected MetaData _metaData = null;
+
+ /**
+ * Object holding all privacy constraints associated with the cacheable data.
+ */
+ protected PrivacyConstraint _privacyConstraint = null;
/** The name of HDFS file in which the data is backed up. */
protected String _hdfsFileName = null; // file name and path
@@ -305,6 +311,14 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
public void removeMetaData() {
_metaData = null;
}
+
+ public void setPrivacyConstraints(PrivacyConstraint pc) {
+ _privacyConstraint = pc;
+ }
+
+ public PrivacyConstraint getPrivacyConstraint() {
+ return _privacyConstraint;
+ }
public DataCharacteristics getDataCharacteristics() {
return _metaData.getDataCharacteristics();
@@ -930,7 +944,7 @@ public abstract class CacheableData<T extends CacheBlock> extends Data
//write the actual meta data file
HDFSTool.writeMetaDataFile (filePathAndName + ".mtd", valueType,
- getSchema(), dataType, dc, oinfo, formatProperties);
+ getSchema(), dataType, dc, oinfo, formatProperties, _privacyConstraint);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index c3adaeb..adcae38 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -27,6 +27,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
public abstract class Instruction
{
@@ -69,6 +70,9 @@ public abstract class Instruction
protected int endLine = -1;
protected int beginCol = -1;
protected int endCol = -1;
+
+ //privacy meta data
+ protected PrivacyConstraint privacyConstraint = null;
public String getFilename() {
return filename;
@@ -129,6 +133,14 @@ public abstract class Instruction
this.endCol = oldInst.endCol;
}
}
+
+ public void setPrivacyConstraint(Lop lop){
+ privacyConstraint = lop.getPrivacyConstraint();
+ }
+
+ public PrivacyConstraint getPrivacyConstraint(){
+ return privacyConstraint;
+ }
/**
* Getter for instruction line number
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 9176335..456b999 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -527,6 +527,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace
//clone meta data because it is updated on copy-on-write, otherwise there
//is potential for hidden side effects between variables.
obj.setMetaData((MetaData)metadata.clone());
+ obj.setPrivacyConstraints(getPrivacyConstraint());
obj.setFileFormatProperties(_formatProperties);
obj.setMarkForLinCache(true);
obj.enableCleanup(!getInput1().getName()
@@ -895,6 +896,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace
else {
// Default behavior
MatrixObject mo = ec.getMatrixObject(getInput1().getName());
+ mo.setPrivacyConstraints(getPrivacyConstraint());
mo.exportData(fname, outFmt, _formatProperties);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
index 192a477..893d665 100644
--- a/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
+++ b/src/main/java/org/apache/sysds/runtime/io/MatrixReader.java
@@ -58,8 +58,8 @@ public abstract class MatrixReader
public abstract MatrixBlock readMatrixFromHDFS( String fname, long rlen, long clen, int blen, long estnnz )
throws IOException, DMLRuntimeException;
- public abstract MatrixBlock readMatrixFromInputStream( InputStream is, long rlen, long clen, int blen, long estnnz )
- throws IOException, DMLRuntimeException;
+ public abstract MatrixBlock readMatrixFromInputStream( InputStream is, long rlen, long clen, int blen, long estnnz)
+ throws IOException, DMLRuntimeException;
/**
* NOTE: mallocDense controls if the output matrix blocks is fully allocated, this can be redundant
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
new file mode 100644
index 0000000..2b32636
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.privacy;
+
+/**
+ * PrivacyConstraint holds all privacy constraints for data in the system at compile time and runtime.
+ */
+public class PrivacyConstraint
+{
+ protected boolean _privacy = false;
+
+ public PrivacyConstraint(){}
+
+ public PrivacyConstraint(boolean privacy) {
+ _privacy = privacy;
+ }
+
+ public void setPrivacy(boolean privacy){
+ _privacy = privacy;
+ }
+
+ public boolean getPrivacy(){
+ return _privacy;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
new file mode 100644
index 0000000..2070c99
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyPropagator.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.privacy;
+
+/**
+ * Class with static methods merging privacy constraints of operands
+ * in expressions to generate the privacy constraints of the output.
+ */
+public class PrivacyPropagator {
+
+ public static PrivacyConstraint MergeBinary(PrivacyConstraint privacyConstraint1, PrivacyConstraint privacyConstraint2) {
+ if (privacyConstraint1 != null && privacyConstraint2 != null)
+ return new PrivacyConstraint(
+ privacyConstraint1.getPrivacy() || privacyConstraint2.getPrivacy());
+ else if (privacyConstraint1 != null)
+ return privacyConstraint1;
+ else if (privacyConstraint2 != null)
+ return privacyConstraint2;
+ return null;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
index 643c509..bb37873 100644
--- a/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
+++ b/src/main/java/org/apache/sysds/runtime/util/HDFSTool.java
@@ -48,6 +48,7 @@ import org.apache.sysds.runtime.matrix.data.InputInfo;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.OutputInfo;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import java.io.BufferedReader;
import java.io.BufferedWriter;
@@ -341,10 +342,20 @@ public class HDFSTool
throws IOException {
writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, outinfo);
}
+
+ public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics mc, OutputInfo outinfo, PrivacyConstraint privacyConstraint)
+ throws IOException {
+ writeMetaDataFile(mtdfile, vt, null, DataType.MATRIX, mc, outinfo, null, privacyConstraint);
+ }
public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, OutputInfo outinfo)
throws IOException {
- writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, null);
+ writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, (PrivacyConstraint) null);
+ }
+
+ public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics mc, OutputInfo outinfo, PrivacyConstraint privacyConstraint)
+ throws IOException {
+ writeMetaDataFile(mtdfile, vt, schema, dt, mc, outinfo, null, privacyConstraint);
}
public static void writeMetaDataFile(String mtdfile, ValueType vt, DataCharacteristics dc, OutputInfo outinfo, FileFormatProperties formatProperties)
@@ -356,10 +367,17 @@ public class HDFSTool
OutputInfo outinfo, FileFormatProperties formatProperties)
throws IOException
{
+ writeMetaDataFile(mtdfile, vt, schema, dt, dc, outinfo, formatProperties, null);
+ }
+
+ public static void writeMetaDataFile(String mtdfile, ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc,
+ OutputInfo outinfo, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint)
+ throws IOException
+ {
Path path = new Path(mtdfile);
FileSystem fs = IOUtilFunctions.getFileSystem(path);
try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) {
- String mtd = metaDataToString(vt, schema, dt, dc, outinfo, formatProperties);
+ String mtd = metaDataToString(vt, schema, dt, dc, outinfo, formatProperties, privacyConstraint);
br.write(mtd);
} catch (Exception e) {
throw new IOException("Error creating and writing metadata JSON file", e);
@@ -369,10 +387,16 @@ public class HDFSTool
public static void writeScalarMetaDataFile(String mtdfile, ValueType vt)
throws IOException
{
+ writeScalarMetaDataFile(mtdfile, vt, null);
+ }
+
+ public static void writeScalarMetaDataFile(String mtdfile, ValueType vt, PrivacyConstraint privacyConstraint)
+ throws IOException
+ {
Path path = new Path(mtdfile);
FileSystem fs = IOUtilFunctions.getFileSystem(path);
try( BufferedWriter br = new BufferedWriter(new OutputStreamWriter(fs.create(path,true))) ) {
- String mtd = metaDataToString(vt, null, DataType.SCALAR, null, OutputInfo.TextCellOutputInfo, null);
+ String mtd = metaDataToString(vt, null, DataType.SCALAR, null, OutputInfo.TextCellOutputInfo, null, privacyConstraint);
br.write(mtd);
}
catch (Exception e) {
@@ -381,7 +405,7 @@ public class HDFSTool
}
public static String metaDataToString(ValueType vt, ValueType[] schema, DataType dt, DataCharacteristics dc,
- OutputInfo outinfo, FileFormatProperties formatProperties) throws JSONException, DMLRuntimeException
+ OutputInfo outinfo, FileFormatProperties formatProperties, PrivacyConstraint privacyConstraint) throws JSONException, DMLRuntimeException
{
OrderedJSONObject mtd = new OrderedJSONObject(); // maintain order in output file
@@ -427,6 +451,12 @@ public class HDFSTool
}
}
+ //add privacy constraints
+ if ( privacyConstraint != null ){
+ mtd.put(DataExpression.PRIVACY, privacyConstraint.getPrivacy());
+ }
+
+ //add username and time
String userName = System.getProperty("user.name");
if (StringUtils.isNotEmpty(userName)) {
mtd.put(DataExpression.AUTHORPARAM, userName);
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index c3d224f..5217d0c 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -37,6 +37,7 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession.Builder;
+import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
import org.junit.After;
import org.junit.Assert;
@@ -61,6 +62,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.OutputInfo;
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.utils.ParameterBuilder;
@@ -439,17 +441,32 @@ public abstract class AutomatedTestBase {
protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, long nnz, boolean bIncludeR) {
MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length,
OptimizerUtils.DEFAULT_BLOCKSIZE, nnz);
- return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc);
+ return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null);
+ }
+
+ protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR,
+ MatrixCharacteristics mc) {
+ return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, null);
+ }
+
+ protected double [][] writeInputMatrixWithMTD(String name, double[][] matrix, PrivacyConstraint privacyConstraint) {
+ return writeInputMatrixWithMTD(name, matrix, false, null, privacyConstraint);
+ }
+
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, PrivacyConstraint privacyConstraint) {
+ MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length,
+ OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+ return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc, privacyConstraint);
}
protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR,
- MatrixCharacteristics mc) {
+ MatrixCharacteristics mc, PrivacyConstraint privacyConstraint) {
writeInputMatrix(name, matrix, bIncludeR);
// write metadata file
try {
String completeMTDPath = baseDirectory + INPUT_DIR + name + ".mtd";
- HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, OutputInfo.stringToOutputInfo("textcell"));
+ HDFSTool.writeMetaDataFile(completeMTDPath, ValueType.FP64, mc, OutputInfo.stringToOutputInfo("textcell"), privacyConstraint);
}
catch(IOException e) {
e.printStackTrace();
@@ -678,8 +695,7 @@ public abstract class AutomatedTestBase {
public static MatrixCharacteristics readDMLMetaDataFile(String fileName) {
try {
- String fname = baseDirectory + OUTPUT_DIR + fileName + ".mtd";
- JSONObject meta = new DataExpression().readMetadataFile(fname, false);
+ JSONObject meta = getMetaDataJSON(fileName);
long rlen = Long.parseLong(meta.get(DataExpression.READROWPARAM).toString());
long clen = Long.parseLong(meta.get(DataExpression.READCOLPARAM).toString());
return new MatrixCharacteristics(rlen, clen, -1, -1);
@@ -689,10 +705,23 @@ public abstract class AutomatedTestBase {
}
}
+ public static JSONObject getMetaDataJSON(String fileName) {
+ return getMetaDataJSON(fileName, OUTPUT_DIR);
+ }
+
+ public static JSONObject getMetaDataJSON(String fileName, String outputDir) {
+ String fname = baseDirectory + outputDir + fileName + ".mtd";
+ return new DataExpression().readMetadataFile(fname, false);
+ }
+
+ public static String readDMLMetaDataValue(String fileName, String outputDir, String key) throws JSONException {
+ JSONObject meta = getMetaDataJSON(fileName, outputDir);
+ return meta.get(key).toString();
+ }
+
public static ValueType readDMLMetaDataValueType(String fileName) {
try {
- String fname = baseDirectory + OUTPUT_DIR + fileName + ".mtd";
- JSONObject meta = new DataExpression().readMetadataFile(fname, false);
+ JSONObject meta = getMetaDataJSON(fileName);
return ValueType.fromExternalString(meta.get(DataExpression.VALUETYPEPARAM).toString());
}
catch(Exception ex) {
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index a843552..a23f7b6 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -118,16 +118,7 @@ public class TestUtils
Path compareFile = new Path(expectedFile);
FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf);
FSDataInputStream fsin = fs.open(compareFile);
- try( BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin)) ) {
- String line;
- while ((line = compareIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- expectedValues.put(new CellIndex(i, j), v);
- }
- }
+ readValuesFromFileStream(fsin, expectedValues);
HashMap<CellIndex, Double> actualValues = new HashMap<>();
@@ -135,16 +126,7 @@ public class TestUtils
for (FileStatus file : outFiles) {
FSDataInputStream fsout = fs.open(file.getPath());
- try( BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout)) ) {
- String line = null;
- while ((line = outIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- actualValues.put(new CellIndex(i, j), v);
- }
- }
+ readValuesFromFileStream(fsout, actualValues);
}
ArrayList<Double> e_list = new ArrayList<>();
@@ -208,13 +190,7 @@ public class TestUtils
line = compareIn.readLine();
expRcn = line.split(" ");
- while ((line = compareIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- expectedValues.put(new CellIndex(i, j), v);
- }
+ readValuesFromFileStreamAndPut(compareIn, expectedValues);
}
HashMap<CellIndex, Double> actualValues = new HashMap<>();
@@ -238,14 +214,8 @@ public class TestUtils
else if (Integer.parseInt(expRcn[2]) != Integer.parseInt(rcn[2])) {
System.out.println(" Nnz mismatch: expected " + Integer.parseInt(expRcn[2]) + ", actual " + Integer.parseInt(rcn[2]));
}
-
- while ((line = outIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- actualValues.put(new CellIndex(i, j), v);
- }
+
+ readValuesFromFileStreamAndPut(outIn, actualValues);
}
@@ -270,6 +240,38 @@ public class TestUtils
}
/**
+ * Read doubles from the input stream and put them into the given hashmap of values.
+ * @param inputStream input stream of doubles with related indices
+ * @param values hashmap of values (initially empty)
+ * @throws IOException
+ */
+ public static void readValuesFromFileStream(FSDataInputStream inputStream, HashMap<CellIndex, Double> values)
+ throws IOException
+ {
+ try( BufferedReader inReader = new BufferedReader(new InputStreamReader(inputStream)) ) {
+ readValuesFromFileStreamAndPut(inReader, values);
+ }
+ }
+
+ /**
+ * Read values from file stream and put into hashmap
+ * @param inReader BufferedReader to read values from
+ * @param values hashmap where values are put
+ */
+ public static void readValuesFromFileStreamAndPut(BufferedReader inReader, HashMap<CellIndex, Double> values)
+ throws IOException
+ {
+ String line = null;
+ while ((line = inReader.readLine()) != null) {
+ StringTokenizer st = new StringTokenizer(line, " ");
+ int i = Integer.parseInt(st.nextToken());
+ int j = Integer.parseInt(st.nextToken());
+ double v = Double.parseDouble(st.nextToken());
+ values.put(new CellIndex(i, j), v);
+ }
+ }
+
+ /**
* <p>
* Compares the expected values calculated in Java by testcase and which are
* in the normal filesystem, with those calculated by SystemDS located in
@@ -289,37 +291,17 @@ public class TestUtils
Path outDirectory = new Path(actualDir);
Path compareFile = new Path(expectedFile);
FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf);
- FSDataInputStream fsin = fs.open(compareFile);
+ FSDataInputStream fsin = fs.open(compareFile);
HashMap<CellIndex, Double> expectedValues = new HashMap<>();
-
- try( BufferedReader compareIn = new BufferedReader(new InputStreamReader(fsin)) ) {
- String line;
- while ((line = compareIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- expectedValues.put(new CellIndex(i, j), v);
- }
- }
+ readValuesFromFileStream(fsin, expectedValues);
HashMap<CellIndex, Double> actualValues = new HashMap<>();
-
FileStatus[] outFiles = fs.listStatus(outDirectory);
for (FileStatus file : outFiles) {
FSDataInputStream fsout = fs.open(file.getPath());
- try( BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout)) ) {
- String line = null;
- while ((line = outIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- actualValues.put(new CellIndex(i, j), v);
- }
- }
+ readValuesFromFileStream(fsout, actualValues);
}
int countErrors = 0;
@@ -378,20 +360,11 @@ public class TestUtils
{
Path outDirectory = new Path(filePath);
FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf);
- String line;
FileStatus[] outFiles = fs.listStatus(outDirectory);
for (FileStatus file : outFiles) {
FSDataInputStream outIn = fs.open(file.getPath());
- try(BufferedReader reader = new BufferedReader(new InputStreamReader(outIn)) ) {
- while ((line = reader.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- expectedValues.put(new CellIndex(i,j), v);
- }
- }
+ readValuesFromFileStream(outIn, expectedValues);
}
}
catch (IOException e) {
@@ -1036,33 +1009,17 @@ public class TestUtils
HashMap<CellIndex, Double> expectedValues = new HashMap<>();
HashMap<CellIndex, Double> actualValues = new HashMap<>();
try(BufferedReader compareIn = new BufferedReader(new FileReader(rFile))) {
- String line;
// skip both R header lines
compareIn.readLine();
compareIn.readLine();
- while ((line = compareIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- expectedValues.put(new CellIndex(i, j), v);
- }
+ readValuesFromFileStreamAndPut(compareIn, expectedValues);
}
FileStatus[] outFiles = fs.listStatus(outDirectory);
for (FileStatus file : outFiles) {
FSDataInputStream fsout = fs.open(file.getPath());
- try(BufferedReader outIn = new BufferedReader(new InputStreamReader(fsout))) {
- String line = null;
- while ((line = outIn.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- actualValues.put(new CellIndex(i, j), v);
- }
- }
+ readValuesFromFileStream(fsout, actualValues);
}
int countErrors = 0;
diff --git a/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java b/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java
index 37fc799..e670eab 100644
--- a/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/data/misc/WriteMMTest.java
@@ -119,7 +119,7 @@ public class WriteMMTest extends AutomatedTestBase
input("A"), Integer.toString(rows), Integer.toString(cols), output("B") };
//generate actual dataset
- double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, System.currentTimeMillis());
+ double[][] A = getRandomMatrix(rows, cols, -1, 1, 1, System.currentTimeMillis());
writeInputMatrixWithMTD("A", A, false, new MatrixCharacteristics(rows,cols, 1000, 1000));
writeExpectedMatrixMarket("B", A);
diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java
new file mode 100644
index 0000000..a16355a
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/MatrixMultiplicationPropagationTest.java
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.privacy;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import org.apache.wink.json4j.JSONException;
+import org.junit.Test;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.privacy.PrivacyConstraint;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class MatrixMultiplicationPropagationTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/privacy/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + MatrixMultiplicationPropagationTest.class.getSimpleName() + "/";
+ private final int m = 20;
+ private final int n = 20;
+ private final int k = 20;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration("MatrixMultiplicationPropagationTest",
+ new TestConfiguration(TEST_CLASS_DIR, "MatrixMultiplicationPropagationTest", new String[]{"c"}));
+ }
+
+ @Test
+ public void testMatrixMultiplicationPropagation() throws JSONException {
+ matrixMultiplicationPropagation(true, true);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPropagationFalse() throws JSONException {
+ matrixMultiplicationPropagation(false, true);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPropagationSecondOperand() throws JSONException {
+ matrixMultiplicationPropagation(true, false);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPropagationSecondOperandFalse() throws JSONException {
+ matrixMultiplicationPropagation(false, false);
+ }
+
+ private void matrixMultiplicationPropagation(boolean privacy, boolean privateFirstOperand) throws JSONException {
+
+ TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest");
+ loadTestConfiguration(config);
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml";
+ programArgs = new String[]{"-nvargs",
+ "a=" + input("a"), "b=" + input("b"), "c=" + output("c"),
+ "m=" + m, "n=" + n, "k=" + k};
+
+ double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1);
+ double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1);
+ double[][] c = TestUtils.performMatrixMultiplication(a, b);
+
+ PrivacyConstraint privacyConstraint = new PrivacyConstraint(privacy);
+ MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k);
+
+ if ( privateFirstOperand ) {
+ writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint);
+ writeInputMatrix("b", b);
+ }
+ else {
+ writeInputMatrix("a", a);
+ writeInputMatrixWithMTD("b", b, false, dataCharacteristics, privacyConstraint);
+ }
+
+ writeExpectedMatrix("c", c);
+
+ runTest(true,false,null,-1);
+
+ // Check that the output data is correct
+ compareResults(1e-9);
+
+ // Check that the output metadata is correct
+ String actualPrivacyValue = readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
+ assertEquals(String.valueOf(privacy), actualPrivacyValue);
+ }
+
+ @Test
+ public void testMatrixMultiplicationNoPropagation() {
+ matrixMultiplicationNoPropagation();
+ }
+
+ private void matrixMultiplicationNoPropagation() {
+ TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest");
+ loadTestConfiguration(config);
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + config.getTestScript() + ".dml";
+ programArgs = new String[]{ "-nvargs",
+ "a=" + input("a"), "b=" + input("b"), "c=" + output("c"),
+ "m=" + m, "n=" + n, "k=" + k};
+
+ double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1);
+ double[][] b = getRandomMatrix(n, k, -1, 1, 1, -1);
+ double[][] c = TestUtils.performMatrixMultiplication(a, b);
+
+
+ writeInputMatrix("a", a);
+ writeInputMatrix("b", b);
+ writeExpectedMatrix("c", c);
+
+ runTest(true,false,null,-1);
+
+ // Check that the output data is correct
+ compareResults(1e-9);
+
+ // Check that a JSONException is thrown
+ // because no privacy metadata should be written to c
+ boolean JSONExceptionThrown = false;
+ try{
+ readDMLMetaDataValue("c", OUTPUT_DIR, DataExpression.PRIVACY);
+ } catch (JSONException e){
+ JSONExceptionThrown = true;
+ } catch (Exception e){
+ fail("Exception occured, but JSONException was expected. The exception thrown is: " + e.getMessage());
+ e.printStackTrace();
+ }
+ assert(JSONExceptionThrown);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPrivacyInputTrue() throws JSONException {
+ testMatrixMultiplicationPrivacyInput(true);
+ }
+
+ @Test
+ public void testMatrixMultiplicationPrivacyInputFalse() throws JSONException {
+ testMatrixMultiplicationPrivacyInput(false);
+ }
+
+ private void testMatrixMultiplicationPrivacyInput(boolean privacy) throws JSONException {
+ TestConfiguration config = availableTestConfigurations.get("MatrixMultiplicationPropagationTest");
+ loadTestConfiguration(config);
+
+ double[][] a = getRandomMatrix(m, n, -1, 1, 1, -1);
+
+ PrivacyConstraint privacyConstraint = new PrivacyConstraint();
+ privacyConstraint.setPrivacy(privacy);
+ MatrixCharacteristics dataCharacteristics = new MatrixCharacteristics(m,n,k,k);
+
+ writeInputMatrixWithMTD("a", a, false, dataCharacteristics, privacyConstraint);
+
+ String actualPrivacyValue = readDMLMetaDataValue("a", INPUT_DIR, DataExpression.PRIVACY);
+ assertEquals(String.valueOf(privacy), actualPrivacyValue);
+ }
+}
diff --git a/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml b/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml
new file mode 100644
index 0000000..9705cef
--- /dev/null
+++ b/src/test/scripts/functions/privacy/MatrixMultiplicationPropagationTest.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.tugraz.sysds.test.integration.functions.binary.matrix.MatrixMultiplicationTest.java
+
+A = read($a, rows=$m, cols=$n, format="text");
+B = read($b, rows=$n, cols=$k, format="text");
+C = A %*% B;
+write(C, $c, format="text");
\ No newline at end of file