You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@systemds.apache.org by GitBox <gi...@apache.org> on 2021/02/25 21:49:20 UTC

[GitHub] [systemds] OlgaOvcharenko opened a new pull request #1193: [WIP] Federated ternary instruction

OlgaOvcharenko opened a new pull request #1193:
URL: https://github.com/apache/systemds/pull/1193


   This PR adds ternary fed instruction. Added to run tests.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
sebwrede commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603053321



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       I meant to write "+\*" and "-\*", but the stars were removed because it is a special character in this Github comment system :smile: . I can see that "+\*" is number 8 in your list of heavy hitter instructions. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603288340



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       I fixed the L2SVM inner loop a few minutes ago, but I needed to modify a few fed instructions for this. So if you neeed, you can use it, but I do not know if other tests don't fail now 😄
   But yes,  "+*" and "-*"  are implemented and supported.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603264651



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       "+*" and "-*" are working, but the inputs in L2SVM are not federated. 
   In while loop there is tmp_Xw = Xw + step_sz*Xd [line 115], Xd = X %*% s should be federated, but in AggregateBinaryFEDInstruction.java MV output is not federated [line 83]. 
   I fixed this line with if( mo2.getNumColumns() == 1 && mo2.getNumRows() != mo1.getNumColumns()), but then out = 1 - Y * (tmp_Xw) [in line 84 inl2svm.dml] fails.
   This could be solved if "1-*" was commutative (and I thought it is), but not.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
sebwrede commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603283322



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       Yes, this is also what I have been looking at. I am currently working on compiling federated instructions instead of converting them during runtime and how to control whether the output is federated. When doing this, I will look at the instructions in the inner loop of L2SVM and this is also the reason why I built a temporary TernaryFED and some additional cases for some of the other instructions. I am developing this in some other branches, so it is not necessary to have this in your PR. What I think is relevant for your PR is just to ensure that the "+\*" and "-\*" operations are supported in TernaryFEDInstruction. If this is implemented, I can later merge the compilation of FED instructions, including the federated output planning, with your version of TernaryFEDInstruction. If your changes break the L2SVM tests, you can decide not to convert the CP instruction to a FED instruction when it is a "+\*" or "-\*" operation.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] asfgit closed pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
asfgit closed pull request #1193:
URL: https://github.com/apache/systemds/pull/1193


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602794572



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
+			fr3 = FederationUtils.callInstruction(instString,
+				output,
+				new CPOperand[] {input1, input2, input3},
+				new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+			mo1.getFedMapping().execute(getTID(), fr3);
+		} else {
+			FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+			FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);

Review comment:
       Added case when 2 matrices are federated and aligned and one is broadcasted.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603264651



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       "+*" and "-*" are working, but the inputs in L2SVM are not federated. 
   In while loop there is tmp_Xw = Xw + step_sz*Xd [line 115], Xd = X %*% s should be federated, but in AggregateBinaryFEDInstruction.java MV output is not federated [line 83]. I fixed this line with if( mo2.getNumColumns() == 1 && mo2.getNumRows() != mo1.getNumColumns()), but then out = 1 - Y * (tmp_Xw) [in line 84 inl2svm.dml] fails.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602795376



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       > Thanks for the PR!
   > It looks like a good start.
   > 
   > I think we could use some more tests that cover all branches of the processing of federated ternary instructions. One of the comments I added is regarding whether a part of the code can even be reached, so maybe we could think about a test that could cover this part (or if this is not possible, this part of the code could be removed).
   > I think it would also be interesting to look at other ternary operations, for instance the "+_" and "-_". This is relevant for L2SVM and I have already looked at this in a separate branch with a solution that is targeted for this single purpose, but my approach is still incomplete, so it is more relevant to build this in your TernaryFEDInstruction version.
   
   What do you mean by "+_" and "-_"? 
   Heavy hitter instructions:
     #  Instruction  Time(s)  Count
     1  m_l2svm        0,137      1
     2  fed_ba+*       0,116     15
     3  fed_fedinit    0,061      1
     4  ba+*           0,022     29
     5  write          0,012      1
     6  rightIndex     0,011      1
     7  rmvar          0,003    459
     8  +*             0,002     35
     9  1-*            0,002     21
    10  createvar      0,002    227
    11  tak+*          0,002     28
    12  *              0,002     91
    13  list           0,001      4
    14  max            0,001     21
    15  tsmm           0,001     42
    16  r'             0,001     30
    17  >              0,001     14
    18  castdts        0,000     56
    19  -              0,000     35
    20  +              0,000     59




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603288340



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       I fixed the L2SVM inner loop a few minutes ago, but I needed to modify a few fed instructions for this. So if you neeed, you can use it, but I do not know if other tests don't fail now 😄
   But yes,  "+* " and "-* "  are implemented and supported.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602794437



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}

