You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/04/01 01:20:06 UTC

incubator-systemml git commit: [SYSTEMML-1447] Extended code generator (replace in rowagg/cell tmpls)

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 8f7cf77be -> 2e48d951b


[SYSTEMML-1447] Extended code generator (replace in rowagg/cell tmpls)

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2e48d951
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2e48d951
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2e48d951

Branch: refs/heads/master
Commit: 2e48d951b825fe4ef85dc13f6d69934b8cadfe46
Parents: 8f7cf77
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Mar 31 17:17:55 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Mar 31 18:21:17 2017 -0700

----------------------------------------------------------------------
 .../sysml/hops/ParameterizedBuiltinOp.java      |  5 ++++
 .../sysml/hops/codegen/cplan/CNodeTernary.java  | 16 ++++++++--
 .../hops/codegen/template/TemplateCell.java     | 25 +++++++++++++---
 .../hops/codegen/template/TemplateRowAgg.java   | 18 +++++++++++-
 .../hops/codegen/template/TemplateUtils.java    |  3 ++
 .../functions/codegen/CellwiseTmplTest.java     | 22 ++++++++++++--
 .../scripts/functions/codegen/cellwisetmpl11.R  | 31 ++++++++++++++++++++
 .../functions/codegen/cellwisetmpl11.dml        | 27 +++++++++++++++++
 8 files changed, 138 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index fa51948..1d6828c 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -154,6 +154,11 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
 			getInput().get(_paramIndexMap.get("target")) : null;
 	}
 	
+	public Hop getParameterHop(String name) {
+		return _paramIndexMap.containsKey(name) ?   
+			getInput().get(_paramIndexMap.get(name)) : null;	
+	}
+	
 	@Override
 	public void setMaxNumThreads( int k ) {
 		_maxNumThreads = k;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index eb26eff..a8bbcb2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -28,6 +28,7 @@ public class CNodeTernary extends CNode
 {
 	public enum TernaryType {
 		PLUS_MULT, MINUS_MULT,
+		REPLACE, REPLACE_NAN,
 		LOOKUP_RC1;
 		
 		public static boolean contains(String value) {
@@ -40,10 +41,17 @@ public class CNodeTernary extends CNode
 		public String getTemplate(boolean sparse) {
 			switch (this) {
 				case PLUS_MULT:
-					return "    double %TMP% = %IN1% + %IN2% * %IN3%;\n" ;
+					return "    double %TMP% = %IN1% + %IN2% * %IN3%;\n";
 				
 				case MINUS_MULT:
-					return "    double %TMP% = %IN1% - %IN2% * %IN3%;\n" ;
+					return "    double %TMP% = %IN1% - %IN2% * %IN3%;\n";
+					
+				case REPLACE:
+					return "    double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) "
+							+ "&& Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
+				
+				case REPLACE_NAN:
+					return "    double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
 					
 				case LOOKUP_RC1:
 					return "    double %TMP% = %IN1%[rowIndex*%IN2%+%IN3%-1];\n";	
@@ -101,6 +109,8 @@ public class CNodeTernary extends CNode
 		switch(_type) {
 			case PLUS_MULT: return "t(+*)";
 			case MINUS_MULT: return "t(-*)";
+			case REPLACE: 
+			case REPLACE_NAN: return "t(rplc)";
 			case LOOKUP_RC1: return "u(ixrc1)";
 			default:
 				return super.toString();	
@@ -112,6 +122,8 @@ public class CNodeTernary extends CNode
 		switch(_type) {
 			case PLUS_MULT: 
 			case MINUS_MULT:
+			case REPLACE:
+			case REPLACE_NAN:
 			case LOOKUP_RC1:
 				_rows = 0;
 				_cols = 0;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 87ec899..447f6d6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,8 +34,10 @@ import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
 import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.ParamBuiltinOp;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
 import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
@@ -157,7 +159,7 @@ public class TemplateCell extends TemplateBase
 			else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
 				cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
 			
-			String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+			String primitiveOpName = ((UnaryOp)hop).getOp().name();
 			out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
 		}
 		else if(hop instanceof BinaryOp)
@@ -165,7 +167,7 @@ public class TemplateCell extends TemplateBase
 			BinaryOp bop = (BinaryOp) hop;
 			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
 			CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
-			String primitiveOpName = bop.getOp().toString();
+			String primitiveOpName = bop.getOp().name();
 			
 			//cdata1 is vector
 			if( TemplateUtils.isColVector(cdata1) )
@@ -207,7 +209,21 @@ public class TemplateCell extends TemplateBase
 			
 			//construct ternary cnode, primitive operation derived from OpOp3
 			out = new CNodeTernary(cdata1, cdata2, cdata3, 
-					TernaryType.valueOf(top.getOp().toString()));
+					TernaryType.valueOf(top.getOp().name()));
+		}
+		else if( hop instanceof ParameterizedBuiltinOp ) 
+		{
+			CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+			if( TemplateUtils.isColVector(cdata1) )
+				cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
+			else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
+				cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
+			
+			CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+			CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+			TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? 
+					TernaryType.REPLACE_NAN : TernaryType.REPLACE;
+			out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
 		}
 		else if( hop instanceof IndexingOp ) 
 		{
@@ -285,7 +301,8 @@ public class TemplateCell extends TemplateBase
 		//check supported unary, binary, ternary operations
 		return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp 
 				|| isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense 
-				|| isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense);	
+				|| isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense
+				|| (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));	
 	}
 	
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
index f8f1508..2883893 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.IndexingOp;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.UnaryOp;
 import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -78,7 +79,8 @@ public class TemplateRowAgg extends TemplateBase
 		return !isClosed() && 
 			(  (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
 					|| HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) ) 
-			|| (hop instanceof UnaryOp && TemplateCell.isValidOperation(hop))		
+			|| ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp) 
+					&& TemplateCell.isValidOperation(hop))		
 			|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol)
 			|| (hop instanceof AggBinaryOp && hop.getDim1()>1 
 				&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
@@ -255,6 +257,20 @@ public class TemplateRowAgg extends TemplateBase
 			out = new CNodeTernary(cdata1, cdata2, cdata3, 
 					TernaryType.valueOf(top.getOp().toString()));
 		}
