You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/03 06:59:48 UTC
systemml git commit: [SYSTEMML-2108] Performance CP ternary +* and -*
operations
Repository: systemml
Updated Branches:
refs/heads/master c95019fd9 -> f14255f46
[SYSTEMML-2108] Performance CP ternary +* and -* operations
Since the introduction of the general ternary operation framework for
ifelse (which also subsumed the specific +* and -* operations), the +*
and -* operations showed non-negligible overhead, especially for
sparse-dense combinations. Hence, this patch adds a special case for
matrix-scalar-matrix and matrix-matrix-scalar operations that routes
these operations to the binary operation framework.
On lenet over mnist, +* and -* consumed 28% execution time - this patch
then reduced the runtime of these operations by more than 2x.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/f14255f4
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/f14255f4
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/f14255f4
Branch: refs/heads/master
Commit: f14255f464017a0f3dea1d335160b25810fe20a3
Parents: c95019f
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Feb 2 22:59:52 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Feb 2 22:59:52 2018 -0800
----------------------------------------------------------------------
.../runtime/functionobjects/MinusMultiply.java | 22 ++++++++++++++++++--
.../runtime/functionobjects/PlusMultiply.java | 22 ++++++++++++++++++--
.../functionobjects/TernaryValueFunction.java | 5 +++++
.../sysml/runtime/matrix/data/MatrixBlock.java | 14 ++++++++++---
.../matrix/operators/BinaryOperator.java | 3 +++
5 files changed, 59 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
index 794571f..1e3d093 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/MinusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
import java.io.Serializable;
-public class MinusMultiply extends TernaryValueFunction implements Serializable
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class MinusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable
{
private static final long serialVersionUID = 2801982061205871665L;
private static MinusMultiply singleObj = null;
+ private final double _cnt;
+
private MinusMultiply() {
- // nothing to do here
+ _cnt = 1;
+ }
+
+ private MinusMultiply(double cnt) {
+ _cnt = cnt;
}
public static MinusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class MinusMultiply extends TernaryValueFunction implements Serializable
public double execute(double in1, double in2, double in3) {
return in1 - in2 * in3;
}
+
+ public BinaryOperator setOp2Constant(double cnt) {
+ return new BinaryOperator(new MinusMultiply(cnt));
+ }
+
+ @Override
+ public double execute(double in1, double in2) {
+ return in1 - _cnt * in2;
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
index cb821f5..041527f 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/PlusMultiply.java
@@ -21,14 +21,23 @@ package org.apache.sysml.runtime.functionobjects;
import java.io.Serializable;
-public class PlusMultiply extends TernaryValueFunction implements Serializable
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
+
+public class PlusMultiply extends TernaryValueFunction implements ValueFunctionWithConstant, Serializable
{
private static final long serialVersionUID = 2801982061205871665L;
private static PlusMultiply singleObj = null;
+ private final double _cnt;
+
private PlusMultiply() {
- // nothing to do here
+ _cnt = 1;
+ }
+
+ private PlusMultiply(double cnt) {
+ _cnt = cnt;
}
public static PlusMultiply getFnObject() {
@@ -41,4 +50,13 @@ public class PlusMultiply extends TernaryValueFunction implements Serializable
public double execute(double in1, double in2, double in3) {
return in1 + in2 * in3;
}
+
+ public BinaryOperator setOp2Constant(double cnt) {
+ return new BinaryOperator(new PlusMultiply(cnt));
+ }
+
+ @Override
+ public double execute(double in1, double in2) {
+ return in1 + _cnt * in2;
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
index c317010..9629746 100644
--- a/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
+++ b/src/main/java/org/apache/sysml/runtime/functionobjects/TernaryValueFunction.java
@@ -22,6 +22,7 @@ package org.apache.sysml.runtime.functionobjects;
import java.io.Serializable;
import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.operators.BinaryOperator;
public abstract class TernaryValueFunction extends ValueFunction implements Serializable
{
@@ -29,4 +30,8 @@ public abstract class TernaryValueFunction extends ValueFunction implements Seri
public abstract double execute ( double in1, double in2, double in3 )
throws DMLRuntimeException;
+
+ public interface ValueFunctionWithConstant {
+ public BinaryOperator setOp2Constant(double cnt);
+ }
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index e06c8c1..654cf53 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -51,14 +51,17 @@ import org.apache.sysml.runtime.functionobjects.IfElse;
import org.apache.sysml.runtime.functionobjects.KahanFunction;
import org.apache.sysml.runtime.functionobjects.KahanPlus;
import org.apache.sysml.runtime.functionobjects.KahanPlusSq;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.functionobjects.ReduceAll;
import org.apache.sysml.runtime.functionobjects.ReduceCol;
import org.apache.sysml.runtime.functionobjects.ReduceRow;
import org.apache.sysml.runtime.functionobjects.RevIndex;
import org.apache.sysml.runtime.functionobjects.SortIndex;
import org.apache.sysml.runtime.functionobjects.SwapIndex;
+import org.apache.sysml.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
import org.apache.sysml.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysml.runtime.instructions.cp.KahanObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
@@ -2803,9 +2806,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
//prepare result
ret.reset(m, n, false);
- if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) )
- {
- //special case for shallow-copy if-else
+ if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) {
+ //SPECIAL CASE for shallow-copy if-else
boolean expr = s1 ? (d1 != 0) : (nnz==(long)m*n);
MatrixBlock tmp = expr ? m2 : m3;
if( tmp.rlen==m && tmp.clen==n ) {
@@ -2822,6 +2824,12 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
}
}
}
+ else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply) ) {
+ //SPECIAL CASE for sparse-dense combinations of common +* and -*
+ BinaryOperator bop = ((ValueFunctionWithConstant)op.fn)
+ .setOp2Constant(s2 ? d2 : d3);
+ LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop);
+ }
else {
ret.allocateDenseBlock();
http://git-wip-us.apache.org/repos/asf/systemml/blob/f14255f4/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
index 5245db5..e3b9a06 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/operators/BinaryOperator.java
@@ -39,12 +39,14 @@ import org.apache.sysml.runtime.functionobjects.IntegerDivide;
import org.apache.sysml.runtime.functionobjects.LessThan;
import org.apache.sysml.runtime.functionobjects.LessThanEquals;
import org.apache.sysml.runtime.functionobjects.Minus;
+import org.apache.sysml.runtime.functionobjects.MinusMultiply;
import org.apache.sysml.runtime.functionobjects.MinusNz;
import org.apache.sysml.runtime.functionobjects.Modulus;
import org.apache.sysml.runtime.functionobjects.Multiply;
import org.apache.sysml.runtime.functionobjects.NotEquals;
import org.apache.sysml.runtime.functionobjects.Or;
import org.apache.sysml.runtime.functionobjects.Plus;
+import org.apache.sysml.runtime.functionobjects.PlusMultiply;
import org.apache.sysml.runtime.functionobjects.Power;
import org.apache.sysml.runtime.functionobjects.ValueFunction;
import org.apache.sysml.runtime.functionobjects.Xor;
@@ -58,6 +60,7 @@ public class BinaryOperator extends Operator implements Serializable
public BinaryOperator(ValueFunction p) {
//binaryop is sparse-safe iff (0 op 0) == 0
super (p instanceof Plus || p instanceof Multiply || p instanceof Minus
+ || p instanceof PlusMultiply || p instanceof MinusMultiply
|| p instanceof And || p instanceof Or || p instanceof Xor
|| p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor
|| p instanceof BitwShiftL || p instanceof BitwShiftR);