You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/03 06:59:48 UTC

systemml git commit: [SYSTEMML-2108] Performance CP ternary +* and -* operations

Repository: systemml
Updated Branches:
  refs/heads/master c95019fd9 -> f14255f46


[SYSTEMML-2108] Performance CP ternary +* and -* operations

Since the introduction of the general ternary operation framework for
ifelse (which also subsumed the specific +* and -* operations), the +*
and -* operations showed non-negligible overhead, especially for
sparse-dense combinations. Hence, this patch adds a special case for
matrix-scalar-matrix and matrix-matrix-scalar operations that routes
these operations to the binary operation framework.

On lenet over mnist, +* and -* consumed 28% execution time - this patch
then reduced the runtime of these operations by more than 2x.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f14255f4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f14255f4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f14255f4

Branch: refs/heads/master
Commit: f14255f464017a0f3dea1d335160b25810fe20a3
Parents: c95019f
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Feb 2 22:59:52 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Feb 2 22:59:52 2018 -0800

----------------------------------------------------------------------
 .../runtime/functionobjects/MinusMultiply.java  | 22 ++++++++++++++++++--
 .../runtime/functionobjects/PlusMultiply.java   | 22 ++++++++++++++++++--
 .../functionobjects/TernaryValueFunction.java   |  5 +++++
 .../sysml/runtime/matrix/data/MatrixBlock.java  | 14 ++++++++++---
 .../matrix/operators/BinaryOperator.java        |  3 +++
 5 files changed, 59 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
index 794571f..1e3d093 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
 
 import java.io.Serializable;
 
-public class MinusMultiply extends TernaryValueFunction implements Serializable
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class MinusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable
 {
 	private static final long serialVersionUID = 2801982061205871665L;
 	
 	private static MinusMultiply singleObj = null;
 
+	private final double _cnt;
+	
 	private MinusMultiply() {
-		// nothing to do here
+		_cnt = 1;
+	}
+	
+	private MinusMultiply(double cnt) {
+		_cnt = cnt;
 	}
 
 	public static MinusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class MinusMultiply extends TernaryValueFunction implements Serializable
 	public double execute(double in1, double in2, double in3) {
 		return in1 - in2 * in3;
 	}
+	
+	public BinaryOperator setOp2Constant(double cnt) {
+		return new BinaryOperator(new MinusMultiply(cnt));
+	}
+	
+	@Override
+	public double execute(double in1, double in2) {
+		return in1 - _cnt * in2;
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
index cb821f5..041527f 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
 
 import java.io.Serializable;
 
-public class PlusMultiply extends TernaryValueFunction implements Serializable
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class PlusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable
 {
 	private static final long serialVersionUID = 2801982061205871665L;
 
 	private static PlusMultiply singleObj = null;
 
+	private final double _cnt;
+	
 	private PlusMultiply() {
-		// nothing to do here
+		_cnt = 1;
+	}
+	
+	private PlusMultiply(double cnt) {
+		_cnt = cnt;
 	}
 
 	public static PlusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class PlusMultiply extends TernaryValueFunction implements Serializable
 	public double execute(double in1, double in2, double in3) {
 		return in1 + in2 * in3;
 	}
+
+	public BinaryOperator setOp2Constant(double cnt) {
+		return new BinaryOperator(new PlusMultiply(cnt));
+	}
+	
+	@Override
+	public double execute(double in1, double in2) {
+		return in1 + _cnt * in2;
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
index c317010..9629746 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.functionobjects;
 import java.io.Serializable;
 
 import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
 
 public abstract class TernaryValueFunction extends ValueFunction implements Serializable
 {
@@ -29,4 +30,8 @@ public abstract class TernaryValueFunction extends ValueFunction implements Seri
 	
 	public abstract double execute ( double in1, double in2, double in3 )
 		throws DMLRuntimeException;
+	
+	public interface ValueFunctionWithConstant {
+		public BinaryOperator setOp2Constant(double cnt);
+	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index e06c8c1..654cf53 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -51,14 +51,17 @@ import org.apache.sysml.runtime.functionobjects.IfElse;
 import org.apache.sysml.runtime.functionobjects.KahanFunction;
 import org.apache.sysml.runtime.functionobjects.KahanPlus;
 import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.ReduceAll;
 import org.apache.sysml.runtime.functionobjects.ReduceCol;
 import org.apache.sysml.runtime.functionobjects.ReduceRow;
 import org.apache.sysml.runtime.functionobjects.RevIndex;
 import org.apache.sysml.runtime.functionobjects.SortIndex;
 import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
 import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysml.runtime.instructions.cp.KahanObject;
 import org.apache.sysml.runtime.instructions.cp.ScalarObject;
@@ -2803,9 +2806,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		//prepare result
 		ret.reset(m, n, false);
 		
-		if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) )
-		{
-			//special case for shallow-copy if-else
+		if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) {
+			//SPECIAL CASE for shallow-copy if-else
 			boolean expr = s1 ? (d1 != 0) : (nnz==(long)m*n);
 			MatrixBlock tmp = expr ? m2 : m3;
 			if( tmp.rlen==m && tmp.clen==n ) {
@@ -2822,6 +2824,12 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 				}
 			}
 		}
+		else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply) ) {
+			//SPECIAL CASE for sparse-dense combinations of common +* and -*
+			BinaryOperator bop = ((ValueFunctionWithConstant)op.fn)
+				.setOp2Constant(s2 ? d2 : d3);
+			LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop);
+		}
 		else {
 			ret.allocateDenseBlock();
 			

http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
index 5245db5..e3b9a06 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
@@ -39,12 +39,14 @@ import org.apache.sysml.runtime.functionobjects.IntegerDivide;
 import org.apache.sysml.runtime.functionobjects.LessThan;
 import org.apache.sysml.runtime.functionobjects.LessThanEquals;
 import org.apache.sysml.runtime.functionobjects.Minus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
 import org.apache.sysml.runtime.functionobjects.MinusNz;
 import org.apache.sysml.runtime.functionobjects.Modulus;
 import org.apache.sysml.runtime.functionobjects.Multiply;
 import org.apache.sysml.runtime.functionobjects.NotEquals;
 import org.apache.sysml.runtime.functionobjects.Or;
 import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
 import org.apache.sysml.runtime.functionobjects.Power;
 import org.apache.sysml.runtime.functionobjects.ValueFunction;
 import org.apache.sysml.runtime.functionobjects.Xor;
@@ -58,6 +60,7 @@ public class BinaryOperator  extends Operator implements Serializable
 	public BinaryOperator(ValueFunction p) {
 		//binaryop is sparse-safe iff (0 op 0) == 0
 		super (p instanceof Plus || p instanceof Multiply || p instanceof Minus
+			|| p instanceof PlusMultiply || p instanceof MinusMultiply
 			|| p instanceof And || p instanceof Or || p instanceof Xor
 			|| p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor
 			|| p instanceof BitwShiftL || p instanceof BitwShiftR);