+		else if( hop instanceof ParameterizedBuiltinOp ) 
+		{
+			CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+			if( TemplateUtils.isColVector(cdata1) )
+				cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
+			else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
+				cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
+			
+			CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+			CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+			TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ? 
+					TernaryType.REPLACE_NAN : TernaryType.REPLACE;
+			out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
+		}
 		else if( hop instanceof IndexingOp ) 
 		{
 			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 3f5fed9..b959638 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -30,6 +30,7 @@ import org.apache.sysml.hops.AggUnaryOp;
 import org.apache.sysml.hops.BinaryOp;
 import org.apache.sysml.hops.Hop;
 import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
 import org.apache.sysml.hops.TernaryOp;
 import org.apache.sysml.hops.Hop.AggOp;
 import org.apache.sysml.hops.Hop.Direction;
@@ -105,6 +106,8 @@ public class TemplateUtils
 			return BinType.contains(((BinaryOp)h).getOp().name());
 		else if(h instanceof TernaryOp)
 			return TernaryType.contains(((TernaryOp)h).getOp().name());
+		else if(h instanceof ParameterizedBuiltinOp) 
+			return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
 		return false;
 	}
 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index 066b761..10aa038 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -46,6 +46,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	private static final String TEST_NAME8 = TEST_NAME+8;
 	private static final String TEST_NAME9 = TEST_NAME+9;   //sum((X + 7 * Y)^2)
 	private static final String TEST_NAME10 = TEST_NAME+10; //min/max(X + 7 * Y)
+	private static final String TEST_NAME11 = TEST_NAME+11; //replace((0 / (X - 500))+1, 0/0, 7);
+	
 
 	private static final String TEST_DIR = "functions/codegen/";
 	private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/";
@@ -58,7 +60,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	@Override
 	public void setUp() {
 		TestUtils.clearAssertionInformation();
-		for( int i=1; i<=10; i++ ) {
+		for( int i=1; i<=11; i++ ) {
 			addTestConfiguration( TEST_NAME+i, new TestConfiguration(
 					TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
 		}
@@ -114,6 +116,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	public void testCodegenCellwiseRewrite10() {
 		testCodegenIntegration( TEST_NAME10, true, ExecType.CP  );
 	}
+	
+	@Test
+	public void testCodegenCellwiseRewrite11() {
+		testCodegenIntegration( TEST_NAME11, true, ExecType.CP  );
+	}
 
 	@Test
 	public void testCodegenCellwise1() {
@@ -165,6 +172,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
 	public void testCodegenCellwise10() {
 		testCodegenIntegration( TEST_NAME10, false, ExecType.CP  );
 	}
+	
+	@Test
+	public void testCodegenCellwise11() {
+		testCodegenIntegration( TEST_NAME11, false, ExecType.CP  );
+	}
 
 	@Test
 	public void testCodegenCellwiseRewrite1_sp() {
@@ -191,6 +203,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
 		testCodegenIntegration( TEST_NAME10, true, ExecType.SPARK );
 	}
 	
+	@Test
+	public void testCodegenCellwiseRewrite11_sp() {
+		testCodegenIntegration( TEST_NAME11, true, ExecType.SPARK );
+	}
+	
 	private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
 	{	
 		
@@ -247,7 +264,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
 				Assert.assertTrue(!heavyHittersContainsSubString("tsmm"));
 			else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused
 				Assert.assertTrue(!heavyHittersContainsSubString("uamin","uamax"));
-				
+			else if( testname.equals(TEST_NAME11) ) //ensure replace is fused
+				Assert.assertTrue(!heavyHittersContainsSubString("replace"));	
 		}
 		finally {
 			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.R b/src/test/scripts/functions/codegen/cellwisetmpl11.R
new file mode 100644
index 0000000..33531ba
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(7, 1006), 500, 2, byrow=TRUE);
+
+Y = (0 / (X - 500))+1;
+R = replace(Y, is.nan(Y), 7);
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.dml b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
new file mode 100644
index 0000000..c77da08
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(7, 1006), 500, 2);
+
+Y = (0 / (X - 500))+1;
+R = replace(target=Y, pattern=0/0, replacement=7);
+
+write(R, $1)