Review comment:
       I modified it slightly, so now it's more readable. Also added the case when mo1 and mo2 are aligned and no broadcast is needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602794437



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}

Review comment:
       I modified it slightly, so now it's more readable.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
sebwrede commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602445163



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
+			fr3 = FederationUtils.callInstruction(instString,
+				output,
+				new CPOperand[] {input1, input2, input3},
+				new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+			mo1.getFedMapping().execute(getTID(), fr3);
+		} else {
+			FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+			FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);

Review comment:
       What if mo1 and mo2 are federated and aligned, but mo3 needs to be broadcast? 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {
+			fr3 = FederationUtils.callInstruction(instString,
+				output,
+				new CPOperand[] {input1, input2, input3},
+				new long[] {mo1.getFedMapping().getID(), mo2.getFedMapping().getID(), mo3.getFedMapping().getID()});
+			mo1.getFedMapping().execute(getTID(), fr3);
+		} else {
+			FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+			FederatedRequest[] fr2 = mo1.getFedMapping().broadcastSliced(mo3, false);
+
+			long vars[];
+			if(!mo1.isFederated())

Review comment:
       Isn't mo1 always federated? Either mo1, mo2, or mo3 has to be federated for this to be a federated instruction and mo1 is always replaced by whichever mo is federated in line 119-126.

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       I think this needs a newline to fit the code style (or perhaps put it in a method returning a boolean). 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}

Review comment:
       Is it necessary to swap the variables like this? The functionality is fine, but it is difficult to read and to follow where the matrix objects end up by the end of the if-else statements. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r603288340



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       I fixed the L2SVM inner loop a few minutes ago, but I needed to modify a few fed instructions for this. So if you neeed, you can use it, but I do not know if other tests won't fail now 😄
   But yes,  "+*" and "-*"  are implemented and supported.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
sebwrede commented on pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#issuecomment-810901079


   I think the latest commit is a work around for not having a BinaryMatrixMatrixFEDInstruction with federated right input. I had this case in another branch which has not been added to master. I made some changes to your branch, but I cannot push to your fork. Instead, I added it as a branch to my own fork. You can find it here: https://github.com/sebwrede/systemds/commit/bbf679cae0a73a2e8eee7dbc16efff7e998bba36
   I think it is best not to adapt the algorithms to prevent calling the federated instructions with right input. If you agree, you can merge from my branch or develop your own solution for it. 
   When this is done, I think this PR looks good and I can merge it later. 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] OlgaOvcharenko commented on a change in pull request #1193: [SYSTEMDS-2904] Federated ternary instruction

