You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2023/07/15 13:36:32 UTC

[systemds] branch main updated: [MINOR] Support Non-literals in Federated Reshape Instructions

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f33be5b1d9 [MINOR] Support Non-literals in Federated Reshape Instructions
f33be5b1d9 is described below

commit f33be5b1d9100e781dfa8f0ebf63390817b606bb
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Jul 15 15:32:56 2023 +0200

    [MINOR] Support Non-literals in Federated Reshape Instructions
    
    AMLS project SoSe'23, part I
    Closes #1862.
---
 .../instructions/fed/ReshapeFEDInstruction.java        | 18 ++++++++++--------
 .../functions/federated/io/FederatedReaderTest.java    | 12 +++++-------
 .../federated/primitives/FederatedMisAlignedTest.java  |  8 ++++----
 .../functions/federated/FederatedReshapeTest.dml       |  7 ++++++-
 4 files changed, 25 insertions(+), 20 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
index 3d355cd1dd..521dbe8e51 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.Arrays;
 import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.lops.Lop;
@@ -119,7 +120,7 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
 			mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]);
 
 			// set new fed map
-			FederationMap reshapedFedMap = mo1.getFedMapping();
+			FederationMap reshapedFedMap = mo1.getFedMapping().copyWithNewID(fr1[0].getID());
 			for(int i = 0; i < reshapedFedMap.getFederatedRanges().length; i++) {
 				long cells = reshapedFedMap.getFederatedRanges()[i].getSize();
 				long row = byRow.getBooleanValue() ? cells / cols : rows;
@@ -140,7 +141,7 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
 			//derive output federated mapping
 			MatrixObject out = ec.getMatrixObject(output);
 			out.getDataCharacteristics().set(rows, cols, (int) mo1.getBlocksize(), mo1.getNnz());
-			out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
+			out.setFedMapping(reshapedFedMap);
 		}
 		else {
 			// TODO support tensor out, frame and list
@@ -156,14 +157,15 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
 			.collect(Collectors.toSet()).size();
 		sameFedSize = sameFedSize == 1 ? 1 : mo1.getFedMapping().getSize();
 
+		String execTypeName = InstructionUtils.getExecType(instString).name();
+		String[] instParts = InstructionUtils.getInstructionPartsWithValueType(instString);
 		for(int i = 0; i < sameFedSize; i++) {
-			String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
 			long size = mo1.getFedMapping().getFederatedRanges()[i].getSize();
-			String oldInstStringPart = byRow ? instParts[3] : instParts[4];
-			String newInstStringPart = byRow ? 
-				oldInstStringPart.replace(String.valueOf(rows), String.valueOf(size/cols)) :
-				oldInstStringPart.replace(String.valueOf(cols), String.valueOf(size/rows));
-			instStrings[i] = instString.replace(oldInstStringPart, newInstStringPart);
+			instParts[2] = InstructionUtils.createLiteralOperand(
+				String.valueOf((int)(byRow ? size/cols : rows)), Types.ValueType.INT64);
+			instParts[3] = InstructionUtils.createLiteralOperand(
+				String.valueOf((int)(byRow ? cols : size/rows)), Types.ValueType.INT64);
+			instStrings[i] = InstructionUtils.concatOperands(ArrayUtils.addFirst(instParts, execTypeName));
 		}
 
 		if(sameFedSize == 1)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff68c8328e..295fe54770 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -40,7 +40,7 @@ import org.junit.runners.Parameterized;
 public class FederatedReaderTest extends AutomatedTestBase {
 
 	private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName());
-	private final static String TEST_DIR = "functions/federated/ioR/";
+	private final static String TEST_DIR = "functions/federated/io/";
 	private final static String TEST_NAME = "FederatedReaderTest";
 	private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/";
 	private final static int blocksize = 1024;
@@ -50,8 +50,6 @@ public class FederatedReaderTest extends AutomatedTestBase {
 	public int cols;
 	@Parameterized.Parameter(2)
 	public boolean rowPartitioned;
-	@Parameterized.Parameter(3)
-	public int fedCount;
 
 	@Override
 	public void setUp() {
@@ -62,7 +60,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
 	@Parameterized.Parameters
 	public static Collection<Object[]> data() {
 		// number of rows or cols has to be >= number of federated locations.
-		return Arrays.asList(new Object[][] {{10, 13, true, 2}});
+		return Arrays.asList(new Object[][] {{10, 13, true}});
 	}
 
 	@Test
@@ -111,11 +109,11 @@ public class FederatedReaderTest extends AutomatedTestBase {
 			// Run reference dml script with normal matrix
 
 			if(workerCount == 1) {
-				fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + "1Reference.dml";
+				fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + "1Reference.dml";
 				programArgs = new String[] {"-stats", "-args", input("X1")};
 			}
 			else {
-				fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME
+				fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME
 					+ (rowPartitioned ? "Row" : "Col") + "2Reference.dml";
 				programArgs = new String[] {"-stats", "-args", input("X1"), input("X2")};
 			}
@@ -125,7 +123,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
 			LOG.debug(refOut);
 			
 			// Run federated
-			fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml";
+			fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
 			programArgs = new String[] {"-stats", "-args", input("X.json")};
 			String out = runTest(null).toString();
 
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
index ecc8a7b90f..5b4b350b08 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
@@ -205,10 +205,10 @@ public class FederatedMisAlignedTest extends AutomatedTestBase {
 			c = cols;
 		}
 
-		double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
-		double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
-		double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
-		double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
+		double[][] X1 = getRandomMatrix(r, c, 3, 4, 1, 3);
+		double[][] X2 = getRandomMatrix(r, c, 3, 4, 1, 7);
+		double[][] X3 = getRandomMatrix(r, c, 3, 4, 1, 8);
+		double[][] X4 = getRandomMatrix(r, c, 3, 4, 1, 9);
 
 		MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
 		writeInputMatrixWithMTD("X1", X1, false, mc);
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTest.dml b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
index 6aa8a165b5..f133bcff17 100644
--- a/src/test/scripts/functions/federated/FederatedReshapeTest.dml
+++ b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
@@ -27,5 +27,10 @@ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
     ranges=list(list(0, 0), list(2, 12), list(2, 0), list(4, $cols),
     list(4, 0), list(10, $cols), list(10, 0), list(12, $cols)));
 
-s = matrix(A, rows=$r_rows, cols=$r_cols);
+# materialize the scalar input (non-literal)
+reshape_cols = $r_cols;
+while(FALSE) {}
+reshape_cols = reshape_cols;
+
+s = matrix(A, rows=$r_rows, cols=reshape_cols);
 write(s, $out_S);