You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/09 07:34:54 UTC
systemml git commit: [SYSTEMML-1755] Fix simplification rewrite
binary matrix-scalar ops
Repository: systemml
Updated Branches:
refs/heads/master b84a4933c -> 352c256a3
[SYSTEMML-1755] Fix simplification rewrite binary matrix-scalar ops
This patch fixes the rewrite for simplifying matrix-scalar to
scalar-scalar operations to correctly check for binary operations that
are supported over scalars.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/352c256a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/352c256a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/352c256a
Branch: refs/heads/master
Commit: 352c256a3d71bb587162120134f87e4a9a2df507
Parents: b84a493
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Jul 9 00:32:47 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Jul 9 00:32:47 2017 -0700
----------------------------------------------------------------------
src/main/java/org/apache/sysml/hops/Hop.java | 92 ++++++++++----------
.../RewriteAlgebraicSimplificationStatic.java | 8 +-
2 files changed, 54 insertions(+), 46 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 8f8afde..80d33f1 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.lops.Binary;
+import org.apache.sysml.lops.BinaryScalar;
import org.apache.sysml.lops.CSVReBlock;
import org.apache.sysml.lops.Checkpoint;
import org.apache.sysml.lops.Compression;
@@ -1143,53 +1145,53 @@ public abstract class Hop
}
- protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes> HopsOpOp2LopsB;
+ protected static final HashMap<Hop.OpOp2, Binary.OperationTypes> HopsOpOp2LopsB;
static {
- HopsOpOp2LopsB = new HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes>();
- HopsOpOp2LopsB.put(OpOp2.PLUS, org.apache.sysml.lops.Binary.OperationTypes.ADD);
- HopsOpOp2LopsB.put(OpOp2.MINUS, org.apache.sysml.lops.Binary.OperationTypes.SUBTRACT);
- HopsOpOp2LopsB.put(OpOp2.MULT, org.apache.sysml.lops.Binary.OperationTypes.MULTIPLY);
- HopsOpOp2LopsB.put(OpOp2.DIV, org.apache.sysml.lops.Binary.OperationTypes.DIVIDE);
- HopsOpOp2LopsB.put(OpOp2.MODULUS, org.apache.sysml.lops.Binary.OperationTypes.MODULUS);
- HopsOpOp2LopsB.put(OpOp2.INTDIV, org.apache.sysml.lops.Binary.OperationTypes.INTDIV);
- HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, org.apache.sysml.lops.Binary.OperationTypes.MINUS1_MULTIPLY);
- HopsOpOp2LopsB.put(OpOp2.LESS, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN);
- HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN_OR_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.GREATER, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN);
- HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.EQUAL, org.apache.sysml.lops.Binary.OperationTypes.EQUALS);
- HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.Binary.OperationTypes.NOT_EQUALS);
- HopsOpOp2LopsB.put(OpOp2.MIN, org.apache.sysml.lops.Binary.OperationTypes.MIN);
- HopsOpOp2LopsB.put(OpOp2.MAX, org.apache.sysml.lops.Binary.OperationTypes.MAX);
- HopsOpOp2LopsB.put(OpOp2.AND, org.apache.sysml.lops.Binary.OperationTypes.OR);
- HopsOpOp2LopsB.put(OpOp2.OR, org.apache.sysml.lops.Binary.OperationTypes.AND);
- HopsOpOp2LopsB.put(OpOp2.SOLVE, org.apache.sysml.lops.Binary.OperationTypes.SOLVE);
- HopsOpOp2LopsB.put(OpOp2.POW, org.apache.sysml.lops.Binary.OperationTypes.POW);
- HopsOpOp2LopsB.put(OpOp2.LOG, org.apache.sysml.lops.Binary.OperationTypes.NOTSUPPORTED);
- }
-
- protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
+ HopsOpOp2LopsB = new HashMap<Hop.OpOp2, Binary.OperationTypes>();
+ HopsOpOp2LopsB.put(OpOp2.PLUS, Binary.OperationTypes.ADD);
+ HopsOpOp2LopsB.put(OpOp2.MINUS, Binary.OperationTypes.SUBTRACT);
+ HopsOpOp2LopsB.put(OpOp2.MULT, Binary.OperationTypes.MULTIPLY);
+ HopsOpOp2LopsB.put(OpOp2.DIV, Binary.OperationTypes.DIVIDE);
+ HopsOpOp2LopsB.put(OpOp2.MODULUS, Binary.OperationTypes.MODULUS);
+ HopsOpOp2LopsB.put(OpOp2.INTDIV, Binary.OperationTypes.INTDIV);
+ HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, Binary.OperationTypes.MINUS1_MULTIPLY);
+ HopsOpOp2LopsB.put(OpOp2.LESS, Binary.OperationTypes.LESS_THAN);
+ HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, Binary.OperationTypes.LESS_THAN_OR_EQUALS);
+ HopsOpOp2LopsB.put(OpOp2.GREATER, Binary.OperationTypes.GREATER_THAN);
+ HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
+ HopsOpOp2LopsB.put(OpOp2.EQUAL, Binary.OperationTypes.EQUALS);
+ HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, Binary.OperationTypes.NOT_EQUALS);
+ HopsOpOp2LopsB.put(OpOp2.MIN, Binary.OperationTypes.MIN);
+ HopsOpOp2LopsB.put(OpOp2.MAX, Binary.OperationTypes.MAX);
+ HopsOpOp2LopsB.put(OpOp2.AND, Binary.OperationTypes.OR);
+ HopsOpOp2LopsB.put(OpOp2.OR, Binary.OperationTypes.AND);
+ HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE);
+ HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW);
+ HopsOpOp2LopsB.put(OpOp2.LOG, Binary.OperationTypes.NOTSUPPORTED);
+ }
+
+ protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
static {
- HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes>();
- HopsOpOp2LopsBS.put(OpOp2.PLUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.ADD);
- HopsOpOp2LopsBS.put(OpOp2.MINUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.SUBTRACT);
- HopsOpOp2LopsBS.put(OpOp2.MULT, org.apache.sysml.lops.BinaryScalar.OperationTypes.MULTIPLY);
- HopsOpOp2LopsBS.put(OpOp2.DIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.DIVIDE);
- HopsOpOp2LopsBS.put(OpOp2.MODULUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.MODULUS);
- HopsOpOp2LopsBS.put(OpOp2.INTDIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.INTDIV);
- HopsOpOp2LopsBS.put(OpOp2.LESS, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN);
- HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.GREATER, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN);
- HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.EQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.NOT_EQUALS);
- HopsOpOp2LopsBS.put(OpOp2.MIN, org.apache.sysml.lops.BinaryScalar.OperationTypes.MIN);
- HopsOpOp2LopsBS.put(OpOp2.MAX, org.apache.sysml.lops.BinaryScalar.OperationTypes.MAX);
- HopsOpOp2LopsBS.put(OpOp2.AND, org.apache.sysml.lops.BinaryScalar.OperationTypes.AND);
- HopsOpOp2LopsBS.put(OpOp2.OR, org.apache.sysml.lops.BinaryScalar.OperationTypes.OR);
- HopsOpOp2LopsBS.put(OpOp2.LOG, org.apache.sysml.lops.BinaryScalar.OperationTypes.LOG);
- HopsOpOp2LopsBS.put(OpOp2.POW, org.apache.sysml.lops.BinaryScalar.OperationTypes.POW);
- HopsOpOp2LopsBS.put(OpOp2.PRINT, org.apache.sysml.lops.BinaryScalar.OperationTypes.PRINT);
+ HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, BinaryScalar.OperationTypes>();
+ HopsOpOp2LopsBS.put(OpOp2.PLUS, BinaryScalar.OperationTypes.ADD);
+ HopsOpOp2LopsBS.put(OpOp2.MINUS, BinaryScalar.OperationTypes.SUBTRACT);
+ HopsOpOp2LopsBS.put(OpOp2.MULT, BinaryScalar.OperationTypes.MULTIPLY);
+ HopsOpOp2LopsBS.put(OpOp2.DIV, BinaryScalar.OperationTypes.DIVIDE);
+ HopsOpOp2LopsBS.put(OpOp2.MODULUS, BinaryScalar.OperationTypes.MODULUS);
+ HopsOpOp2LopsBS.put(OpOp2.INTDIV, BinaryScalar.OperationTypes.INTDIV);
+ HopsOpOp2LopsBS.put(OpOp2.LESS, BinaryScalar.OperationTypes.LESS_THAN);
+ HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
+ HopsOpOp2LopsBS.put(OpOp2.GREATER, BinaryScalar.OperationTypes.GREATER_THAN);
+ HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
+ HopsOpOp2LopsBS.put(OpOp2.EQUAL, BinaryScalar.OperationTypes.EQUALS);
+ HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, BinaryScalar.OperationTypes.NOT_EQUALS);
+ HopsOpOp2LopsBS.put(OpOp2.MIN, BinaryScalar.OperationTypes.MIN);
+ HopsOpOp2LopsBS.put(OpOp2.MAX, BinaryScalar.OperationTypes.MAX);
+ HopsOpOp2LopsBS.put(OpOp2.AND, BinaryScalar.OperationTypes.AND);
+ HopsOpOp2LopsBS.put(OpOp2.OR, BinaryScalar.OperationTypes.OR);
+ HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG);
+ HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW);
+ HopsOpOp2LopsBS.put(OpOp2.PRINT, BinaryScalar.OperationTypes.PRINT);
}
protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp2LopsU;
http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index b8f9369..53359cc 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -846,8 +846,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos )
throws HopsException
{
+ // Note: This rewrite is not applicable for all binary operations because some of them
+ // are undefined over scalars. We explicitly exclude potential conflicting matrix-scalar binary
+ // operations; other operations like cbind/rbind will never occur as matrix-scalar operations.
+
if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)
- && hi.getInput().get(0) instanceof BinaryOp )
+ && hi.getInput().get(0) instanceof BinaryOp
+ && !HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.QUANTILE,
+ OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, OpOp2.LOG_NZ))
{
BinaryOp bin = (BinaryOp) hi.getInput().get(0);
BinaryOp bout = null;