You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@systemds.apache.org by GitBox <gi...@apache.org> on 2021/08/23 14:03:55 UTC

[GitHub] [systemds] ywcb00 opened a new pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

ywcb00 opened a new pull request #1371:
URL: https://github.com/apache/systemds/pull/1371


   Hi,
   This PR adds support for keeping the output of the ctable instruction federated whenever it is possible. Previously, the partial outputs were always pulled to the coordinator.
   Since there were some assumptions that the federated partitions would be the same size, I rewrote some methods.
   I also added a testcase to test the creation of the federated output using matrices as input instead of just vectors.
   
   Please have a close look at it :monocle_face: 
   
   Thanks for review :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
sebwrede commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r696452598



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;

Review comment:
       The null initializers are redundant. 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);

Review comment:
       Do you have any thoughts about how to include the _fedOut field (forced federated/forced local) in this decision?
   How would it for instance influence the setFedOutput method? 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.
+	 *
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param mo2 input matrix object mo2
+	 * @return boolean indicating if the output can be kept on the federated sites
+	 */
+	private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+		MatrixBlock mb = mo2.acquireReadAndRelease();
+		FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
+		SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+
+		// collect min and max of the corresponding slices of mo2
+		IntStream.range(0, fedRanges.length).forEach(i -> {
+			MatrixBlock sliced = mb.slice(
+				fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1,
+				fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
+			fedDims.put(sliced.min(), sliced.max());
+		});
 
+		boolean retVal = (fedDims.size() == fedRanges.length); // no duplicate begin dimension entries
 
-	private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / dims1.length;
+		Iterator<SortedMap.Entry<Double, Double>> iter = fedDims.entrySet().iterator();
+		SortedMap.Entry<Double, Double> entry = iter.next(); // first entry does not have to be checked
+		double prevEndDim = entry.getValue().doubleValue();
+		while(iter.hasNext() && retVal) {
+			entry = iter.next();
+			// previous end dimension must be less than current begin dimension (no overlaps of ranges)
+			retVal &= (prevEndDim < entry.getKey());
+			prevEndDim = entry.getValue().doubleValue();
+		}
 
-		long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+		return retVal;
+	}
+
+	/**
+	 * Set the output and its data characteristics on the federated sites.
+	 *
+	 * @param mo1 input matrix object mo1
+	 * @param out input matrix object of the output
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @return boolean indicating if the output can be kept on the federated sites

Review comment:
       It is a void method, so it does not return anything. 

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.
+	 *
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param mo2 input matrix object mo2
+	 * @return boolean indicating if the output can be kept on the federated sites
+	 */
+	private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+		MatrixBlock mb = mo2.acquireReadAndRelease();
+		FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
+		SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+
+		// collect min and max of the corresponding slices of mo2
+		IntStream.range(0, fedRanges.length).forEach(i -> {
+			MatrixBlock sliced = mb.slice(
+				fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1,
+				fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
+			fedDims.put(sliced.min(), sliced.max());
+		});
 
+		boolean retVal = (fedDims.size() == fedRanges.length); // no duplicate begin dimension entries
 
-	private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / dims1.length;
+		Iterator<SortedMap.Entry<Double, Double>> iter = fedDims.entrySet().iterator();
+		SortedMap.Entry<Double, Double> entry = iter.next(); // first entry does not have to be checked
+		double prevEndDim = entry.getValue().doubleValue();
+		while(iter.hasNext() && retVal) {
+			entry = iter.next();
+			// previous end dimension must be less than current begin dimension (no overlaps of ranges)
+			retVal &= (prevEndDim < entry.getKey());
+			prevEndDim = entry.getValue().doubleValue();

Review comment:
       Unnecessary unboxing with `.doubleValue()` on line 233 and 238.

##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.

Review comment:
       Can we make it a bit more explicit what the conditions are for being able to keep the output federated? 
   Is it only possible in the cases where the output does not need to be aggregated across the federated workers? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] ywcb00 commented on pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
ywcb00 commented on pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#issuecomment-911687538


   Thank you for the review and comments @sebwrede.
   I've made the changes accordingly.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] ywcb00 commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
ywcb00 commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r700188882



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.

Review comment:
       I tried to add a little more explanation there - not sure if it is best understandable.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] ywcb00 commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
