You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/09 07:34:54 UTC

systemml git commit: [SYSTEMML-1755] Fix simplification rewrite binary matrix-scalar ops

Repository: systemml
Updated Branches:
  refs/heads/master b84a4933c -> 352c256a3


[SYSTEMML-1755] Fix simplification rewrite binary matrix-scalar ops

This patch fixes the rewrite for simplifying matrix-scalar to
scalar-scalar operations to correctly check for binary operations that
are supported over scalars. 

Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/352c256a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/352c256a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/352c256a

Branch: refs/heads/master
Commit: 352c256a3d71bb587162120134f87e4a9a2df507
Parents: b84a493
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sun Jul 9 00:32:47 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sun Jul 9 00:32:47 2017 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    | 92 ++++++++++----------
 .../RewriteAlgebraicSimplificationStatic.java   |  8 +-
 2 files changed, 54 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 8f8afde..80d33f1 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.lops.Binary;
+import org.apache.sysml.lops.BinaryScalar;
 import org.apache.sysml.lops.CSVReBlock;
 import org.apache.sysml.lops.Checkpoint;
 import org.apache.sysml.lops.Compression;
@@ -1143,53 +1145,53 @@ public abstract class Hop
 
 	}
 
-	protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes> HopsOpOp2LopsB;
+	protected static final HashMap<Hop.OpOp2, Binary.OperationTypes> HopsOpOp2LopsB;
 	static {
-		HopsOpOp2LopsB = new HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes>();
-		HopsOpOp2LopsB.put(OpOp2.PLUS, org.apache.sysml.lops.Binary.OperationTypes.ADD);
-		HopsOpOp2LopsB.put(OpOp2.MINUS, org.apache.sysml.lops.Binary.OperationTypes.SUBTRACT);
-		HopsOpOp2LopsB.put(OpOp2.MULT, org.apache.sysml.lops.Binary.OperationTypes.MULTIPLY);
-		HopsOpOp2LopsB.put(OpOp2.DIV, org.apache.sysml.lops.Binary.OperationTypes.DIVIDE);
-		HopsOpOp2LopsB.put(OpOp2.MODULUS, org.apache.sysml.lops.Binary.OperationTypes.MODULUS);
-		HopsOpOp2LopsB.put(OpOp2.INTDIV, org.apache.sysml.lops.Binary.OperationTypes.INTDIV);
-		HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, org.apache.sysml.lops.Binary.OperationTypes.MINUS1_MULTIPLY);
-		HopsOpOp2LopsB.put(OpOp2.LESS, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN);
-		HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN_OR_EQUALS);
-		HopsOpOp2LopsB.put(OpOp2.GREATER, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN);
-		HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
-		HopsOpOp2LopsB.put(OpOp2.EQUAL, org.apache.sysml.lops.Binary.OperationTypes.EQUALS);
-		HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.Binary.OperationTypes.NOT_EQUALS);
-		HopsOpOp2LopsB.put(OpOp2.MIN, org.apache.sysml.lops.Binary.OperationTypes.MIN);
-		HopsOpOp2LopsB.put(OpOp2.MAX, org.apache.sysml.lops.Binary.OperationTypes.MAX);
-		HopsOpOp2LopsB.put(OpOp2.AND, org.apache.sysml.lops.Binary.OperationTypes.OR);
-		HopsOpOp2LopsB.put(OpOp2.OR, org.apache.sysml.lops.Binary.OperationTypes.AND);
-		HopsOpOp2LopsB.put(OpOp2.SOLVE, org.apache.sysml.lops.Binary.OperationTypes.SOLVE);
-		HopsOpOp2LopsB.put(OpOp2.POW, org.apache.sysml.lops.Binary.OperationTypes.POW);
-		HopsOpOp2LopsB.put(OpOp2.LOG, org.apache.sysml.lops.Binary.OperationTypes.NOTSUPPORTED);
-	}
-
-	protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
+		HopsOpOp2LopsB = new HashMap<Hop.OpOp2, Binary.OperationTypes>();
+		HopsOpOp2LopsB.put(OpOp2.PLUS, Binary.OperationTypes.ADD);
+		HopsOpOp2LopsB.put(OpOp2.MINUS, Binary.OperationTypes.SUBTRACT);
+		HopsOpOp2LopsB.put(OpOp2.MULT, Binary.OperationTypes.MULTIPLY);
+		HopsOpOp2LopsB.put(OpOp2.DIV, Binary.OperationTypes.DIVIDE);
+		HopsOpOp2LopsB.put(OpOp2.MODULUS, Binary.OperationTypes.MODULUS);
+		HopsOpOp2LopsB.put(OpOp2.INTDIV, Binary.OperationTypes.INTDIV);
+		HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, Binary.OperationTypes.MINUS1_MULTIPLY);
+		HopsOpOp2LopsB.put(OpOp2.LESS, Binary.OperationTypes.LESS_THAN);
+		HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, Binary.OperationTypes.LESS_THAN_OR_EQUALS);
+		HopsOpOp2LopsB.put(OpOp2.GREATER, Binary.OperationTypes.GREATER_THAN);
+		HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, Binary.OperationTypes.GREATER_THAN_OR_EQUALS);
+		HopsOpOp2LopsB.put(OpOp2.EQUAL, Binary.OperationTypes.EQUALS);
+		HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, Binary.OperationTypes.NOT_EQUALS);
+		HopsOpOp2LopsB.put(OpOp2.MIN, Binary.OperationTypes.MIN);
+		HopsOpOp2LopsB.put(OpOp2.MAX, Binary.OperationTypes.MAX);
+		HopsOpOp2LopsB.put(OpOp2.AND, Binary.OperationTypes.OR);
+		HopsOpOp2LopsB.put(OpOp2.OR, Binary.OperationTypes.AND);
+		HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE);
+		HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW);
+		HopsOpOp2LopsB.put(OpOp2.LOG, Binary.OperationTypes.NOTSUPPORTED);
+	}
+
+	protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> HopsOpOp2LopsBS;
 	static {
-		HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes>();
-		HopsOpOp2LopsBS.put(OpOp2.PLUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.ADD);	
-		HopsOpOp2LopsBS.put(OpOp2.MINUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.SUBTRACT);
-		HopsOpOp2LopsBS.put(OpOp2.MULT, org.apache.sysml.lops.BinaryScalar.OperationTypes.MULTIPLY);
-		HopsOpOp2LopsBS.put(OpOp2.DIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.DIVIDE);
-		HopsOpOp2LopsBS.put(OpOp2.MODULUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.MODULUS);
-		HopsOpOp2LopsBS.put(OpOp2.INTDIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.INTDIV);
-		HopsOpOp2LopsBS.put(OpOp2.LESS, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN);
-		HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
-		HopsOpOp2LopsBS.put(OpOp2.GREATER, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN);
-		HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
-		HopsOpOp2LopsBS.put(OpOp2.EQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.EQUALS);
-		HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.NOT_EQUALS);
-		HopsOpOp2LopsBS.put(OpOp2.MIN, org.apache.sysml.lops.BinaryScalar.OperationTypes.MIN);
-		HopsOpOp2LopsBS.put(OpOp2.MAX, org.apache.sysml.lops.BinaryScalar.OperationTypes.MAX);
-		HopsOpOp2LopsBS.put(OpOp2.AND, org.apache.sysml.lops.BinaryScalar.OperationTypes.AND);
-		HopsOpOp2LopsBS.put(OpOp2.OR, org.apache.sysml.lops.BinaryScalar.OperationTypes.OR);
-		HopsOpOp2LopsBS.put(OpOp2.LOG, org.apache.sysml.lops.BinaryScalar.OperationTypes.LOG);
-		HopsOpOp2LopsBS.put(OpOp2.POW, org.apache.sysml.lops.BinaryScalar.OperationTypes.POW);
-		HopsOpOp2LopsBS.put(OpOp2.PRINT, org.apache.sysml.lops.BinaryScalar.OperationTypes.PRINT);
+		HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, BinaryScalar.OperationTypes>();
+		HopsOpOp2LopsBS.put(OpOp2.PLUS, BinaryScalar.OperationTypes.ADD);	
+		HopsOpOp2LopsBS.put(OpOp2.MINUS, BinaryScalar.OperationTypes.SUBTRACT);
+		HopsOpOp2LopsBS.put(OpOp2.MULT, BinaryScalar.OperationTypes.MULTIPLY);
+		HopsOpOp2LopsBS.put(OpOp2.DIV, BinaryScalar.OperationTypes.DIVIDE);
+		HopsOpOp2LopsBS.put(OpOp2.MODULUS, BinaryScalar.OperationTypes.MODULUS);
+		HopsOpOp2LopsBS.put(OpOp2.INTDIV, BinaryScalar.OperationTypes.INTDIV);
+		HopsOpOp2LopsBS.put(OpOp2.LESS, BinaryScalar.OperationTypes.LESS_THAN);
+		HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS);
+		HopsOpOp2LopsBS.put(OpOp2.GREATER, BinaryScalar.OperationTypes.GREATER_THAN);
+		HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS);
+		HopsOpOp2LopsBS.put(OpOp2.EQUAL, BinaryScalar.OperationTypes.EQUALS);
+		HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, BinaryScalar.OperationTypes.NOT_EQUALS);
+		HopsOpOp2LopsBS.put(OpOp2.MIN, BinaryScalar.OperationTypes.MIN);
+		HopsOpOp2LopsBS.put(OpOp2.MAX, BinaryScalar.OperationTypes.MAX);
+		HopsOpOp2LopsBS.put(OpOp2.AND, BinaryScalar.OperationTypes.AND);
+		HopsOpOp2LopsBS.put(OpOp2.OR, BinaryScalar.OperationTypes.OR);
+		HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG);
+		HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW);
+		HopsOpOp2LopsBS.put(OpOp2.PRINT, BinaryScalar.OperationTypes.PRINT);
 	}
 
 	protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp2LopsU;

http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index b8f9369..53359cc 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -846,8 +846,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 	private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) 
 		throws HopsException
 	{
+		// Note: This rewrite is not applicable for all binary operations because some of them 
+		// are undefined over scalars. We explicitly exclude potential conflicting matrix-scalar binary
+		// operations; other operations like cbind/rbind will never occur as matrix-scalar operations.
+		
 		if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR)  
-		   && hi.getInput().get(0) instanceof BinaryOp ) 
+			&& hi.getInput().get(0) instanceof BinaryOp
+			&& !HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.QUANTILE, 
+			OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, OpOp2.LOG_NZ)) 
 		{
 			BinaryOp bin = (BinaryOp) hi.getInput().get(0);
 			BinaryOp bout = null;