You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/07 06:16:33 UTC

[2/2] systemml git commit: [SYSTEMML-2082] Codegen support for ternary ifelse in cell/magg tmpls

[SYSTEMML-2082] Codegen support for ternary ifelse in cell/magg tmpls

This patch adds basic support for ternary ifelse operations in codegen
cell and magg templates along with related tests.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/aa537dad
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/aa537dad
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/aa537dad

Branch: refs/heads/master
Commit: aa537dad43f2cf21badaedcb8629b27ad301032b
Parents: 5457066
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Feb 6 20:06:28 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Feb 6 20:06:28 2018 -0800

----------------------------------------------------------------------
 .../sysml/hops/codegen/cplan/CNodeTernary.java  | 23 ++++++++------
 .../hops/codegen/template/TemplateCell.java     | 12 +++++---
 .../functions/codegen/CellwiseTmplTest.java     | 18 ++++++++++-
 .../scripts/functions/codegen/cellwisetmpl18.R  | 32 ++++++++++++++++++++
 .../functions/codegen/cellwisetmpl18.dml        | 30 ++++++++++++++++++
 5 files changed, 99 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index 155cc8b..dc8ff82 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -27,7 +27,7 @@ public class CNodeTernary extends CNode
 {
 	public enum TernaryType {
 		PLUS_MULT, MINUS_MULT,
-		REPLACE, REPLACE_NAN,
+		REPLACE, REPLACE_NAN, IFELSE,
 		LOOKUP_RC1, LOOKUP_RVECT1;
 		
 		
@@ -52,7 +52,10 @@ public class CNodeTernary extends CNode
 				
 				case REPLACE_NAN:
 					return "    double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
-					
+				
+				case IFELSE:
+					return "    double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+				
 				case LOOKUP_RC1:
 					return sparse ?
 						"    double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
@@ -124,15 +127,14 @@ public class CNodeTernary extends CNode
 	@Override
 	public String toString() {
 		switch(_type) {
-			case PLUS_MULT: return "t(+*)";
-			case MINUS_MULT: return "t(-*)";
-			case REPLACE: 
-			case REPLACE_NAN: return "t(rplc)";
-			case LOOKUP_RC1: return "u(ixrc1)";
+			case PLUS_MULT:     return "t(+*)";
+			case MINUS_MULT:    return "t(-*)";
+			case REPLACE:
+			case REPLACE_NAN:   return "t(rplc)";
+			case IFELSE:        return "t(ifelse)";
+			case LOOKUP_RC1:    return "u(ixrc1)";
 			case LOOKUP_RVECT1: return "u(ixrv1)";
-			
-			default:
-				return super.toString();	
+			default:            return super.toString();
 		}
 	}
 	
@@ -143,6 +145,7 @@ public class CNodeTernary extends CNode
 			case MINUS_MULT:
 			case REPLACE:
 			case REPLACE_NAN:
+			case IFELSE:
 			case LOOKUP_RC1:
 				_rows = 0;
 				_cols = 0;

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 50b42ea..2b8db2a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,6 +34,7 @@ import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
 import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
@@ -168,7 +169,7 @@ public class TemplateCell extends TemplateBase
 					&& HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose
 				rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
 			else {
-				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);	
+				CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
 				tmp.put(c.getHopID(), cdata);
 				inHops.add(c);
 			}
@@ -208,6 +209,7 @@ public class TemplateCell extends TemplateBase
 			
 			//add lookups if required
 			cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
+			cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
 			cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
 			
 			//construct ternary cnode, primitive operation derived from OpOp3
@@ -299,11 +301,11 @@ public class TemplateCell extends TemplateBase
 		//prepare indicators for ternary operations
 		boolean isTernaryVectorScalarVector = false;
 		boolean isTernaryMatrixScalarMatrixDense = false;
+		boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
 		if( hop instanceof TernaryOp && hop.getInput().size()==3 && hop.dimsKnown() 
-			&& HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
+			&& HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX) ) {
 			Hop left = hop.getInput().get(0);
 			Hop right = hop.getInput().get(2);
-			
 			isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
 			isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right) 
 				&& !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
@@ -312,8 +314,8 @@ public class TemplateCell extends TemplateBase
 		//check supported unary, binary, ternary operations
 		return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp 
 				|| isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix
-				|| isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense
-				|| (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));	
+				|| isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse
+				|| (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));
 	}
 	
 	protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index bd3b36a..2f44f61 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -53,6 +53,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	private static final String TEST_NAME15 = TEST_NAME+15; //colMins(2*log(X))
 	private static final String TEST_NAME16 = TEST_NAME+16; //colSums(2*log(X));
 	private static final String TEST_NAME17 = TEST_NAME+17; //xor operation
+	private static final String TEST_NAME18 = TEST_NAME+18; //sum(ifelse(X,Y,Z))
 	
 	
 	private static final String TEST_DIR = "functions/codegen/";
@@ -66,7 +67,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	@Override
 	public void setUp() {
 		TestUtils.clearAssertionInformation();
-		for( int i=1; i<=17; i++ ) {
+		for( int i=1; i<=18; i++ ) {
 			addTestConfiguration( TEST_NAME+i, new TestConfiguration(
 					TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
 		}
@@ -304,6 +305,21 @@ public class CellwiseTmplTest extends AutomatedTestBase
 		testCodegenIntegration( TEST_NAME17, true, ExecType.SPARK );
 	}
 	
+	@Test
+	public void testCodegenCellwiseRewrite18() {
+		testCodegenIntegration( TEST_NAME18, true, ExecType.CP );
+	}
+
+	@Test
+	public void testCodegenCellwise18() {
+		testCodegenIntegration( TEST_NAME18, false, ExecType.CP );
+	}
+
+	@Test
+	public void testCodegenCellwiseRewrite18_sp() {
+		testCodegenIntegration( TEST_NAME18, true, ExecType.SPARK );
+	}
+	
 	
 	private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
 	{			

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.R b/src/test/scripts/functions/codegen/cellwisetmpl18.R
new file mode 100644
index 0000000..e6a275a
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(-1000, 198999), 1000, 200, byrow=TRUE);
+Y = matrix(seq(0, 199999), 1000, 200, byrow=TRUE);
+Z = matrix(seq(1000, 200999), 1000, 200, byrow=TRUE);
+
+R = as.matrix(sum(as.numeric(ifelse(X,Y,Z))));
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));

http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.dml b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
new file mode 100644
index 0000000..c178dd3
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(-1000, 198999), 1000, 200);
+Y = matrix(seq(0, 199999), 1000, 200);
+Z = matrix(seq(1000, 200999), 1000, 200);
+
+while(FALSE){}
+
+R = as.matrix(sum(ifelse(X,Y,Z)));
+
+write(R, $1)