Posted by GitBox <gi...@apache.org>.
OlgaOvcharenko commented on a change in pull request #1193:
URL: https://github.com/apache/systemds/pull/1193#discussion_r602795376



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
##########
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.fed;
+
+import java.util.Objects;
+
+import com.sun.tools.javac.util.List;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
+
+public class TernaryFEDInstruction extends ComputationFEDInstruction {
+
+	private TernaryFEDInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
+		super(FEDInstruction.FEDType.Ternary, op, in1, in2, in3, out, opcode, str);
+	}
+
+	public static TernaryFEDInstruction parseInstruction(String str)
+	{
+		String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
+		String opcode=parts[0];
+		CPOperand operand1 = new CPOperand(parts[1]);
+		CPOperand operand2 = new CPOperand(parts[2]);
+		CPOperand operand3 = new CPOperand(parts[3]);
+		CPOperand outOperand = new CPOperand(parts[4]);
+		TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
+		return new TernaryFEDInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
+	}
+
+	@Override
+	public void processInstruction( ExecutionContext ec ) {
+		MatrixObject mo1 = input1.isMatrix() ? ec.getMatrixObject(input1.getName()) : null;
+		MatrixObject mo2 = input2.isMatrix() ? ec.getMatrixObject(input2.getName()) : null;
+		MatrixObject mo3 = input3.isMatrix() ? ec.getMatrixObject(input3.getName()) : null;
+
+		long matrixInputsCount = List.of(mo1, mo2, mo3).stream().filter(Objects::nonNull).count();
+
+		if(matrixInputsCount == 3)
+			processMatrixInput(ec, mo1, mo2, mo3);
+		else if (matrixInputsCount == 1) {
+			CPOperand in = mo1 == null ? mo2 == null ? input3 : input2 : input1;
+			mo1 = mo1 == null ? mo2 == null ? mo3 : mo2 : mo1;
+			processMatrixScalarInput(ec, mo1, in);
+		} else
+			process2MatrixScalarInput(ec, mo1, mo2, mo3);
+	}
+
+	private void processMatrixScalarInput(ExecutionContext ec, MatrixObject mo1, CPOperand in) {
+		FederatedRequest fr1 = FederationUtils.callInstruction(instString, output,
+			new CPOperand[] {in}, new long[] {mo1.getFedMapping().getID()});
+		mo1.getFedMapping().execute(getTID(), true, fr1);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
+	}
+
+	private void process2MatrixScalarInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		CPOperand[] inputArgs = new CPOperand[] {input1, input2};
+		if(mo1 != null && mo1.isFederated() && mo2 == null) {
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input1, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 == null) {
+			mo1 = mo2;
+			mo2 = mo3;
+			inputArgs = new CPOperand[] {input2, input3};
+		} else if(mo2 != null && mo2.isFederated() && mo1 != null) {
+			mo1 = mo2;
+			mo2 = ec.getMatrixObject(input1);
+			inputArgs = new CPOperand[] {input2, input1};
+		} else if(mo3 != null && mo3.isFederated() && mo1 == null) {
+			mo1 = mo3;
+			inputArgs = new CPOperand[] {input3, input2};
+		} else if(mo3 != null && mo3.isFederated() && mo1 != null) {
+			mo1 = mo3;
+			mo2 = ec.getMatrixObject(input1);
+
+			inputArgs = new CPOperand[] {input3, input1};
+		}
+
+		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
+
+		FederatedRequest fr2 = FederationUtils.callInstruction(instString, output,
+			inputArgs, new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+
+		FederatedRequest fr3 = mo1.getFedMapping().cleanup(getTID(), fr1[0].getID());
+		mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+
+		//derive new fed mapping for output
+		MatrixObject out = ec.getMatrixObject(output);
+		out.getDataCharacteristics().set(mo1.getDataCharacteristics());
+		out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr3.getID()));
+	}
+
+
+	private void processMatrixInput(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3) {
+		if(!mo1.isFederated())
+			if(mo2.isFederated()) {
+				mo1 = mo2;
+				mo2 = ec.getMatrixObject(input1);
+			} else {
+				mo1 = mo3;
+				mo3 = ec.getMatrixObject(input1);
+			}
+
+		FederatedRequest fr3;
+		// all 3 inputs aligned on the one worker
+		if(mo1.isFederated() && mo2.isFederated() && mo3.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false) && mo1.getFedMapping().isAligned(mo3.getFedMapping(), false)) {

Review comment:
       > Thanks for the PR!
   > It looks like a good start.
   > 
   > I think we could use some more tests that cover all branches of the processing of federated ternary instructions. One of the comments I added is regarding whether a part of the code can even be reached, so maybe we could think about a test that could cover this part (or if this is not possible, this part of the code could be removed).
   > I think it would also be interesting to look at other ternary operations, for instance the "+_" and "-_". This is relevant for L2SVM and I have already looked at this in a separate branch with a solution that is targeted for this single purpose, but my approach is still incomplete, so it is more relevant to build this in your TernaryFEDInstruction version.
   
   What do you mean by "+_" and "-_"? 
   Heavy hitter instructions:
     1  m_l2svm        0,137      1
     2  fed_ba+*       0,116     15
     3  fed_fedinit    0,061      1
     4  ba+*           0,022     29
     5  write          0,012      1
     6  rightIndex     0,011      1
     7  rmvar          0,003    459
     8  +*             0,002     35
     9  1-*            0,002     21
    10  createvar      0,002    227
    11  tak+*          0,002     28
    12  *              0,002     91
    13  list           0,001      4
    14  max            0,001     21
    15  tsmm           0,001     42
    16  r'             0,001     30
    17  >              0,001     14
    18  castdts        0,000     56
    19  -              0,000     35
    20  +              0,000     59




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org