You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2020/10/31 22:07:10 UTC

[systemds] 02/05: [SYSTEMDS-2549] Extended federated binary element-wise operations

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 998d82e27b8add5a0ca55ac687f0bfd9abe54c8b
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Sat Oct 31 20:25:59 2020 +0100

    [SYSTEMDS-2549] Extended federated binary element-wise operations
    
    This patch generalizes the existing federated binary element-wise
    operations to avoid unsupported scenarios. Specifically, if the
    right-hand-side matrix (instead of left-hand-side) matrix is federated
    and the operation is commutative (e.g., mult/add) we canonicalize the
    inputs accordingly.
---
 .../fed/BinaryMatrixMatrixFEDInstruction.java           | 17 +++++++++++++----
 .../sysds/runtime/matrix/operators/BinaryOperator.java  |  7 +++++++
 .../apache/sysds/runtime/meta/DataCharacteristics.java  |  4 +++-
 .../sysds/runtime/meta/MatrixCharacteristics.java       | 12 +++++++++++-
 .../sysds/runtime/meta/TensorCharacteristics.java       |  9 +++++++++
 .../federated/algorithms/FederatedGLMTest.java          |  2 +-
 .../federated/algorithms/FederatedKmeansTest.java       |  4 +++-
 7 files changed, 47 insertions(+), 8 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index bceb6ae..ea34df1 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
@@ -39,8 +40,16 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
 		MatrixObject mo1 = ec.getMatrixObject(input1);
 		MatrixObject mo2 = ec.getMatrixObject(input2);
 		
+		//canonicalization for federated lhs
+		if( !mo1.isFederated() && mo2.isFederated() 
+			&& mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics()) 
+			&& ((BinaryOperator)_optr).isCommutative() ) {
+			mo1 = ec.getMatrixObject(input2);
+			mo2 = ec.getMatrixObject(input1);
+		}
+		
+		//execute federated operation on mo1 or mo2
 		FederatedRequest fr2 = null;
-
 		if( mo2.isFederated() ) {
 			if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) {
 				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
@@ -48,12 +57,12 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
 				mo1.getFedMapping().execute(getTID(), true, fr2);
 			}
 			else {
-				throw new DMLRuntimeException("Matrix-matrix binary operations "
-					+ " with a federated right input are not supported yet.");
+				throw new DMLRuntimeException("Matrix-matrix binary operations with a "
+					+ "federated right input are only supported for special cases yet.");
 			}
 		}
 		else {
-			//matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output
+			//matrix-matrix binary operations -> lhs fed input -> fed output
 			if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector
 				FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
 				fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2},
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index beca629..bc4cdd0 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -56,6 +56,7 @@ public class BinaryOperator  extends Operator implements Serializable
 	private static final long serialVersionUID = -2547950181558989209L;
 
 	public final ValueFunction fn;
+	public final boolean commutative;
 	
 	public BinaryOperator(ValueFunction p) {
 		//binaryop is sparse-safe iff (0 op 0) == 0
@@ -65,6 +66,8 @@ public class BinaryOperator  extends Operator implements Serializable
 			|| p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor
 			|| p instanceof BitwShiftL || p instanceof BitwShiftR);
 		fn = p;
+		commutative = p instanceof Plus || p instanceof Multiply 
+			|| p instanceof And || p instanceof Or || p instanceof Xor;
 	}
 	
 	/**
@@ -111,6 +114,10 @@ public class BinaryOperator  extends Operator implements Serializable
 		return null;
 	}
 	
+	public boolean isCommutative() {
+		return commutative;
+	}
+	
 	@Override
 	public String toString() {
 		return "BinaryOperator("+fn.getClass().getSimpleName()+")";
diff --git a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
index d71ce9d..a28d98d 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java
@@ -188,9 +188,11 @@ public abstract class DataCharacteristics implements Serializable {
 		dimOut.set(dim1.getRows(), dim2.getCols(), dim1.getBlocksize());
 	}
 
+	public abstract boolean equalDims(Object anObject);
+
 	@Override
 	public abstract boolean equals(Object anObject);
-
+	
 	@Override
 	public abstract int hashCode();
 }
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
index 0b29cce..bdc4b21 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
@@ -229,7 +229,17 @@ public class MatrixCharacteristics extends DataCharacteristics
 		return !nnzKnown() || numRows==0 || numColumns==0
 			|| (nonZero < numRows*numColumns - singleBlk);
 	}
-
+	
+	@Override
+	public boolean equalDims(Object anObject) {
+		if( !(anObject instanceof MatrixCharacteristics) )
+			return false;
+		MatrixCharacteristics mc = (MatrixCharacteristics) anObject;
+		return dimsKnown() && mc.dimsKnown()
+			&& numRows == mc.numRows
+			&& numColumns == mc.numColumns;
+	}
+	
 	@Override
 	public boolean equals (Object anObject) {
 		if( !(anObject instanceof MatrixCharacteristics) )
diff --git a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
index 449cc2d..2b554a2 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java
@@ -157,6 +157,15 @@ public class TensorCharacteristics extends DataCharacteristics
 	}
 	
 	@Override
+	public boolean equalDims(Object anObject) {
+		if( !(anObject instanceof TensorCharacteristics) )
+			return false;
+		TensorCharacteristics tc = (TensorCharacteristics) anObject;
+		return dimsKnown() && tc.dimsKnown()
+			&& Arrays.equals(_dims, tc._dims);
+	}
+	
+	@Override
 	public boolean equals (Object anObject) {
 		if( !(anObject instanceof TensorCharacteristics) )
 			return false;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
index 2b9d287..44de28f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java
@@ -123,7 +123,7 @@ public class FederatedGLMTest extends AutomatedTestBase {
 		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
 		Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+"));
 		Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
-		Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+		//Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
 		Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
 		
 		//check that federated input files are still existing
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
index eb70a4b..0dd339f 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java
@@ -128,8 +128,10 @@ public class FederatedKmeansTest extends AutomatedTestBase {
 		
 			// check for federated operations
 			Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-			Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
+			//Assert.assertTrue(heavyHittersContainsString("fed_uasqk+"));
 			Assert.assertTrue(heavyHittersContainsString("fed_uarmin"));
+			Assert.assertTrue(heavyHittersContainsString("fed_uark+"));
+			Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
 			Assert.assertTrue(heavyHittersContainsString("fed_*"));
 			Assert.assertTrue(heavyHittersContainsString("fed_+"));
 			Assert.assertTrue(heavyHittersContainsString("fed_<="));