You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2023/07/15 13:36:32 UTC
[systemds] branch main updated: [MINOR] Support Non-literals in Federated Reshape Instructions
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f33be5b1d9 [MINOR] Support Non-literals in Federated Reshape Instructions
f33be5b1d9 is described below
commit f33be5b1d9100e781dfa8f0ebf63390817b606bb
Author: ywcb00 <yw...@ywcb.org>
AuthorDate: Sat Jul 15 15:32:56 2023 +0200
[MINOR] Support Non-literals in Federated Reshape Instructions
AMLS project SoSe'23, part I
Closes #1862.
---
.../instructions/fed/ReshapeFEDInstruction.java | 18 ++++++++++--------
.../functions/federated/io/FederatedReaderTest.java | 12 +++++-------
.../federated/primitives/FederatedMisAlignedTest.java | 8 ++++----
.../functions/federated/FederatedReshapeTest.dml | 7 ++++++-
4 files changed, 25 insertions(+), 20 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
index 3d355cd1dd..521dbe8e51 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.Arrays;
import java.util.stream.Collectors;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Lop;
@@ -119,7 +120,7 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
mo1.getFedMapping().execute(getTID(), true, fr1, new FederatedRequest[0]);
// set new fed map
- FederationMap reshapedFedMap = mo1.getFedMapping();
+ FederationMap reshapedFedMap = mo1.getFedMapping().copyWithNewID(fr1[0].getID());
for(int i = 0; i < reshapedFedMap.getFederatedRanges().length; i++) {
long cells = reshapedFedMap.getFederatedRanges()[i].getSize();
long row = byRow.getBooleanValue() ? cells / cols : rows;
@@ -140,7 +141,7 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
//derive output federated mapping
MatrixObject out = ec.getMatrixObject(output);
out.getDataCharacteristics().set(rows, cols, (int) mo1.getBlocksize(), mo1.getNnz());
- out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
+ out.setFedMapping(reshapedFedMap);
}
else {
// TODO support tensor out, frame and list
@@ -156,14 +157,15 @@ public class ReshapeFEDInstruction extends UnaryFEDInstruction {
.collect(Collectors.toSet()).size();
sameFedSize = sameFedSize == 1 ? 1 : mo1.getFedMapping().getSize();
+ String execTypeName = InstructionUtils.getExecType(instString).name();
+ String[] instParts = InstructionUtils.getInstructionPartsWithValueType(instString);
for(int i = 0; i < sameFedSize; i++) {
- String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
long size = mo1.getFedMapping().getFederatedRanges()[i].getSize();
- String oldInstStringPart = byRow ? instParts[3] : instParts[4];
- String newInstStringPart = byRow ?
- oldInstStringPart.replace(String.valueOf(rows), String.valueOf(size/cols)) :
- oldInstStringPart.replace(String.valueOf(cols), String.valueOf(size/rows));
- instStrings[i] = instString.replace(oldInstStringPart, newInstStringPart);
+ instParts[2] = InstructionUtils.createLiteralOperand(
+ String.valueOf((int)(byRow ? size/cols : rows)), Types.ValueType.INT64);
+ instParts[3] = InstructionUtils.createLiteralOperand(
+ String.valueOf((int)(byRow ? cols : size/rows)), Types.ValueType.INT64);
+ instStrings[i] = InstructionUtils.concatOperands(ArrayUtils.addFirst(instParts, execTypeName));
}
if(sameFedSize == 1)
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff68c8328e..295fe54770 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -40,7 +40,7 @@ import org.junit.runners.Parameterized;
public class FederatedReaderTest extends AutomatedTestBase {
private static final Log LOG = LogFactory.getLog(FederatedReaderTest.class.getName());
- private final static String TEST_DIR = "functions/federated/ioR/";
+ private final static String TEST_DIR = "functions/federated/io/";
private final static String TEST_NAME = "FederatedReaderTest";
private final static String TEST_CLASS_DIR = TEST_DIR + FederatedReaderTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -50,8 +50,6 @@ public class FederatedReaderTest extends AutomatedTestBase {
public int cols;
@Parameterized.Parameter(2)
public boolean rowPartitioned;
- @Parameterized.Parameter(3)
- public int fedCount;
@Override
public void setUp() {
@@ -62,7 +60,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
// number of rows or cols has to be >= number of federated locations.
- return Arrays.asList(new Object[][] {{10, 13, true, 2}});
+ return Arrays.asList(new Object[][] {{10, 13, true}});
}
@Test
@@ -111,11 +109,11 @@ public class FederatedReaderTest extends AutomatedTestBase {
// Run reference dml script with normal matrix
if(workerCount == 1) {
- fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + "1Reference.dml";
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + "1Reference.dml";
programArgs = new String[] {"-stats", "-args", input("X1")};
}
else {
- fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME
+ (rowPartitioned ? "Row" : "Col") + "2Reference.dml";
programArgs = new String[] {"-stats", "-args", input("X1"), input("X2")};
}
@@ -125,7 +123,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
LOG.debug(refOut);
// Run federated
- fullDMLScriptName = SCRIPT_DIR + "functions/federated/io/" + TEST_NAME + ".dml";
+ fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "-args", input("X.json")};
String out = runTest(null).toString();
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
index ecc8a7b90f..5b4b350b08 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
@@ -205,10 +205,10 @@ public class FederatedMisAlignedTest extends AutomatedTestBase {
c = cols;
}
- double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
- double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
- double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
- double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
+ double[][] X1 = getRandomMatrix(r, c, 3, 4, 1, 3);
+ double[][] X2 = getRandomMatrix(r, c, 3, 4, 1, 7);
+ double[][] X3 = getRandomMatrix(r, c, 3, 4, 1, 8);
+ double[][] X4 = getRandomMatrix(r, c, 3, 4, 1, 9);
MatrixCharacteristics mc = new MatrixCharacteristics(r, c, blocksize, r * c);
writeInputMatrixWithMTD("X1", X1, false, mc);
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTest.dml b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
index 6aa8a165b5..f133bcff17 100644
--- a/src/test/scripts/functions/federated/FederatedReshapeTest.dml
+++ b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
@@ -27,5 +27,10 @@ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list(2, 12), list(2, 0), list(4, $cols),
list(4, 0), list(10, $cols), list(10, 0), list(12, $cols)));
-s = matrix(A, rows=$r_rows, cols=$r_cols);
+# materialize the scalar input (non-literal)
+reshape_cols = $r_cols;
+while(FALSE) {}
+reshape_cols = reshape_cols;
+
+s = matrix(A, rows=$r_rows, cols=reshape_cols);
write(s, $out_S);