ywcb00 commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r700190923



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
+		processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, fedOutput, staticDim, dims2);
 	}
 
+	/**
+	 * Broadcast, execute, and finalize the federated instruction according to
+	 * the specified inputs.
+	 *
+	 * @param ec execution context
+	 * @param mo1 input matrix object 1
+	 * @param mo2 input matrix object 2
+	 * @param mo3 input matrix object 3 or null
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @param reversedWeights boolean indicating if inputs mo1 and mo3 are reversed
+	 * @param fedOutput boolean indicating if output can be kept federated
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 */
 	private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3,
-		boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
-		Future<FederatedResponse>[] ffr;
+		boolean reversed, boolean reversedWeights, boolean fedOutput, long staticDim, Long[] dims2) {
+
+		FederationMap fedMap = mo1.getFedMapping();
+
+		FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+		FederatedRequest[] fr2 = null;
+		FederatedRequest fr3, fr4, fr5;
+		fr3 = fr4 = fr5 = null;
+		Future<FederatedResponse>[] ffr = null;
 
-		FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
-		FederatedRequest fr2, fr3;
 		if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-		&& mo1.getFedMapping().isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
+			&& fedMap.isAligned(mo3.getFedMapping(), AlignType.FULL)) { // mo1 and mo3 federated and aligned
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), mo3.getFedMapping().getID()});
 		}
 		else if(mo3 == null) {
 			if(!reversed)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fedMap.getID(), fr1[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr2, fr3);
-
-		} else {
-			FederatedRequest[] fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2},
+					new long[] {fr1[0].getID(), fedMap.getID()});
+		}
+		else {
+			fr2 = fedMap.broadcastSliced(mo3, false);
 			if(!reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fedMap.getID(), fr1[0].getID(), fr2[0].getID()});
 			else if(reversed && !reversedWeights)
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()});
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fedMap.getID(), fr2[0].getID()});
 			else
-				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
-					new long[] {fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()});
-
-			fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-			ffr = mo1.getFedMapping().execute(getTID(), true, fr1, fr4, fr2, fr3);
+				fr3 = FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, input2, input3},
+					new long[] {fr1[0].getID(), fr2[0].getID(), fedMap.getID()});
 		}
 
-		if(fedOutput && isFedOutput(ffr, dims1)) {
+		if(fedOutput) {
+			if(fr2 != null) // broadcasted mo3
+				fedMap.execute(getTID(), true, fr1, fr2, fr3);
+			else
+				fedMap.execute(getTID(), true, fr1, fr3);
+
 			MatrixObject out = ec.getMatrixObject(output);
-			FederationMap newFedMap = modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-			setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+			FederationMap newFedMap = modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+				staticDim, dims2, reversed);
+			setFedOutput(mo1, out, newFedMap, staticDim, dims2, reversed);
 		} else {
+			fr4 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+			fr5 = fedMap.cleanup(getTID(), fr3.getID());
+			if(fr2 != null) // broadcasted mo3
+				ffr = fedMap.execute(getTID(), true, fr1, fr2, fr3, fr4, fr5);
+			else
+				ffr = fedMap.execute(getTID(), true, fr1, fr3, fr4, fr5);
+
 			ec.setMatrixOutput(output.getName(), aggResult(ffr));
 		}
 	}
 
