You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/10/31 22:07:10 UTC
[systemds] 02/05: [SYSTEMDS-2549] Extended federated binary
element-wise operations
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 998d82e27b8add5a0ca55ac687f0bfd9abe54c8b
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Oct 31 20:25:59 2020 +0100
[SYSTEMDS-2549] Extended federated binary element-wise operations
This patch generalizes the existing federated binary element-wise
operations to avoid unsupported scenarios. Specifically, if the
right-hand-side matrix (instead of left-hand-side) matrix is federated
and the operation is commutative (e.g., mult/add) we canonicalize the
inputs accordingly.
---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 17 +++++++++++++----
.../sysds/runtime/matrix/operators/BinaryOperator.java | 7 +++++++
.../apache/sysds/runtime/meta/DataCharacteristics.java | 4 +++-
.../sysds/runtime/meta/MatrixCharacteristics.java | 12 +++++++++++-
.../sysds/runtime/meta/TensorCharacteristics.java | 9 +++++++++
.../federated/algorithms/FederatedGLMTest.java | 2 +-
.../federated/algorithms/FederatedKmeansTest.java | 4 +++-
7 files changed, 47 insertions(+), 8 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index bceb6ae..ea34df1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
@@ -39,8 +40,16 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
+ //canonicalization for federated lhs
+ if( !mo1.isFederated() && mo2.isFederated()
+ && mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
+ && ((BinaryOperator)_optr).isCommutative() ) {
+ mo1 = ec.getMatrixObject(input2);
+ mo2 = ec.getMatrixObject(input1);
+ }
+
+ //execute federated operation on mo1 or mo2
FederatedRequest fr2 = null;
-
if( mo2.isFederated() ) {
if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
@@ -48,12 +57,12 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
mo1.getFedMapping().execute(getTID(), true, fr2);
}
else {
- throw new DMLRuntimeException("Matrix-matrix binary operations "
- + " with a federated right input are not supported yet.");
+ throw new DMLRuntimeException("Matrix-matrix binary operations with a "
+ + "federated right input are only supported for special cases yet.");
}
}
else {
- //matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output
+ //matrix-matrix binary operations -> lhs fed input -> fed output
if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index beca629..bc4cdd0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -56,6 +56,7 @@ public class BinaryOperator extends Operator implements Serializable
private static final long serialVersionUID = -2547950181558989209L;
public final ValueFunction fn;
+ public final boolean commutative;
public BinaryOperator(ValueFunction p) {
//binaryop is sparse-safe iff (0 op 0) == 0
@@ -65,6 +66,8 @@ public class BinaryOperator extends Operator implements Serializable
|| p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor
|| p instanceof BitwShiftL || p instanceof BitwShiftR);
fn = p;
+ commutative = p instanceof Plus || p instanceof Multiply
+ || p instanceof And || p instanceof Or || p instanceof Xor;
}
/**
@@ -111,6 +114,10 @@ public class BinaryOperator extends Operator implements Serializable
return null;
}
+ public boolean isCommutative() {
+ return commutative;
+ }
+
@Override
public String toString() {
return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index d71ce9d..a28d98d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -188,9 +188,11 @@ public abstract class DataCharacteristics implements Serializable {
dimOut.set(dim1.getRows(), dim2.getCols(), dim1.getBlocksize());
}
+ public abstract boolean equalDims(Object anObject);
+
@Override
public abstract boolean equals(Object anObject);
-
+
@Override
public abstract int hashCode();
}
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
index 0b29cce..bdc4b21 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
@@ -229,7 +229,17 @@ public class MatrixCharacteristics extends DataCharacteristics
return !nnzKnown() || numRows==0 || numColumns==0
|| (nonZero < numRows*numColumns - singleBlk);
}
-
+
+ @Override
+ public boolean equalDims(Object anObject) {
+ if( !(anObject instanceof MatrixCharacteristics) )
+ return false;
+ MatrixCharacteristics mc = (MatrixCharacteristics) anObject;
+ return dimsKnown() && mc.dimsKnown()
+ && numRows == mc.numRows
+ && numColumns == mc.numColumns;
+ }
+
@Override
public boolean equals (Object anObject) {
if( !(anObject instanceof MatrixCharacteristics) )
diff --git a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
index 449cc2d..2b554a2 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
@@ -157,6 +157,15 @@ public class TensorCharacteristics extends DataCharacteristics
}
@Override
+ public boolean equalDims(Object anObject) {
+ if( !(anObject instanceof TensorCharacteristics) )
+ return false;
+ TensorCharacteristics tc = (TensorCharacteristics) anObject;
+ return dimsKnown() && tc.dimsKnown()
+ && Arrays.equals(_dims, tc._dims);
+ }
+
+ @Override
public boolean equals (Object anObject) {
if( !(anObject instanceof TensorCharacteristics) )
return false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 2b9d287..44de28f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -123,7 +123,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
- Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ //Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
//check that federated input files are still existing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index eb70a4b..0dd339f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -128,8 +128,10 @@ public class FederatedKmeansTest extends AutomatedTestBase {
// check for federated operations
Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
- Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+ //Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uark+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
Assert.assertTrue(heavyHittersContainsString("fed_*"));
Assert.assertTrue(heavyHittersContainsString("fed_+"));
Assert.assertTrue(heavyHittersContainsString("fed_<="));