You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/08/24 16:09:59 UTC
[systemds] branch main updated: [SYSTEMDS-3386] Refactor replacement CP Cleanup
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new cc24dc36c1 [SYSTEMDS-3386] Refactor replacement CP Cleanup
cc24dc36c1 is described below
commit cc24dc36c1d026f3ce96ef0c0285f6fd21c2e2d9
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Thu Aug 18 18:37:35 2022 +0200
[SYSTEMDS-3386] Refactor replacement CP Cleanup
This commit moves the logic of parsing the individual instructions
from CP or SP to individual Fed instructions, giving a cleaner design.
Closes #1680
---
.../fed/AggregateBinaryFEDInstruction.java | 22 +-
.../fed/AggregateTernaryFEDInstruction.java | 24 +-
.../fed/AggregateUnaryFEDInstruction.java | 4 +-
.../instructions/fed/BinaryFEDInstruction.java | 133 +++++-
.../fed/BinaryMatrixMatrixFEDInstruction.java | 4 +-
.../instructions/fed/CastFEDInstruction.java | 14 +-
.../instructions/fed/CtableFEDInstruction.java | 46 +-
.../instructions/fed/FEDInstructionUtils.java | 479 +++------------------
.../instructions/fed/MMChainFEDInstruction.java | 19 +-
...tiReturnParameterizedBuiltinFEDInstruction.java | 28 +-
.../fed/ParameterizedBuiltinFEDInstruction.java | 71 +--
.../fed/QuantileSortFEDInstruction.java | 1 +
.../instructions/fed/QuaternaryFEDInstruction.java | 43 +-
.../instructions/fed/ReorgFEDInstruction.java | 29 +-
.../instructions/fed/SpoofFEDInstruction.java | 24 +-
.../instructions/fed/TernaryFEDInstruction.java | 51 ++-
.../instructions/fed/TsmmFEDInstruction.java | 14 +-
.../instructions/fed/UnaryFEDInstruction.java | 143 +++++-
.../fed/UnaryMatrixFEDInstruction.java | 8 +-
.../instructions/fed/VariableFEDInstruction.java | 19 +-
20 files changed, 629 insertions(+), 547 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
index 6128df20e0..9340e9fb12 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.java
@@ -38,7 +38,6 @@ import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
-import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -55,14 +54,23 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
super(FEDType.AggregateBinary, op, in1, in2, out, opcode, istr, fedOut);
}
- public static AggregateBinaryFEDInstruction parseInstruction(AggregateBinaryCPInstruction instr) {
- return new AggregateBinaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output,
- instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
+ public static AggregateBinaryFEDInstruction parseInstruction(AggregateBinaryCPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() && inst.input2.isMatrix()) {
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ MatrixObject mo2 = ec.getMatrixObject(inst.input2);
+ if((mo1.isFederated(FType.ROW) && mo1.isFederatedExcept(FType.BROADCAST)) ||
+ (mo2.isFederated(FType.ROW) && mo2.isFederatedExcept(FType.BROADCAST)) ||
+ (mo1.isFederated(FType.COL) && mo1.isFederatedExcept(FType.BROADCAST))) {
+ return AggregateBinaryFEDInstruction.parseInstruction(inst);
+ }
+ }
+ return null;
}
- public static AggregateBinaryFEDInstruction parseInstruction(AggregateBinarySPInstruction instr) {
+ private static AggregateBinaryFEDInstruction parseInstruction(AggregateBinaryCPInstruction instr) {
return new AggregateBinaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output,
- instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
+ instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
public static AggregateBinaryFEDInstruction parseInstruction(String str) {
@@ -70,7 +78,7 @@ public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
String opcode = parts[0];
if(!opcode.equalsIgnoreCase("ba+*"))
throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
-
+
InstructionUtils.checkNumFields(parts, 5);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
index 5a54fc6374..f8e8f8ad22 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.java
@@ -48,12 +48,30 @@ public class AggregateTernaryFEDInstruction extends ComputationFEDInstruction {
super(FEDType.AggregateTernary, op, in1, in2, in3, out, opcode, istr, fedOut);
}
- public static AggregateTernaryFEDInstruction parseInstruction(AggregateTernaryCPInstruction instr) {
+ public static AggregateTernaryFEDInstruction parseInstruction(AggregateTernaryCPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST) &&
+ inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) {
+ return parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static AggregateTernaryFEDInstruction parseInstruction(AggregateTernarySPInstruction inst,
+ ExecutionContext ec) {
+ if(inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST) &&
+ inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) {
+ return parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static AggregateTernaryFEDInstruction parseInstruction(AggregateTernaryCPInstruction instr) {
return new AggregateTernaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
- public static AggregateTernaryFEDInstruction parseInstruction(AggregateTernarySPInstruction instr) {
+ private static AggregateTernaryFEDInstruction parseInstruction(AggregateTernarySPInstruction instr) {
return new AggregateTernaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
@@ -79,8 +97,8 @@ public class AggregateTernaryFEDInstruction extends ComputationFEDInstruction {
}
else {
throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
- }
}
+}
@Override
public void processInstruction(ExecutionContext ec) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index b4b729a96d..55554240b9 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -78,7 +78,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
return new AggregateUnaryFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.input3,
instr.output, instr.getOpcode(), instr.getInstructionString());
}
-
+
public static AggregateUnaryFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -101,7 +101,7 @@ public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
fedOut = FederatedOutput.valueOf(parts[5]);
return new AggregateUnaryFEDInstruction(aggun, in1, out, opcode, str, fedOut);
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
if (getOpcode().contains("var")) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
index 20b378ac8e..10a907d78a 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.java
@@ -21,31 +21,126 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.BinaryM.VectorType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CovarianceSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
- protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FederatedOutput fedOut) {
+ protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr, FederatedOutput fedOut) {
super(type, op, in1, in2, out, opcode, istr, fedOut);
}
- protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
+ protected BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand out,
+ String opcode, String istr) {
this(type, op, in1, in2, out, opcode, istr, FederatedOutput.NONE);
}
- public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op,
- CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
+ public BinaryFEDInstruction(FEDInstruction.FEDType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3,
+ CPOperand out, String opcode, String istr) {
super(type, op, in1, in2, in3, out, opcode, istr);
}
+ public static BinaryFEDInstruction parseInstruction(BinaryCPInstruction inst, ExecutionContext ec) {
+ if((inst.input1.isMatrix() && ec.getMatrixObject(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2 != null && inst.input2.isMatrix() &&
+ ec.getMatrixObject(inst.input2).isFederatedExcept(FType.BROADCAST))) {
+ if(inst instanceof AppendCPInstruction)
+ return AppendFEDInstruction.parseInstruction((AppendCPInstruction) inst);
+ else if(inst instanceof QuantilePickCPInstruction)
+ return QuantilePickFEDInstruction.parseInstruction((QuantilePickCPInstruction) inst);
+ else if(inst instanceof CovarianceCPInstruction && (ec.getMatrixObject(inst.input1).isFederated(FType.ROW) ||
+ ec.getMatrixObject(inst.input2).isFederated(FType.ROW)))
+ return CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction) inst);
+ else if(inst instanceof BinaryMatrixMatrixCPInstruction)
+ return BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixCPInstruction) inst);
+ else if(inst instanceof BinaryMatrixScalarCPInstruction)
+ return BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarCPInstruction) inst);
+ }
+ return null;
+ }
+
+ public static BinaryFEDInstruction parseInstruction(BinarySPInstruction inst, ExecutionContext ec) {
+ if(inst instanceof MapmmSPInstruction || inst instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+ return MMFEDInstruction.parseInstruction((AggregateBinarySPInstruction) inst);
+ }
+ }
+ else if(inst instanceof QuantilePickSPInstruction) {
+ QuantilePickSPInstruction qinstruction = (QuantilePickSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ return QuantilePickFEDInstruction.parseInstruction(qinstruction);
+ }
+ else if(inst instanceof AppendGAlignedSPInstruction || inst instanceof AppendGSPInstruction ||
+ inst instanceof AppendMSPInstruction || inst instanceof AppendRSPInstruction) {
+ BinarySPInstruction ainstruction = (BinarySPInstruction) inst;
+ Data data1 = ec.getVariable(ainstruction.input1);
+ Data data2 = ec.getVariable(ainstruction.input2);
+ if((data1 instanceof MatrixObject && ((MatrixObject) data1).isFederatedExcept(FType.BROADCAST)) ||
+ (data2 instanceof MatrixObject && ((MatrixObject) data2).isFederatedExcept(FType.BROADCAST))) {
+ return AppendFEDInstruction.parseInstruction((AppendSPInstruction) inst);
+ }
+ }
+ else if(inst instanceof BinaryMatrixScalarSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+ return BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarSPInstruction) inst);
+ }
+ }
+ else if(inst instanceof BinaryMatrixMatrixSPInstruction) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
+ return BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixSPInstruction) inst);
+ }
+ }
+ else if((inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() && ec.getMatrixObject(inst.input2).isFederatedExcept(FType.BROADCAST))) {
+ if(inst instanceof CovarianceSPInstruction && (ec.getMatrixObject(inst.input1).isFederated(FType.ROW) ||
+ ec.getMatrixObject(inst.input2).isFederated(FType.ROW)))
+ return CovarianceFEDInstruction.parseInstruction((CovarianceSPInstruction) inst);
+ else if(inst instanceof CumulativeOffsetSPInstruction) {
+ return CumulativeOffsetFEDInstruction.parseInstruction((CumulativeOffsetSPInstruction) inst);
+ }
+ else
+ return BinaryFEDInstruction.parseInstruction(InstructionUtils.concatOperands(inst.getInstructionString(),
+ FEDInstruction.FederatedOutput.NONE.name()));
+ }
+ return null;
+ }
+
public static BinaryFEDInstruction parseInstruction(String str) {
+ // TODO remove
if(str.startsWith(ExecType.SPARK.name())) {
// rewrite the spark instruction to a cp instruction
str = rewriteSparkInstructionToCP(str);
@@ -57,20 +152,20 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
- FederatedOutput fedOut = FederatedOutput.valueOf(parts[parts.length-1]);
+ FederatedOutput fedOut = FederatedOutput.valueOf(parts[parts.length - 1]);
checkOutputDataType(in1, in2, out);
Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
- //Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
+ // Operator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2);
// TODO different binary instructions
- if( in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR )
+ if(in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR)
throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
- else if( in1.getDataType() == DataType.MATRIX && in2.getDataType() == DataType.MATRIX )
+ else if(in1.getDataType() == DataType.MATRIX && in2.getDataType() == DataType.MATRIX)
return new BinaryMatrixMatrixFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
- else if( in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.TENSOR )
+ else if(in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.TENSOR)
throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
- else if( in1.isMatrix() && in2.isScalar() || in2.isMatrix() && in1.isScalar() )
+ else if(in1.isMatrix() && in2.isScalar() || in2.isMatrix() && in1.isScalar())
return new BinaryMatrixScalarFEDInstruction(operator, in1, in2, out, opcode, str, fedOut);
else
throw new DMLRuntimeException("Federated binary operations not yet supported:" + opcode);
@@ -78,7 +173,7 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 3, 4 );
+ InstructionUtils.checkNumFields(parts, 3, 4);
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
@@ -86,9 +181,10 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
return opcode;
}
- protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
+ protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3,
+ CPOperand out) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr);
- InstructionUtils.checkNumFields ( parts, 4 );
+ InstructionUtils.checkNumFields(parts, 4);
String opcode = parts[0];
in1.split(parts[1]);
in2.split(parts[2]);
@@ -99,9 +195,10 @@ public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
protected static void checkOutputDataType(CPOperand in1, CPOperand in2, CPOperand out) {
// check for valid data type of output
- if( (in1.getDataType() == DataType.MATRIX || in2.getDataType() == DataType.MATRIX) && out.getDataType() != DataType.MATRIX )
- throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() +
- " and " + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
+ if((in1.getDataType() == DataType.MATRIX || in2.getDataType() == DataType.MATRIX) &&
+ out.getDataType() != DataType.MATRIX)
+ throw new DMLRuntimeException("Element-wise matrix operations between variables " + in1.getName() + " and "
+ + in2.getName() + " must produce a matrix, which " + out.getName() + " is not");
}
protected static String rewriteSparkInstructionToCP(String inst_str) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 250b5d193a..fb8455ea9e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -42,12 +42,12 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
super(FEDType.Binary, op, in1, in2, out, opcode, istr, fedOut);
}
- public static BinaryMatrixMatrixFEDInstruction parseInstruction(BinaryMatrixMatrixCPInstruction instr) {
+ protected static BinaryMatrixMatrixFEDInstruction parseInstruction(BinaryMatrixMatrixCPInstruction instr) {
return new BinaryMatrixMatrixFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output,
instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
- public static BinaryMatrixMatrixFEDInstruction parseInstruction(BinaryMatrixMatrixSPInstruction instr) {
+ protected static BinaryMatrixMatrixFEDInstruction parseInstruction(BinaryMatrixMatrixSPInstruction instr) {
String instrStr = rewriteSparkInstructionToCP(instr.getInstructionString());
String opcode = InstructionUtils.getInstructionPartsWithValueType(instrStr)[0];
return new BinaryMatrixMatrixFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output,
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
index df2fe11e12..89edb9a4dc 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CastFEDInstruction.java
@@ -28,6 +28,7 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -49,12 +50,21 @@ public class CastFEDInstruction extends UnaryFEDInstruction {
super(FEDInstruction.FEDType.Cast, op, in, out, opcode, istr);
}
- public static CastFEDInstruction parseInstruction(CastSPInstruction spInstruction) {
+ public static CastFEDInstruction parseInstruction(CastSPInstruction inst, ExecutionContext ec) {
+ if((inst.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString()) ||
+ inst.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_MATRIX.toString())) && inst.input1.isMatrix() &&
+ ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return CastFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static CastFEDInstruction parseInstruction(CastSPInstruction spInstruction) {
return new CastFEDInstruction(spInstruction.getOperator(), spInstruction.input1, spInstruction.output,
spInstruction.getOpcode(), spInstruction.getInstructionString());
}
- public static CastFEDInstruction parseInstruction ( String str ) {
+ public static CastFEDInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 2);
String opcode = parts[0];
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 3f87668492..0ca04788e2 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -21,16 +21,17 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.Collections;
-import java.util.concurrent.Future;
import java.util.Iterator;
import java.util.SortedMap;
-import java.util.stream.IntStream;
import java.util.TreeMap;
+import java.util.concurrent.Future;
+import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -56,25 +57,38 @@ import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
public class CtableFEDInstruction extends ComputationFEDInstruction {
private final CPOperand _outDim1;
private final CPOperand _outDim2;
- //private final boolean _isExpand;
- //private final boolean _ignoreZeros;
private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, CPOperand outDim1,
CPOperand outDim2, boolean isExpand, boolean ignoreZeros, String opcode, String istr) {
super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
_outDim1 = outDim1;
_outDim2 = outDim2;
- // _isExpand = isExpand;
- // _ignoreZeros = ignoreZeros;
}
- public static CtableFEDInstruction parseInstruction(CtableCPInstruction instr) {
+ public static CtableFEDInstruction parseInstruction(CtableCPInstruction inst, ExecutionContext ec) {
+ if((inst.getOpcode().equalsIgnoreCase("ctable") || inst.getOpcode().equalsIgnoreCase("ctableexpand")) &&
+ (ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
+ (inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederated(FType.ROW)) ||
+ (inst.input3.isMatrix() && ec.getCacheableData(inst.input3).isFederated(FType.ROW))))
+ return CtableFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ private static CtableFEDInstruction parseInstruction(CtableCPInstruction instr) {
return new CtableFEDInstruction(instr.input1, instr.input2, instr.input3, instr.output, instr.getOutDim1(),
instr.getOutDim2(), instr.getIsExpand(), instr.getIgnoreZeros(), instr.getOpcode(),
instr.getInstructionString());
}
+
+ public static CtableFEDInstruction parseInstruction(CtableSPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equalsIgnoreCase("ctable") && (ec.getCacheableData(inst.input1).isFederated(FType.ROW) ||
+ (inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederated(FType.ROW)) ||
+ (inst.input3.isMatrix() && ec.getCacheableData(inst.input3).isFederated(FType.ROW))))
+ return CtableFEDInstruction.parseInstruction(inst);
+ return null;
+ }
- public static CtableFEDInstruction parseInstruction(CtableSPInstruction instr) {
+ private static CtableFEDInstruction parseInstruction(CtableSPInstruction instr) {
return new CtableFEDInstruction(instr.input1, instr.input2, instr.input3, instr.output, instr.getOutDim1(),
instr.getOutDim2(), instr.getIsExpand(), instr.getIgnoreZeros(), instr.getOpcode(),
instr.getInstructionString());
@@ -83,33 +97,27 @@ public class CtableFEDInstruction extends ComputationFEDInstruction {
public static CtableFEDInstruction parseInstruction(String inst) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
InstructionUtils.checkNumFields(parts, 7);
-
String opcode = parts[0];
-
- //handle opcode
+ // handle opcode
if(!(opcode.equalsIgnoreCase("ctable")) && !(opcode.equalsIgnoreCase("ctableexpand"))) {
throw new DMLRuntimeException("Unexpected opcode in CtableFEDInstruction: " + inst);
}
-
- //handle operands
+ // handle operands
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
-
- //handle known dimension information
+ // handle known dimension information
String[] dim1Fields = parts[4].split(Instruction.LITERAL_PREFIX);
String[] dim2Fields = parts[5].split(Instruction.LITERAL_PREFIX);
-
CPOperand out = new CPOperand(parts[6]);
boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
-
+
boolean dim1Literal = Boolean.parseBoolean(dim1Fields[1]);
CPOperand outDim1 = new CPOperand(dim1Fields[0], ValueType.FP64, DataType.SCALAR, dim1Literal);
boolean dim2Literal = Boolean.parseBoolean(dim2Fields[1]);
CPOperand outDim2 = new CPOperand(dim2Fields[0], ValueType.FP64, DataType.SCALAR, dim2Literal);
// ctable does not require any operator, so we simply pass-in a dummy operator with null functionobject
- return new CtableFEDInstruction(in1,
- in2, in3, out, outDim1, outDim2, false, ignoreZeros, opcode, inst);
+ return new CtableFEDInstruction(in1, in2, in3, out, outDim1, outDim2, false, ignoreZeros, opcode, inst);
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index eb93ff9d59..c5ada38032 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,476 +19,127 @@
package org.apache.sysds.runtime.instructions.fed;
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.hops.fedplanner.FTypes.FType;
-import org.apache.sysds.runtime.codegen.SpoofCellwise;
-import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
-import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
-import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.AppendCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
-import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.QuantilePickCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
import org.apache.sysds.runtime.instructions.cp.SpoofCPInstruction;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
-import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendRSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.AppendSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.BinaryMatrixMatrixSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.BinarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.CastSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CovarianceSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CtableSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MultiReturnParameterizedBuiltinSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.QuantilePickSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.RmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SpoofSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
-import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
public class FEDInstructionUtils {
-
- private static final String[] PARAM_BUILTINS = new String[]{
- "replace", "rmempty", "lowertri", "uppertri", "transformdecode", "transformapply", "tokenize"};
public static boolean noFedRuntimeConversion = false;
-
- // private static final Log LOG = LogFactory.getLog(FEDInstructionUtils.class.getName());
-
- // This is currently a rather simplistic to our solution of replacing instructions with their correct federated
- // counterpart, since we do not propagate the information that a matrix is federated, therefore we can not decide
- // to choose a federated instruction earlier.
/**
* Check and replace CP instructions with federated instructions if the instruction match criteria.
*
- * @param inst The instruction to analyse
- * @param ec The Execution Context
+ * @param inst The instruction to analyze
+ * @param ec The Execution Context
* @return The potentially modified instruction
*/
public static Instruction checkAndReplaceCP(Instruction inst, ExecutionContext ec) {
- if ( !noFedRuntimeConversion ){
- FEDInstruction fedinst = null;
- if (inst instanceof AggregateBinaryCPInstruction) {
- AggregateBinaryCPInstruction instruction = (AggregateBinaryCPInstruction) inst;
- if( instruction.input1.isMatrix() && instruction.input2.isMatrix()) {
- MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
- MatrixObject mo2 = ec.getMatrixObject(instruction.input2);
- if ( (mo1.isFederated(FType.ROW) && mo1.isFederatedExcept(FType.BROADCAST))
- || (mo2.isFederated(FType.ROW) && mo2.isFederatedExcept(FType.BROADCAST))
- || (mo1.isFederated(FType.COL) && mo1.isFederatedExcept(FType.BROADCAST))) {
- fedinst = AggregateBinaryFEDInstruction.parseInstruction(instruction);
- }
- }
- }
- else if( inst instanceof MMChainCPInstruction) {
- MMChainCPInstruction linst = (MMChainCPInstruction) inst;
- MatrixObject mo = ec.getMatrixObject(linst.input1);
- if( mo.isFederated(FType.ROW) )
- fedinst = MMChainFEDInstruction.parseInstruction(linst);
- }
- else if( inst instanceof MMTSJCPInstruction ) {
- MMTSJCPInstruction linst = (MMTSJCPInstruction) inst;
- MatrixObject mo = ec.getMatrixObject(linst.input1);
- if( (mo.isFederated(FType.ROW) && mo.isFederatedExcept(FType.BROADCAST) && linst.getMMTSJType().isLeft()) ||
- (mo.isFederated(FType.COL) && mo.isFederatedExcept(FType.BROADCAST) && linst.getMMTSJType().isRight()))
- fedinst = TsmmFEDInstruction.parseInstruction(linst);
- }
- else if (inst instanceof UnaryCPInstruction && ! (inst instanceof IndexingCPInstruction)) {
- UnaryCPInstruction instruction = (UnaryCPInstruction) inst;
- if(inst instanceof ReorgCPInstruction && (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
- || inst.getOpcode().equals("rev"))) {
- ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
- CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+ if(noFedRuntimeConversion)
+ return inst;
- if((mo instanceof MatrixObject || mo instanceof FrameObject)
- && mo.isFederatedExcept(FType.BROADCAST) )
- fedinst = ReorgFEDInstruction.parseInstruction(rinst);
- }
- else if(instruction.input1 != null && instruction.input1.isMatrix()
- && ec.containsVariable(instruction.input1)) {
+ FEDInstruction fedinst = null;
- MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
- if( mo1.isFederatedExcept(FType.BROADCAST) ) {
- if(instruction instanceof CentralMomentCPInstruction)
- fedinst = CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction) inst);
- else if(inst instanceof QuantileSortCPInstruction) {
- if(mo1.isFederated(FType.ROW) || mo1.getFedMapping().getFederatedRanges().length == 1 && mo1.isFederated(FType.COL))
- fedinst = QuantileSortFEDInstruction.parseInstruction((QuantileSortCPInstruction) inst);
- }
- else if(inst instanceof ReshapeCPInstruction)
- fedinst = ReshapeFEDInstruction.parseInstruction((ReshapeCPInstruction) inst);
- else if(inst instanceof AggregateUnaryCPInstruction &&
- ((AggregateUnaryCPInstruction) instruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
- fedinst = AggregateUnaryFEDInstruction.parseInstruction((AggregateUnaryCPInstruction) inst);
- else if(inst instanceof UnaryMatrixCPInstruction) {
- if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()) &&
- !(inst.getOpcode().equalsIgnoreCase("ucumk+*") && mo1.isFederated(FType.COL)))
- fedinst = UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixCPInstruction) inst);
- }
- }
- }
- }
- else if (inst instanceof BinaryCPInstruction) {
- BinaryCPInstruction instruction = (BinaryCPInstruction) inst;
- if((instruction.input1.isMatrix() &&
- ec.getMatrixObject(instruction.input1).isFederatedExcept(FType.BROADCAST)) ||
- (instruction.input2 != null && instruction.input2.isMatrix() &&
- ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
- if(instruction instanceof AppendCPInstruction)
- fedinst = AppendFEDInstruction.parseInstruction((AppendCPInstruction) inst);
- else if(instruction instanceof QuantilePickCPInstruction)
- fedinst = QuantilePickFEDInstruction.parseInstruction((QuantilePickCPInstruction) inst);
- else if(instruction instanceof CovarianceCPInstruction &&
- (ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
- ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst = CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction) inst);
- else if(instruction instanceof BinaryMatrixMatrixCPInstruction)
- fedinst = BinaryMatrixMatrixFEDInstruction
- .parseInstruction((BinaryMatrixMatrixCPInstruction) inst);
- else if(instruction instanceof BinaryMatrixScalarCPInstruction)
- fedinst = BinaryMatrixScalarFEDInstruction
- .parseInstruction((BinaryMatrixScalarCPInstruction) inst);
- }
- }
- else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
- ParameterizedBuiltinCPInstruction pinst = (ParameterizedBuiltinCPInstruction) inst;
- if( ArrayUtils.contains(PARAM_BUILTINS, pinst.getOpcode()) && pinst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
- fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst);
- }
- else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) {
- MultiReturnParameterizedBuiltinCPInstruction minst = (MultiReturnParameterizedBuiltinCPInstruction) inst;
- if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
- CacheableData<?> fo = ec.getCacheableData(minst.input1);
- if(fo.isFederatedExcept(FType.BROADCAST)) {
- fedinst = MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(minst);
- }
- }
- }
- else if(inst instanceof IndexingCPInstruction) {
- // matrix and frame indexing
- IndexingCPInstruction minst = (IndexingCPInstruction) inst;
- if((minst.input1.isMatrix() || minst.input1.isFrame())
- && ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
- fedinst = IndexingFEDInstruction.parseInstruction(minst);
- }
- }
- else if(inst instanceof TernaryCPInstruction) {
- TernaryCPInstruction tinst = (TernaryCPInstruction) inst;
- if(inst.getOpcode().equals("_map") && inst instanceof TernaryFrameScalarCPInstruction && !inst.getInstructionString().contains("UtilFunctions")
- && tinst.input1.isFrame() && ec.getFrameObject(tinst.input1).isFederated()) {
- long margin = ec.getScalarInput(tinst.input3).getLongValue();
- FrameObject fo = ec.getFrameObject(tinst.input1);
- if(margin == 0 || (fo.isFederated(FType.ROW) && margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
- fedinst = TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarCPInstruction) inst);
- }
- else if((tinst.input1.isMatrix() && ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
- || (tinst.input2.isMatrix() && ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
- || (tinst.input3.isMatrix() && ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
- fedinst = TernaryFEDInstruction.parseInstruction(tinst);
- }
- }
- else if(inst instanceof VariableCPInstruction ){
- VariableCPInstruction ins = (VariableCPInstruction) inst;
- if(ins.getVariableOpcode() == VariableOperationCode.Write
- && ins.getInput1().isMatrix()
- && ins.getInput3().getName().contains("federated")){
- fedinst = VariableFEDInstruction.parseInstruction(ins);
- }
- else if(ins.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable
- && ins.getInput1().isMatrix()
- && ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){
- fedinst = VariableFEDInstruction.parseInstruction(ins);
- }
- else if(ins.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable
- && ins.getInput1().isFrame()
- && ec.getCacheableData(ins.getInput1()).isFederatedExcept(FType.BROADCAST)){
- fedinst = VariableFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof AggregateTernaryCPInstruction){
- AggregateTernaryCPInstruction ins = (AggregateTernaryCPInstruction) inst;
- if(ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)
- && ins.input2.isMatrix() && ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
- fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof QuaternaryCPInstruction) {
- QuaternaryCPInstruction instruction = (QuaternaryCPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst = QuaternaryFEDInstruction.parseInstruction(instruction);
- }
- else if(inst instanceof SpoofCPInstruction) {
- SpoofCPInstruction ins = (SpoofCPInstruction) inst;
- Class<?> scla = ins.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
- && SpoofFEDInstruction.isFederated(ec, ins.getInputs(), scla))
- || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
- fedinst = SpoofFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof CtableCPInstruction) {
- CtableCPInstruction cinst = (CtableCPInstruction) inst;
- if((inst.getOpcode().equalsIgnoreCase("ctable") || inst.getOpcode().equalsIgnoreCase("ctableexpand"))
- && ( ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
- || (cinst.input2.isMatrix() && ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
- || (cinst.input3.isMatrix() && ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
- fedinst = CtableFEDInstruction.parseInstruction(cinst);
- }
+ if(inst instanceof AggregateBinaryCPInstruction)
+ fedinst = AggregateBinaryFEDInstruction.parseInstruction((AggregateBinaryCPInstruction) inst, ec);
+ else if(inst instanceof MMChainCPInstruction)
+ fedinst = MMChainFEDInstruction.parseInstruction((MMChainCPInstruction) inst, ec);
+ else if(inst instanceof MMTSJCPInstruction)
+ fedinst = TsmmFEDInstruction.parseInstruction((MMTSJCPInstruction) inst, ec);
+ else if(inst instanceof UnaryCPInstruction)
+ fedinst = UnaryFEDInstruction.parseInstruction((UnaryCPInstruction) inst, ec);
+ else if(inst instanceof BinaryCPInstruction)
+ fedinst = BinaryFEDInstruction.parseInstruction((BinaryCPInstruction) inst, ec);
+ else if(inst instanceof ParameterizedBuiltinCPInstruction)
+ fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction((ParameterizedBuiltinCPInstruction) inst, ec);
+ else if(inst instanceof MultiReturnParameterizedBuiltinCPInstruction)
+ fedinst = MultiReturnParameterizedBuiltinFEDInstruction
+ .parseInstruction((MultiReturnParameterizedBuiltinCPInstruction) inst, ec);
+ else if(inst instanceof TernaryCPInstruction)
+ fedinst = TernaryFEDInstruction.parseInstruction((TernaryCPInstruction) inst, ec);
+ else if(inst instanceof VariableCPInstruction)
+ fedinst = VariableFEDInstruction.parseInstruction((VariableCPInstruction) inst, ec);
+ else if(inst instanceof AggregateTernaryCPInstruction)
+ fedinst = AggregateTernaryFEDInstruction.parseInstruction((AggregateTernaryCPInstruction) inst, ec);
+ else if(inst instanceof QuaternaryCPInstruction)
+ fedinst = QuaternaryFEDInstruction.parseInstruction((QuaternaryCPInstruction) inst, ec);
+ else if(inst instanceof SpoofCPInstruction)
+ fedinst = SpoofFEDInstruction.parseInstruction((SpoofCPInstruction) inst, ec);
+ else if(inst instanceof CtableCPInstruction)
+ fedinst = CtableFEDInstruction.parseInstruction((CtableCPInstruction) inst, ec);
- //set thread id for federated context management
- if( fedinst != null ) {
- fedinst.setTID(ec.getTID());
- return fedinst;
- }
+ // set thread id for federated context management
+ if(fedinst != null) {
+ fedinst.setTID(ec.getTID());
+ return fedinst;
}
-
+
return inst;
+
}
public static Instruction checkAndReplaceSP(Instruction inst, ExecutionContext ec) {
+ if(noFedRuntimeConversion)
+ return inst;
FEDInstruction fedinst = null;
- if(inst instanceof CastSPInstruction){
- CastSPInstruction ins = (CastSPInstruction) inst;
- if((ins.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_FRAME.toString())
- || ins.getOpcode().equalsIgnoreCase(OpOp1.CAST_AS_MATRIX.toString()))
- && ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST)){
- fedinst = CastFEDInstruction.parseInstruction(ins);
- }
- }
- else if (inst instanceof WriteSPInstruction) {
+ if(inst instanceof CastSPInstruction)
+ fedinst = CastFEDInstruction.parseInstruction((CastSPInstruction) inst, ec);
+ else if(inst instanceof WriteSPInstruction) {
WriteSPInstruction instruction = (WriteSPInstruction) inst;
Data data = ec.getVariable(instruction.input1);
- if (data instanceof CacheableData && ((CacheableData<?>) data).isFederated()) {
+ if(data instanceof CacheableData && ((CacheableData<?>) data).isFederated()) {
// Write spark instruction can not be executed for federated matrix objects (tries to get rdds which do
// not exist), therefore we replace the instruction with the VariableCPInstruction.
return VariableCPInstruction.parseInstruction(instruction.getInstructionString());
}
}
- else if(inst instanceof QuaternarySPInstruction) {
- QuaternarySPInstruction instruction = (QuaternarySPInstruction) inst;
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
- fedinst = QuaternaryFEDInstruction.parseInstruction(instruction);
- }
- else if(inst instanceof SpoofSPInstruction) {
- SpoofSPInstruction ins = (SpoofSPInstruction) inst;
- Class<?> scla = ins.getOperatorClass().getSuperclass();
- if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
- && SpoofFEDInstruction.isFederated(ec, ins.getInputs(), scla))
- || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, ins.getInputs(), scla))) {
- fedinst = SpoofFEDInstruction.parseInstruction(ins);
- }
- }
- else if (inst instanceof UnarySPInstruction && ! (inst instanceof IndexingSPInstruction)) {
- UnarySPInstruction instruction = (UnarySPInstruction) inst;
- if (inst instanceof CentralMomentSPInstruction) {
- CentralMomentSPInstruction cinstruction = (CentralMomentSPInstruction) inst;
- Data data = ec.getVariable(cinstruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst = CentralMomentFEDInstruction.parseInstruction(cinstruction);
- } else if (inst instanceof QuantileSortSPInstruction) {
- QuantileSortSPInstruction qinstruction = (QuantileSortSPInstruction) inst;
- Data data = ec.getVariable(qinstruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst = QuantileSortFEDInstruction.parseInstruction(qinstruction);
- }
- else if (inst instanceof AggregateUnarySPInstruction) {
- AggregateUnarySPInstruction auinstruction = (AggregateUnarySPInstruction) inst;
- Data data = ec.getVariable(auinstruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederated() && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- if(ArrayUtils.contains(new String[]{"uarimin", "uarimax"}, auinstruction.getOpcode())) {
- if(((MatrixObject) data).getFedMapping().getType() == FType.ROW)
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
- }
- else
- fedinst = AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
- }
- else if(inst instanceof ReorgSPInstruction && (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag")
- || inst.getOpcode().equals("rev"))) {
- ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
- CacheableData<?> mo = ec.getCacheableData(rinst.input1);
- if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() && mo.isFederatedExcept(FType.BROADCAST))
- fedinst = ReorgFEDInstruction.parseInstruction(rinst);
- }
- else if(inst instanceof ReblockSPInstruction && instruction.input1 != null && (instruction.input1.isFrame() || instruction.input1.isMatrix())) {
- ReblockSPInstruction rinst = (ReblockSPInstruction) instruction;
- CacheableData<?> data = ec.getCacheableData(rinst.input1);
- if(data.isFederatedExcept(FType.BROADCAST))
- fedinst = ReblockFEDInstruction.parseInstruction((ReblockSPInstruction) inst);
- }
- else if(instruction.input1 != null && instruction.input1.isMatrix() && ec.containsVariable(instruction.input1)) {
- MatrixObject mo1 = ec.getMatrixObject(instruction.input1);
- if(mo1.isFederatedExcept(FType.BROADCAST)) {
- if(instruction.getOpcode().equalsIgnoreCase("cm"))
- fedinst = CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction)inst);
- else if(inst.getOpcode().equalsIgnoreCase("qsort")) {
- if(mo1.getFedMapping().getFederatedRanges().length == 1)
- fedinst = QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
- }
- else if(inst.getOpcode().equalsIgnoreCase("rshape")) {
- fedinst = ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
- }
- else if(inst instanceof UnaryMatrixSPInstruction) {
- if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
- fedinst = UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixSPInstruction) inst);
- }
- }
- }
- }
- else if (inst instanceof BinarySPInstruction) {
- BinarySPInstruction instruction = (BinarySPInstruction) inst;
- if (inst instanceof MapmmSPInstruction || inst instanceof CpmmSPInstruction || inst instanceof RmmSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if (data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST)) {
- fedinst = MMFEDInstruction.parseInstruction((AggregateBinarySPInstruction) instruction);
- }
- }
- else
- if(inst instanceof QuantilePickSPInstruction) {
- QuantilePickSPInstruction qinstruction = (QuantilePickSPInstruction) inst;
- Data data = ec.getVariable(qinstruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
- fedinst = QuantilePickFEDInstruction.parseInstruction(qinstruction);
- }
- else if (inst instanceof AppendGAlignedSPInstruction || inst instanceof AppendGSPInstruction
- || inst instanceof AppendMSPInstruction || inst instanceof AppendRSPInstruction) {
- BinarySPInstruction ainstruction = (BinarySPInstruction) inst;
- Data data1 = ec.getVariable(ainstruction.input1);
- Data data2 = ec.getVariable(ainstruction.input2);
- if ((data1 instanceof MatrixObject && ((MatrixObject) data1).isFederatedExcept(FType.BROADCAST))
- || (data2 instanceof MatrixObject && ((MatrixObject) data2).isFederatedExcept(FType.BROADCAST))) {
- fedinst = AppendFEDInstruction.parseInstruction((AppendSPInstruction) instruction);
- }
- }
- else if (inst instanceof BinaryMatrixScalarSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FType.BROADCAST)) {
- fedinst = BinaryMatrixScalarFEDInstruction.parseInstruction((BinaryMatrixScalarSPInstruction) inst);
- }
- }
- else if (inst instanceof BinaryMatrixMatrixSPInstruction) {
- Data data = ec.getVariable(instruction.input1);
- if(data instanceof MatrixObject && ((MatrixObject)data).isFederatedExcept(FType.BROADCAST)) {
- fedinst = BinaryMatrixMatrixFEDInstruction.parseInstruction((BinaryMatrixMatrixSPInstruction) inst);
- }
- }
- else if( (instruction.input1.isMatrix() && ec.getCacheableData(instruction.input1).isFederatedExcept(FType.BROADCAST))
- || (instruction.input2.isMatrix() && ec.getMatrixObject(instruction.input2).isFederatedExcept(FType.BROADCAST))) {
- if(inst instanceof CovarianceSPInstruction && (ec.getMatrixObject(instruction.input1)
- .isFederated(FType.ROW) || ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst = CovarianceFEDInstruction.parseInstruction((CovarianceSPInstruction) inst);
- else if(inst instanceof CumulativeOffsetSPInstruction) {
- fedinst = CumulativeOffsetFEDInstruction.parseInstruction((CumulativeOffsetSPInstruction) inst);
- }
- else
- fedinst = BinaryFEDInstruction.parseInstruction(InstructionUtils
- .concatOperands(inst.getInstructionString(), FEDInstruction.FederatedOutput.NONE.name()));
- }
- }
- else if( inst instanceof ParameterizedBuiltinSPInstruction) {
- ParameterizedBuiltinSPInstruction pinst = (ParameterizedBuiltinSPInstruction) inst;
- if( pinst.getOpcode().equalsIgnoreCase("replace") && pinst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
- fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction(pinst);
- }
- else if (inst instanceof MultiReturnParameterizedBuiltinSPInstruction) {
- MultiReturnParameterizedBuiltinSPInstruction minst = (MultiReturnParameterizedBuiltinSPInstruction) inst;
- if(minst.getOpcode().equals("transformencode") && minst.input1.isFrame()) {
- CacheableData<?> fo = ec.getCacheableData(minst.input1);
- if(fo.isFederatedExcept(FType.BROADCAST)) {
- fedinst = MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(minst);
- }
- }
- }
- else if(inst instanceof IndexingSPInstruction) {
- // matrix and frame indexing
- IndexingSPInstruction minst = (IndexingSPInstruction) inst;
- if((minst.input1.isMatrix() || minst.input1.isFrame())
- && ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
- fedinst = IndexingFEDInstruction.parseInstruction(minst);
- }
- }
- else if(inst instanceof TernarySPInstruction) {
- TernarySPInstruction tinst = (TernarySPInstruction) inst;
- if(inst.getOpcode().equals("_map") && inst instanceof TernaryFrameScalarSPInstruction && !inst.getInstructionString().contains("UtilFunctions")
- && tinst.input1.isFrame() && ec.getFrameObject(tinst.input1).isFederated()) {
- long margin = ec.getScalarInput(tinst.input3).getLongValue();
- FrameObject fo = ec.getFrameObject(tinst.input1);
- if(margin == 0 || (fo.isFederated(FType.ROW) && margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
- fedinst = TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarSPInstruction) tinst);
- } else if((tinst.input1.isMatrix() && ec.getCacheableData(tinst.input1).isFederatedExcept(FType.BROADCAST))
- || (tinst.input2.isMatrix() && ec.getCacheableData(tinst.input2).isFederatedExcept(FType.BROADCAST))
- || (tinst.input3.isMatrix() && ec.getCacheableData(tinst.input3).isFederatedExcept(FType.BROADCAST))) {
- fedinst = TernaryFEDInstruction.parseInstruction(tinst);
- }
- }
- else if(inst instanceof AggregateTernarySPInstruction){
- AggregateTernarySPInstruction ins = (AggregateTernarySPInstruction) inst;
- if(ins.input1.isMatrix() && ec.getCacheableData(ins.input1).isFederatedExcept(FType.BROADCAST) && ins.input2.isMatrix() &&
- ec.getCacheableData(ins.input2).isFederatedExcept(FType.BROADCAST)) {
- fedinst = AggregateTernaryFEDInstruction.parseInstruction(ins);
- }
- }
- else if(inst instanceof CtableSPInstruction) {
- CtableSPInstruction cinst = (CtableSPInstruction) inst;
- if(inst.getOpcode().equalsIgnoreCase("ctable")
- && ( ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
- || (cinst.input2.isMatrix() && ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
- || (cinst.input3.isMatrix() && ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
- fedinst = CtableFEDInstruction.parseInstruction(cinst);
- }
+ else if(inst instanceof QuaternarySPInstruction)
+ fedinst = QuaternaryFEDInstruction.parseInstruction((QuaternarySPInstruction) inst, ec);
+ else if(inst instanceof SpoofSPInstruction)
+ fedinst = SpoofFEDInstruction.parseInstruction((SpoofSPInstruction) inst, ec);
+ else if(inst instanceof UnarySPInstruction)
+ fedinst = UnaryFEDInstruction.parseInstruction((UnarySPInstruction) inst, ec);
+ else if(inst instanceof BinarySPInstruction)
+ fedinst = BinaryFEDInstruction.parseInstruction((BinarySPInstruction) inst, ec);
+ else if(inst instanceof ParameterizedBuiltinSPInstruction)
+ fedinst = ParameterizedBuiltinFEDInstruction.parseInstruction((ParameterizedBuiltinSPInstruction) inst, ec);
+ else if(inst instanceof MultiReturnParameterizedBuiltinSPInstruction)
+ fedinst = MultiReturnParameterizedBuiltinFEDInstruction
+ .parseInstruction((MultiReturnParameterizedBuiltinSPInstruction) inst, ec);
+ else if(inst instanceof TernarySPInstruction)
+ fedinst = TernaryFEDInstruction.parseInstruction((TernarySPInstruction) inst, ec);
+ else if(inst instanceof AggregateTernarySPInstruction)
+ fedinst = AggregateTernaryFEDInstruction.parseInstruction((AggregateTernarySPInstruction) inst, ec);
+ else if(inst instanceof CtableSPInstruction)
+ fedinst = CtableFEDInstruction.parseInstruction((CtableSPInstruction) inst, ec);
- //set thread id for federated context management
- if( fedinst != null ) {
+ // set thread id for federated context management
+ if(fedinst != null) {
fedinst.setTID(ec.getTID());
return fedinst;
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
index 5ddc46d899..cf5af3d9c8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MMChainFEDInstruction.java
@@ -19,23 +19,24 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.concurrent.Future;
+
import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.MapMultChain.ChainType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import java.util.concurrent.Future;
-
public class MMChainFEDInstruction extends UnaryFEDInstruction {
public MMChainFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3,
@@ -50,7 +51,15 @@ public class MMChainFEDInstruction extends UnaryFEDInstruction {
return _type;
}
- public static MMChainFEDInstruction parseInstruction(MMChainCPInstruction instr) {
+ public static MMChainFEDInstruction parseInstruction(MMChainCPInstruction inst, ExecutionContext ec) {
+ MMChainCPInstruction linst = (MMChainCPInstruction) inst;
+ MatrixObject mo = ec.getMatrixObject(linst.input1);
+ if( mo.isFederated(FType.ROW) )
+ return MMChainFEDInstruction.parseInstruction(linst);
+ return null;
+ }
+
+ private static MMChainFEDInstruction parseInstruction(MMChainCPInstruction instr) {
return new MMChainFEDInstruction(instr.input1, instr.input2, instr.input3, instr.output, instr.getMMChainType(),
instr.getNumThreads(), instr.getOpcode(), instr.getInstructionString());
}
@@ -62,7 +71,7 @@ public class MMChainFEDInstruction extends UnaryFEDInstruction {
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
-
+
if( parts.length==6 ) {
CPOperand out= new CPOperand(parts[3]);
ChainType type = ChainType.valueOf(parts[4]);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 16471a1497..c9135eb013 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -35,8 +35,10 @@ import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -79,15 +81,35 @@ public class MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
}
public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(
+ MultiReturnParameterizedBuiltinCPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("transformencode") && inst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(inst.input1);
+ if(fo.isFederatedExcept(FType.BROADCAST))
+ return MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(
+ MultiReturnParameterizedBuiltinSPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("transformencode") && inst.input1.isFrame()) {
+ CacheableData<?> fo = ec.getCacheableData(inst.input1);
+ if(fo.isFederatedExcept(FType.BROADCAST))
+ return MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(
MultiReturnParameterizedBuiltinCPInstruction instr) {
return new MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.input1, instr.input2,
instr.getOutputs(), instr.getOpcode(), instr.getInstructionString());
}
- public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(
- MultiReturnParameterizedBuiltinSPInstruction instr) {
+ private static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(
+ MultiReturnParameterizedBuiltinSPInstruction instr) {
return new MultiReturnParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.input1, instr.input2,
- instr.getOutputs(), instr.getOpcode(), instr.getInstructionString());
+ instr.getOutputs(), instr.getOpcode(), instr.getInstructionString());
}
public static MultiReturnParameterizedBuiltinFEDInstruction parseInstruction(String str) {
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index c0b60c557a..b9794b413e 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -84,12 +84,40 @@ import org.apache.sysds.runtime.util.UtilFunctions;
public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstruction {
protected final HashMap<String, String> params;
+ private static final String[] PARAM_BUILTINS = new String[]{
+ "replace", "rmempty", "lowertri", "uppertri", "transformdecode", "transformapply", "tokenize"};
+
+
protected ParameterizedBuiltinFEDInstruction(Operator op, HashMap<String, String> paramsMap, CPOperand out,
String opcode, String istr) {
super(FEDType.ParameterizedBuiltin, op, null, null, out, opcode, istr);
params = paramsMap;
}
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(String str) {
+ String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+ // first part is always the opcode
+ String opcode = parts[0];
+ // last part is always the output
+ CPOperand out = new CPOperand(parts[parts.length - 1]);
+
+ // process remaining parts and build a hash map
+ LinkedHashMap<String, String> paramsMap = constructParameterMap(parts);
+
+ // determine the appropriate value function
+ if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty") ||
+ opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) {
+ ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
+ return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
+ }
+ else if(opcode.equals("transformapply") || opcode.equals("transformdecode") || opcode.equals("tokenize")) {
+ return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
+ }
+ }
+
public HashMap<String, String> getParameterMap() {
return params;
}
@@ -112,39 +140,28 @@ public class ParameterizedBuiltinFEDInstruction extends ComputationFEDInstructio
return paramMap;
}
- public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinCPInstruction instr) {
- return new ParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.getParameterMap(), instr.output,
- instr.getOpcode(), instr.getInstructionString());
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinCPInstruction inst,
+ ExecutionContext ec) {
+ if(ArrayUtils.contains(PARAM_BUILTINS, inst.getOpcode()) && inst.getTarget(ec).isFederatedExcept(FType.BROADCAST))
+ return ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinSPInstruction inst,
+ ExecutionContext ec) {
+ if( inst.getOpcode().equalsIgnoreCase("replace") && inst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
+ return ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
+ return null;
}
- public static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinSPInstruction instr) {
+ private static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinCPInstruction instr) {
return new ParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.getParameterMap(), instr.output,
instr.getOpcode(), instr.getInstructionString());
}
- public static ParameterizedBuiltinFEDInstruction parseInstruction(String str) {
- String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
- // first part is always the opcode
- String opcode = parts[0];
- // last part is always the output
- CPOperand out = new CPOperand(parts[parts.length - 1]);
-
- // process remaining parts and build a hash map
- LinkedHashMap<String, String> paramsMap = constructParameterMap(parts);
-
- // determine the appropriate value function
- if(opcode.equalsIgnoreCase("replace") || opcode.equalsIgnoreCase("rmempty") ||
- opcode.equalsIgnoreCase("lowertri") || opcode.equalsIgnoreCase("uppertri")) {
- ValueFunction func = ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
- return new ParameterizedBuiltinFEDInstruction(new SimpleOperator(func), paramsMap, out, opcode, str);
- }
- else if(opcode.equals("transformapply") || opcode.equals("transformdecode") || opcode.equals("tokenize")) {
- return new ParameterizedBuiltinFEDInstruction(null, paramsMap, out, opcode, str);
- }
- else {
- throw new DMLRuntimeException(
- "Unsupported opcode (" + opcode + ") for ParameterizedBuiltinFEDInstruction.");
- }
+ private static ParameterizedBuiltinFEDInstruction parseInstruction(ParameterizedBuiltinSPInstruction instr) {
+ return new ParameterizedBuiltinFEDInstruction(instr.getOperator(), instr.getParameterMap(), instr.output,
+ instr.getOpcode(), instr.getInstructionString());
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index f817c4c2a6..871128a83f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -127,6 +127,7 @@ public class QuantileSortFEDInstruction extends UnaryFEDInstruction {
inst._fedOut = fedOut;
return inst;
}
+
@Override
public void processInstruction(ExecutionContext ec) {
if(ec.getMatrixObject(input1).isFederated(FType.COL) || ec.getMatrixObject(input1).isFederated(FType.FULL))
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
index 9b5014e6d7..ee89c72485 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.java
@@ -21,17 +21,18 @@ package org.apache.sysds.runtime.instructions.fed;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedDivMM;
-import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
+import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysds.lops.WeightedSquaredLoss;
-import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
+import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -39,6 +40,7 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -58,7 +60,21 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
_input4 = in4;
}
- public static QuaternaryFEDInstruction parseInstruction(QuaternaryCPInstruction instr) {
+ public static QuaternaryFEDInstruction parseInstruction(QuaternaryCPInstruction inst, ExecutionContext ec) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ return QuaternaryFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ public static QuaternaryFEDInstruction parseInstruction(QuaternarySPInstruction inst, ExecutionContext ec) {
+ Data data = ec.getVariable(inst.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated())
+ return QuaternaryFEDInstruction.parseInstruction(inst);
+ return null;
+ }
+
+ private static QuaternaryFEDInstruction parseInstruction(QuaternaryCPInstruction instr) {
QuaternaryOperator qop = (QuaternaryOperator) instr.getOperator();
if(qop.wtype1 != null)
return QuaternaryWSLossFEDInstruction.parseInstruction(instr);
@@ -74,7 +90,7 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
return null;
}
- public static QuaternaryFEDInstruction parseInstruction(QuaternarySPInstruction instr) {
+ private static QuaternaryFEDInstruction parseInstruction(QuaternarySPInstruction instr) {
QuaternaryOperator qop = (QuaternaryOperator) instr.getOperator();
if(qop.wtype1 != null)
return QuaternaryWSLossFEDInstruction.parseInstruction(instr);
@@ -99,7 +115,8 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
- int addInput4 = (opcode.equals(WeightedCrossEntropy.OPCODE_CP) || opcode.equals(WeightedSquaredLoss.OPCODE_CP) || opcode.equals(WeightedDivMM.OPCODE_CP)) ? 1 : 0;
+ int addInput4 = (opcode.equals(WeightedCrossEntropy.OPCODE_CP) || opcode.equals(WeightedSquaredLoss.OPCODE_CP) ||
+ opcode.equals(WeightedDivMM.OPCODE_CP)) ? 1 : 0;
int addUOpcode = (opcode.equals(WeightedUnaryMM.OPCODE_CP) ? 1 : 0);
InstructionUtils.checkNumFields(parts, 6 + addInput4 + addUOpcode);
@@ -124,11 +141,10 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
Double.parseDouble(in4.getName())) : new QuaternaryOperator(wcemm_type));
return new QuaternaryWCeMMFEDInstruction(qop, in1, in2, in3, in4, out, opcode, str);
}
- else if(opcode.equals(WeightedDivMM.OPCODE_CP))
- {
+ else if(opcode.equals(WeightedDivMM.OPCODE_CP)) {
final WDivMMType wdivmm_type = WDivMMType.valueOf(parts[6]);
if(wdivmm_type.hasFourInputs())
- checkDataTypes(new DataType[]{DataType.SCALAR, DataType.MATRIX}, in4);
+ checkDataTypes(new DataType[] {DataType.SCALAR, DataType.MATRIX}, in4);
qop = new QuaternaryOperator(wdivmm_type);
return new QuaternaryWDivMMFEDInstruction(qop, in1, in2, in3, in4, out, opcode, str);
}
@@ -145,8 +161,7 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
qop = new QuaternaryOperator(wsigmoid_type);
return new QuaternaryWSigmoidFEDInstruction(qop, in1, in2, in3, out, opcode, str);
}
- else if(opcode.equals(WeightedUnaryMM.OPCODE_CP))
- {
+ else if(opcode.equals(WeightedUnaryMM.OPCODE_CP)) {
final WUMMType wumm_type = WUMMType.valueOf(parts[6]);
String uopcode = parts[1];
qop = new QuaternaryOperator(wumm_type, uopcode);
@@ -179,7 +194,7 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
protected static String rewriteSparkInstructionToCP(String inst_str) {
// TODO: don't perform replacement over the whole instruction string, possibly changing string literals,
- // instead only at positions of ExecType and Opcode
+ // instead only at positions of ExecType and Opcode
// rewrite the spark instruction to a cp instruction
inst_str = inst_str.replace(ExecType.SPARK.name(), ExecType.CP.name());
if(inst_str.contains(WeightedCrossEntropy.OPCODE))
@@ -203,11 +218,11 @@ public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction
return inst_str;
}
-
+
protected void setOutputDataCharacteristics(MatrixObject X, MatrixObject U, MatrixObject V, ExecutionContext ec) {
long rows = X.getNumRows() > 1 ? X.getNumRows() : U.getNumRows();
- long cols = X.getNumColumns() > 1 ? X.getNumColumns()
- : (U.getNumColumns() == V.getNumRows() ? V.getNumColumns() : V.getNumRows());
+ long cols = X.getNumColumns() > 1 ? X
+ .getNumColumns() : (U.getNumColumns() == V.getNumRows() ? V.getNumColumns() : V.getNumRows());
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(rows, cols, (int) X.getBlocksize());
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index bf3632f1dd..0b173d7fe8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -76,37 +76,40 @@ public class ReorgFEDInstruction extends UnaryFEDInstruction {
rinst.getInstructionString(), FederatedOutput.NONE);
}
- public static ReorgFEDInstruction parseInstruction ( String str ) {
+ public static ReorgFEDInstruction parseInstruction(String str) {
CPOperand in = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
CPOperand out = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
FederatedOutput fedOut;
- if ( opcode.equalsIgnoreCase("r'") ) {
+ if(opcode.equalsIgnoreCase("r'")) {
InstructionUtils.checkNumFields(str, 2, 3, 4);
in.split(parts[1]);
out.split(parts[2]);
int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 : Integer.parseInt(parts[3]);
- fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
- FederatedOutput.valueOf(parts[3]) : FederatedOutput.valueOf(parts[4]);
- return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, fedOut);
+ fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ? FederatedOutput.valueOf(parts[3]) : FederatedOutput
+ .valueOf(parts[4]);
+ return new ReorgFEDInstruction(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str,
+ fedOut);
}
- else if ( opcode.equalsIgnoreCase("rdiag") ) {
- parseUnaryInstruction(str, in, out); //max 2 operands
+ else if(opcode.equalsIgnoreCase("rdiag")) {
+ parseUnaryInstruction(str, in, out); // max 2 operands
fedOut = parseFedOutFlag(str, 3);
- return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
+ return new ReorgFEDInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str,
+ fedOut);
}
- else if ( opcode.equalsIgnoreCase("rev") ) {
- parseUnaryInstruction(str, in, out); //max 2 operands
+ else if(opcode.equalsIgnoreCase("rev")) {
+ parseUnaryInstruction(str, in, out); // max 2 operands
fedOut = parseFedOutFlag(str, 3);
- return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
+ return new ReorgFEDInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str,
+ fedOut);
}
else {
- throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: "+opcode);
+ throw new DMLRuntimeException("ReorgFEDInstruction: unsupported opcode: " + opcode);
}
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
index e5af25ef02..14b16111e8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/SpoofFEDInstruction.java
@@ -69,12 +69,32 @@ public class SpoofFEDInstruction extends FEDInstruction
_output = out;
}
- public static SpoofFEDInstruction parseInstruction(SpoofCPInstruction instr) {
+ public static SpoofFEDInstruction parseInstruction(SpoofCPInstruction inst, ExecutionContext ec){
+ Class<?> scla = inst.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec, inst.getInputs(), scla))
+ || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, inst.getInputs(), scla))) {
+ return SpoofFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static SpoofFEDInstruction parseInstruction(SpoofSPInstruction inst, ExecutionContext ec){
+ Class<?> scla = inst.getOperatorClass().getSuperclass();
+ if(((scla == SpoofCellwise.class || scla == SpoofMultiAggregate.class || scla == SpoofOuterProduct.class)
+ && SpoofFEDInstruction.isFederated(ec, inst.getInputs(), scla))
+ || (scla == SpoofRowwise.class && SpoofFEDInstruction.isFederated(ec, FType.ROW, inst.getInputs(), scla))) {
+ return SpoofFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static SpoofFEDInstruction parseInstruction(SpoofCPInstruction instr) {
return new SpoofFEDInstruction(instr.getSpoofOperator(), instr.getInputs(), instr.getOutput(),
instr.getOpcode(), instr.getInstructionString());
}
- public static SpoofFEDInstruction parseInstruction(SpoofSPInstruction instr) {
+ private static SpoofFEDInstruction parseInstruction(SpoofSPInstruction instr) {
SpoofOperator op = CodegenUtils.createInstance(instr.getOperatorClass());
return new SpoofFEDInstruction(op, instr.getInputs(), instr.getOutput(), instr.getOpcode(),
instr.getInstructionString());
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 342faf3296..0883e6fe02 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
@@ -34,6 +35,8 @@ import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.TernaryFrameScalarCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.TernaryFrameScalarSPInstruction;
import org.apache.sysds.runtime.instructions.spark.TernarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -46,12 +49,46 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str, fedOut);
}
- public static TernaryFEDInstruction parseInstruction(TernaryCPInstruction instr) {
+ public static TernaryFEDInstruction parseInstruction(TernaryCPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("_map") && inst instanceof TernaryFrameScalarCPInstruction &&
+ !inst.getInstructionString().contains("UtilFunctions") && inst.input1.isFrame() &&
+ ec.getFrameObject(inst.input1).isFederated()) {
+ long margin = ec.getScalarInput(inst.input3).getLongValue();
+ FrameObject fo = ec.getFrameObject(inst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) && margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+ return TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarCPInstruction) inst);
+ }
+ else if((inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input3.isMatrix() && ec.getCacheableData(inst.input3).isFederatedExcept(FType.BROADCAST))) {
+ return TernaryFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ public static TernaryFEDInstruction parseInstruction(TernarySPInstruction inst, ExecutionContext ec) {
+ if(inst.getOpcode().equals("_map") && inst instanceof TernaryFrameScalarSPInstruction &&
+ !inst.getInstructionString().contains("UtilFunctions") && inst.input1.isFrame() &&
+ ec.getFrameObject(inst.input1).isFederated()) {
+ long margin = ec.getScalarInput(inst.input3).getLongValue();
+ FrameObject fo = ec.getFrameObject(inst.input1);
+ if(margin == 0 || (fo.isFederated(FType.ROW) && margin == 1) || (fo.isFederated(FType.COL) && margin == 2))
+ return TernaryFrameScalarFEDInstruction.parseInstruction((TernaryFrameScalarSPInstruction) inst);
+ }
+ else if((inst.input1.isMatrix() && ec.getCacheableData(inst.input1).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input2.isMatrix() && ec.getCacheableData(inst.input2).isFederatedExcept(FType.BROADCAST)) ||
+ (inst.input3.isMatrix() && ec.getCacheableData(inst.input3).isFederatedExcept(FType.BROADCAST))) {
+ return TernaryFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static TernaryFEDInstruction parseInstruction(TernaryCPInstruction instr) {
return new TernaryFEDInstruction((TernaryOperator) instr.getOperator(), instr.input1, instr.input2,
instr.input3, instr.output, instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
- public static TernaryFEDInstruction parseInstruction(TernarySPInstruction instr) {
+ private static TernaryFEDInstruction parseInstruction(TernarySPInstruction instr) {
return new TernaryFEDInstruction((TernaryOperator) instr.getOperator(), instr.input1, instr.input2,
instr.input3, instr.output, instr.getOpcode(), instr.getInstructionString(), FederatedOutput.NONE);
}
@@ -63,11 +100,13 @@ public class TernaryFEDInstruction extends ComputationFEDInstruction {
CPOperand operand2 = new CPOperand(parts[2]);
CPOperand operand3 = new CPOperand(parts[3]);
CPOperand outOperand = new CPOperand(parts[4]);
- int numThreads = parts.length>5 & !opcode.contains("map") ? Integer.parseInt(parts[5]) : 1;
- FederatedOutput fedOut = parts.length>=7 && !opcode.contains("map") ? FederatedOutput.valueOf(parts[6]) : FederatedOutput.NONE;
+ int numThreads = parts.length > 5 & !opcode.contains("map") ? Integer.parseInt(parts[5]) : 1;
+ FederatedOutput fedOut = parts.length >= 7 && !opcode.contains("map") ? FederatedOutput
+ .valueOf(parts[6]) : FederatedOutput.NONE;
TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode, numThreads);
- if( operand1.isFrame() && operand2.isScalar() || operand2.isFrame() && operand1.isScalar() )
- return new TernaryFrameScalarFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, InstructionUtils.removeFEDOutputFlag(str), fedOut);
+ if(operand1.isFrame() && operand2.isScalar() || operand2.isFrame() && operand1.isScalar())
+ return new TernaryFrameScalarFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode,
+ InstructionUtils.removeFEDOutputFlag(str), fedOut);
return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str, fedOut);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
index 3b15b273db..3d34338049 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/TsmmFEDInstruction.java
@@ -51,7 +51,15 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
this(in, out, type, k, opcode, istr, FederatedOutput.NONE);
}
- public static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction instr) {
+ public static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction inst, ExecutionContext ec) {
+ MatrixObject mo = ec.getMatrixObject(inst.input1);
+ if( (mo.isFederated(FType.ROW) && mo.isFederatedExcept(FType.BROADCAST) && inst.getMMTSJType().isLeft()) ||
+ (mo.isFederated(FType.COL) && mo.isFederatedExcept(FType.BROADCAST) && inst.getMMTSJType().isRight()))
+ return parseInstruction(inst);
+ return null;
+ }
+
+ private static TsmmFEDInstruction parseInstruction(MMTSJCPInstruction instr) {
return new TsmmFEDInstruction(instr.input1, instr.getOutput(), instr.getMMTSJType(), instr.getNumThreads(),
instr.getOpcode(), instr.getInstructionString());
}
@@ -61,7 +69,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
String opcode = parts[0];
if(!opcode.equalsIgnoreCase("tsmm"))
throw new DMLRuntimeException("TsmmFedInstruction.parseInstruction():: Unknown opcode " + opcode);
-
+
InstructionUtils.checkNumFields(parts, 3, 4, 5);
CPOperand in = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
@@ -70,7 +78,7 @@ public class TsmmFEDInstruction extends BinaryFEDInstruction {
FederatedOutput fedOut = (parts.length > 5) ? FederatedOutput.valueOf(parts[5]) : FederatedOutput.NONE;
return new TsmmFEDInstruction(in, out, type, k, opcode, str, fedOut);
}
-
+
@Override
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index 1c66e77768..623872e963 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -19,9 +19,32 @@
package org.apache.sysds.runtime.instructions.fed;
+import org.apache.commons.lang.ArrayUtils;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
+import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
@@ -55,7 +78,121 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
super(type, op, in1, in2, in3, out, opcode, instr, fedOut);
}
- static String parseUnaryInstruction(String instr, CPOperand in, CPOperand out) {
+ public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction inst, ExecutionContext ec) {
+ if(inst instanceof IndexingCPInstruction) {
+ // matrix and frame indexing
+ IndexingCPInstruction minst = (IndexingCPInstruction) inst;
+ if((minst.input1.isMatrix() || minst.input1.isFrame()) &&
+ ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return IndexingFEDInstruction.parseInstruction(minst);
+ }
+ }
+ else if(inst instanceof ReorgCPInstruction &&
+ (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ ReorgCPInstruction rinst = (ReorgCPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+
+ if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederatedExcept(FType.BROADCAST))
+ return ReorgFEDInstruction.parseInstruction(rinst);
+ }
+ else if(inst.input1 != null && inst.input1.isMatrix() && ec.containsVariable(inst.input1)) {
+
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ if(mo1.isFederatedExcept(FType.BROADCAST)) {
+ if(inst instanceof CentralMomentCPInstruction)
+ return CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction) inst);
+ else if(inst instanceof QuantileSortCPInstruction) {
+ if(mo1.isFederated(FType.ROW) ||
+ mo1.getFedMapping().getFederatedRanges().length == 1 && mo1.isFederated(FType.COL))
+ return QuantileSortFEDInstruction.parseInstruction((QuantileSortCPInstruction) inst);
+ }
+ else if(inst instanceof ReshapeCPInstruction)
+ return ReshapeFEDInstruction.parseInstruction((ReshapeCPInstruction) inst);
+ else if(inst instanceof AggregateUnaryCPInstruction &&
+ ((AggregateUnaryCPInstruction) inst).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT)
+ return AggregateUnaryFEDInstruction.parseInstruction((AggregateUnaryCPInstruction) inst);
+ else if(inst instanceof UnaryMatrixCPInstruction) {
+ if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()) &&
+ !(inst.getOpcode().equalsIgnoreCase("ucumk+*") && mo1.isFederated(FType.COL)))
+ return UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixCPInstruction) inst);
+ }
+ }
+ }
+ return null;
+ }
+
+ public static UnaryFEDInstruction parseInstruction(UnarySPInstruction inst, ExecutionContext ec) {
+ if(inst instanceof IndexingSPInstruction) {
+ // matrix and frame indexing
+ IndexingSPInstruction minst = (IndexingSPInstruction) inst;
+ if((minst.input1.isMatrix() || minst.input1.isFrame()) &&
+ ec.getCacheableData(minst.input1).isFederatedExcept(FType.BROADCAST)) {
+ return IndexingFEDInstruction.parseInstruction(minst);
+ }
+ }
+ else if(inst instanceof CentralMomentSPInstruction) {
+ CentralMomentSPInstruction cinstruction = (CentralMomentSPInstruction) inst;
+ Data data = ec.getVariable(cinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated() &&
+ ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ return CentralMomentFEDInstruction.parseInstruction(cinstruction);
+ }
+ else if(inst instanceof QuantileSortSPInstruction) {
+ QuantileSortSPInstruction qinstruction = (QuantileSortSPInstruction) inst;
+ Data data = ec.getVariable(qinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated() &&
+ ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ return QuantileSortFEDInstruction.parseInstruction(qinstruction);
+ }
+ else if(inst instanceof AggregateUnarySPInstruction) {
+ AggregateUnarySPInstruction auinstruction = (AggregateUnarySPInstruction) inst;
+ Data data = ec.getVariable(auinstruction.input1);
+ if(data instanceof MatrixObject && ((MatrixObject) data).isFederated() &&
+ ((MatrixObject) data).isFederatedExcept(FType.BROADCAST))
+ if(ArrayUtils.contains(new String[] {"uarimin", "uarimax"}, auinstruction.getOpcode())) {
+ if(((MatrixObject) data).getFedMapping().getType() == FType.ROW)
+ return AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
+ }
+ else
+ return AggregateUnaryFEDInstruction.parseInstruction(auinstruction);
+ }
+ else if(inst instanceof ReorgSPInstruction &&
+ (inst.getOpcode().equals("r'") || inst.getOpcode().equals("rdiag") || inst.getOpcode().equals("rev"))) {
+ ReorgSPInstruction rinst = (ReorgSPInstruction) inst;
+ CacheableData<?> mo = ec.getCacheableData(rinst.input1);
+ if((mo instanceof MatrixObject || mo instanceof FrameObject) && mo.isFederated() &&
+ mo.isFederatedExcept(FType.BROADCAST))
+ return ReorgFEDInstruction.parseInstruction(rinst);
+ }
+ else if(inst instanceof ReblockSPInstruction && inst.input1 != null &&
+ (inst.input1.isFrame() || inst.input1.isMatrix())) {
+ ReblockSPInstruction rinst = (ReblockSPInstruction) inst;
+ CacheableData<?> data = ec.getCacheableData(rinst.input1);
+ if(data.isFederatedExcept(FType.BROADCAST))
+ return ReblockFEDInstruction.parseInstruction((ReblockSPInstruction) inst);
+ }
+ else if(inst.input1 != null && inst.input1.isMatrix() && ec.containsVariable(inst.input1)) {
+ MatrixObject mo1 = ec.getMatrixObject(inst.input1);
+ if(mo1.isFederatedExcept(FType.BROADCAST)) {
+ if(inst.getOpcode().equalsIgnoreCase("cm"))
+ return CentralMomentFEDInstruction.parseInstruction((CentralMomentSPInstruction) inst);
+ else if(inst.getOpcode().equalsIgnoreCase("qsort")) {
+ if(mo1.getFedMapping().getFederatedRanges().length == 1)
+ return QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString(), false);
+ }
+ else if(inst.getOpcode().equalsIgnoreCase("rshape")) {
+ return ReshapeFEDInstruction.parseInstruction(inst.getInstructionString());
+ }
+ else if(inst instanceof UnaryMatrixSPInstruction) {
+ if(UnaryMatrixFEDInstruction.isValidOpcode(inst.getOpcode()))
+ return UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixSPInstruction) inst);
+ }
+ }
+ }
+ return null;
+ }
+
+ protected static String parseUnaryInstruction(String instr, CPOperand in, CPOperand out) {
//TODO: simplify once all fed instructions have consistent flags
int num = InstructionUtils.checkNumFields(instr, 2, 3, 4);
if(num == 2)
@@ -69,12 +206,12 @@ public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
}
}
- static String parseUnaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
+ protected static String parseUnaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) {
InstructionUtils.checkNumFields(instr, 3);
return parse(instr, in1, in2, null, out);
}
- static String parseUnaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
+ protected static String parseUnaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) {
InstructionUtils.checkNumFields(instr, 4);
return parse(instr, in1, in2, in3, out);
}
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index c3c2111641..890b681cef 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -70,12 +70,14 @@ public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
- if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") || opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
+ if(parts.length == 5 &&
+ (opcode.equalsIgnoreCase("exp") || opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
in.split(parts[1]);
out.split(parts[2]);
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
- if( Arrays.asList(new String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ){
- UnaryOperator op = new UnaryOperator(func,Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+ if(Arrays.asList(new String[] {"ucumk+", "ucum*", "ucumk+*", "ucummin", "ucummax", "exp", "log", "sigmoid"})
+ .contains(opcode)) {
+ UnaryOperator op = new UnaryOperator(func, Integer.parseInt(parts[3]), Boolean.parseBoolean(parts[4]));
return new UnaryMatrixFEDInstruction(op, in, out, opcode, str);
}
else
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 4a51f49083..f89c32f374 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -29,6 +29,7 @@ import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
@@ -55,7 +56,23 @@ public class VariableFEDInstruction extends FEDInstruction implements LineageTra
_in = in;
}
- public static VariableFEDInstruction parseInstruction(VariableCPInstruction cpInstruction) {
+ public static VariableFEDInstruction parseInstruction(VariableCPInstruction inst, ExecutionContext ec) {
+ if(inst.getVariableOpcode() == VariableOperationCode.Write && inst.getInput1().isMatrix() &&
+ inst.getInput3().getName().contains("federated")) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ else if(inst.getVariableOpcode() == VariableOperationCode.CastAsFrameVariable && inst.getInput1().isMatrix() &&
+ ec.getCacheableData(inst.getInput1()).isFederatedExcept(FType.BROADCAST)) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ else if(inst.getVariableOpcode() == VariableOperationCode.CastAsMatrixVariable && inst.getInput1().isFrame() &&
+ ec.getCacheableData(inst.getInput1()).isFederatedExcept(FType.BROADCAST)) {
+ return VariableFEDInstruction.parseInstruction(inst);
+ }
+ return null;
+ }
+
+ private static VariableFEDInstruction parseInstruction(VariableCPInstruction cpInstruction) {
return new VariableFEDInstruction(cpInstruction);
}