-	boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-		boolean fedOutput = true;
-
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / ffr.length;
-		try {
-			MatrixBlock curr;
-			MatrixBlock prev =(MatrixBlock) ffr[0].get().getData()[0];
-			for(int i = 1; i < ffr.length && fedOutput; i++) {
-				curr = (MatrixBlock) ffr[i].get().getData()[0];
-				MatrixBlock sliced = curr.slice((int) (curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-				if(curr.getNumColumns() != prev.getNumColumns())
-					return false;
-
-				// no intersection
-				if(curr.getNumRows() == (i+1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-					&& (curr.getNumRows() - sliced.getNumRows()) == i * prev.getNumRows()
-					&& curr.getNonZeros() - sliced.getNonZeros() == 0)
-					continue;
-
-				// check intersect with AND and compare number of nnz
-				MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-				prevExtend.copy(0, prev.getNumRows()-1, 0, prev.getNumColumns()-1, prev, true);
-
-				MatrixBlock  intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
-				if(intersect.getNonZeros() != 0)
-					fedOutput = false;
-				prev = sliced;
-			}
-		}
-		catch(Exception e) {
-			e.printStackTrace();
-		}
-		return fedOutput;
-	}
+	/**
+	 * Evaluate if the output can be kept federated on the different federated
+	 * sites or if the output needs to be aggregated on the coordinator, based
+	 * on the output ranges of mo2.
+	 *
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param mo2 input matrix object mo2
+	 * @return boolean indicating if the output can be kept on the federated sites
+	 */
+	private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+		MatrixBlock mb = mo2.acquireReadAndRelease();
+		FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // federated ranges of mo1
+		SortedMap<Double, Double> fedDims = new TreeMap<Double, Double>(); // <beginDim, endDim>
+
+		// collect min and max of the corresponding slices of mo2
+		IntStream.range(0, fedRanges.length).forEach(i -> {
+			MatrixBlock sliced = mb.slice(
+				fedRanges[i].getBeginDimsInt()[0], fedRanges[i].getEndDimsInt()[0] - 1,
+				fedRanges[i].getBeginDimsInt()[1], fedRanges[i].getEndDimsInt()[1] - 1);
+			fedDims.put(sliced.min(), sliced.max());
+		});
 
+		boolean retVal = (fedDims.size() == fedRanges.length); // no duplicate begin dimension entries
 
-	private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
-		long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / dims1.length;
+		Iterator<SortedMap.Entry<Double, Double>> iter = fedDims.entrySet().iterator();
+		SortedMap.Entry<Double, Double> entry = iter.next(); // first entry does not have to be checked
+		double prevEndDim = entry.getValue().doubleValue();
+		while(iter.hasNext() && retVal) {
+			entry = iter.next();
+			// previous end dimension must be less than current begin dimension (no overlaps of ranges)
+			retVal &= (prevEndDim < entry.getKey());
+			prevEndDim = entry.getValue().doubleValue();
+		}
 
-		long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+		return retVal;
+	}
+
+	/**
+	 * Set the output and its data characteristics on the federated sites.
+	 *
+	 * @param mo1 input matrix object mo1
+	 * @param out input matrix object of the output
+	 * @param fedMap the federation map of the federated matrix input mo1
+	 * @param staticDim static non-partitioned dimension of the output
+	 * @param dims2 dimensions of the partial outputs along the federated partitioning
+	 * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+	 * @return boolean indicating if the output can be kept on the federated sites

Review comment:
       Copy paste error :sweat_smile: 
   Thanks for catching it




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] sebwrede commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
sebwrede commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r700247966



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);

Review comment:
       I think you are right. The only potential problem is that the current processRequest implementation assumes that mo1 is federated and mo2 is broadcast. This is only a problem if this same logic for checking which input is federated is not a part of the federated planner, which sets the _fedOut field. This is just something I will have to look at later when I extend the federated planner to include this instruction.
   The current implementation is sufficient for now. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] ywcb00 commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
ywcb00 commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r700225238



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);

Review comment:
       We could just replace the boolean _fedOutput_ of the instruction with the __fedOut_ field, right? Then, in principle, the __fedOut_ field is used to decide in which branch (federated/local) the output is created.
   However, this change should not affect the _setFedOutput_ method at all, but it would bring with it an advantage for the case, that in the future even partially aggregated federated output is supported. :thinking: 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] asfgit closed pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
asfgit closed pull request #1371:
URL: https://github.com/apache/systemds/pull/1371


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [systemds] ywcb00 commented on a change in pull request #1371: [SYSTEMDS-3085] Federated CTable - Keep Output Federated

Posted by GitBox <gi...@apache.org>.
ywcb00 commented on a change in pull request #1371:
URL: https://github.com/apache/systemds/pull/1371#discussion_r701083148



##########
File path: src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
##########
@@ -116,119 +120,154 @@ public void processInstruction(ExecutionContext ec) {
 			mo1 = ec.getMatrixObject(input3);
 		}
 
-		long dim1 = Collections.max(Arrays.asList(dims1), Long::compare);
-		boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 && dims1.length == Arrays.stream(dims1).distinct().count();
+		// static non-partitioned output dimension (same for all federated partitions)
+		long staticDim = Collections.max(Arrays.asList(dims1), Long::compare);
+		boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);

Review comment:
       Alright, then I'll just leave the boolean there for now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscribe@systemds.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org