You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/10/27 19:38:16 UTC

systemml git commit: [SYSTEMML-2498] Fix codegen compiler for cbind w/ vectors and scalars

Repository: systemml
Updated Branches:
  refs/heads/master 0eff9f28d -> 3cbd9d5ab


[SYSTEMML-2498] Fix codegen compiler for cbind w/ vectors and scalars

This patch fixes the codegen compiler for binary and nary cbind
operations to (1) not compile row templates for cbind operations with
row vectors, and (2) robustness for a mix of matrix and colunm vector
inputs, where the column vectors become scalars in the context of a row
template.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3cbd9d5a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3cbd9d5a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3cbd9d5a

Branch: refs/heads/master
Commit: 3cbd9d5ab0e9cd29b4e67183129deaa549c10d30
Parents: 0eff9f2
Author: Matthias Boehm <mb...@gmail.com>
Authored: Sat Oct 27 20:57:37 2018 +0200
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Sat Oct 27 20:57:37 2018 +0200

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeNary.java     | 23 ++++++++++-------
 .../hops/codegen/template/TemplateRow.java      | 14 ++++++-----
 .../test/integration/AutomatedTestBase.java     | 26 ++++++++++----------
 3 files changed, 35 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
index 28e47f4..1a717d3 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeNary.java
@@ -53,15 +53,20 @@ public class CNodeNary extends CNode
 						boolean sparseInput = sparseGen && input instanceof CNodeData
 							&& input.getVarname().startsWith("a");
 						String varj = input.getVarname();
-						String pos = (input instanceof CNodeData && input.getDataType().isMatrix()) ? 
-								(!varj.startsWith("b")) ? varj+"i" : TemplateUtils.isMatrix(input) ? 
-								varj + ".pos(rix)" : "0" : "0";
-						sb.append( sparseInput ?
-							"    LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
-								+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
-							"    LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
-								+", %TMP%, "+pos+", "+off+", "+input._cols+");\n");
-						off += input._cols;
+						if( input.getDataType()==DataType.MATRIX ) {
+							String pos = (input instanceof CNodeData) ?
+								!varj.startsWith("b") ? varj+"i" : varj + ".pos(rix)" : "0";
+							sb.append( sparseInput ?
+								"    LibSpoofPrimitives.vectWrite("+varj+"vals, %TMP%, "
+									+varj+"ix, "+pos+", "+off+", "+input._cols+");\n" :
+								"    LibSpoofPrimitives.vectWrite("+(varj.startsWith("b")?varj+".values(rix)":varj)
+									+", %TMP%, "+pos+", "+off+", "+input._cols+");\n");
+							off += input._cols;	
+						}
+						else { //e.g., col vectors -> scalars
+							sb.append("    %TMP%["+off+"] = "+varj+";\n");
+							off ++;
+						}
 					}
 					return sb.toString();
 				case VECT_MAX_POOL:

http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
index d9da27b..79213eb 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRow.java
@@ -92,9 +92,8 @@ public class TemplateRow extends TemplateBase
 				&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
 			|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
 				&& TemplateCell.isValidOperation(hop) && hop.getDim1() > 1)
-			|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
 			|| HopRewriteUtils.isTernary(hop, OpOp3.PLUS_MULT, OpOp3.MINUS_MULT)
-			|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+			|| isValidBinaryNaryCBind(hop)
 			|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
 			|| (hop instanceof AggBinaryOp && hop.dimsKnown() && hop.getDim2()==1 //MV
 				&& hop.getInput().get(0).getDim1()>1 && hop.getInput().get(0).getDim2()>1)
@@ -125,8 +124,7 @@ public class TemplateRow extends TemplateBase
 	public boolean fuse(Hop hop, Hop input) {
 		return !isClosed() && 
 			(  (hop instanceof BinaryOp && isValidBinaryOperation(hop)) 
-			|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
-			|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+			|| isValidBinaryNaryCBind(hop)
 			|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
 			|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) 
 				&& TemplateCell.isValidOperation(hop))
@@ -156,8 +154,7 @@ public class TemplateRow extends TemplateBase
 		return !isClosed() &&
 			((hop instanceof BinaryOp && isValidBinaryOperation(hop)
 				&& hop.getDim1() > 1 && input.getDim1()>1)
-			|| (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
-			|| (HopRewriteUtils.isNary(hop, OpOpN.CBIND) && hop.getInput().get(0).isMatrix() && hop.dimsKnown())
+			|| isValidBinaryNaryCBind(hop)
 			|| (HopRewriteUtils.isNary(hop, OpOpN.MIN, OpOpN.MAX) && hop.isMatrix())
 			|| (HopRewriteUtils.isDnn(hop, OpOpDnn.BIASADD, OpOpDnn.BIASMULT)
 				&& hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()
@@ -191,6 +188,11 @@ public class TemplateRow extends TemplateBase
 		return TemplateUtils.isOperationSupported(hop);
 	}
 	
+	private static boolean isValidBinaryNaryCBind(Hop hop) {
+		return (HopRewriteUtils.isBinary(hop, OpOp2.CBIND) || HopRewriteUtils.isNary(hop, OpOpN.CBIND))
+			&& hop.getInput().get(0).isMatrix() && hop.dimsKnown() && hop.getInput().get(0).getDim1()>1;
+	}
+	
 	private static boolean isFuseSkinnyMatrixMult(Hop hop) {
 		//check for fusable but not opening matrix multiply (vect_outer-mult)
 		Hop in1 = hop.getInput().get(0); //transpose

http://git-wip-us.apache.org/repos/asf/systemml/blob/3cbd9d5a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 2135f45..e3576ab 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -213,19 +213,19 @@ public abstract class AutomatedTestBase
 	
 	protected RUNTIME_PLATFORM setRuntimePlatform(ExecType et) {
 		RUNTIME_PLATFORM platformOld = rtplatform;
-        switch (et) {
-            case MR:
-                rtplatform = RUNTIME_PLATFORM.HADOOP;
-                break;
-            case SPARK: {
-                rtplatform = RUNTIME_PLATFORM.SPARK;
-                DMLScript.USE_LOCAL_SPARK_CONFIG = true; // Always use local config for junit tests
-                break;
-            }
-            default:
-                rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
-                break;
-        }
+		switch (et) {
+			case MR:
+				rtplatform = RUNTIME_PLATFORM.HADOOP;
+				break;
+			case SPARK: {
+				rtplatform = RUNTIME_PLATFORM.SPARK;
+				DMLScript.USE_LOCAL_SPARK_CONFIG = true; // Always use local config for junit tests
+				break;
+			}
+			default:
+				rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+				break;
+		}
 		return platformOld;
 	}