You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/04/01 01:20:06 UTC
incubator-systemml git commit: [SYSTEMML-1447] Extended code
generator (replace in rowagg/cell tmpls)
Repository: incubator-systemml
Updated Branches:
refs/heads/master 8f7cf77be -> 2e48d951b
[SYSTEMML-1447] Extended code generator (replace in rowagg/cell tmpls)
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/2e48d951
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/2e48d951
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/2e48d951
Branch: refs/heads/master
Commit: 2e48d951b825fe4ef85dc13f6d69934b8cadfe46
Parents: 8f7cf77
Author: Matthias Boehm <mb...@gmail.com>
Authored: Fri Mar 31 17:17:55 2017 -0700
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Fri Mar 31 18:21:17 2017 -0700
----------------------------------------------------------------------
.../sysml/hops/ParameterizedBuiltinOp.java | 5 ++++
.../sysml/hops/codegen/cplan/CNodeTernary.java | 16 ++++++++--
.../hops/codegen/template/TemplateCell.java | 25 +++++++++++++---
.../hops/codegen/template/TemplateRowAgg.java | 18 +++++++++++-
.../hops/codegen/template/TemplateUtils.java | 3 ++
.../functions/codegen/CellwiseTmplTest.java | 22 ++++++++++++--
.../scripts/functions/codegen/cellwisetmpl11.R | 31 ++++++++++++++++++++
.../functions/codegen/cellwisetmpl11.dml | 27 +++++++++++++++++
8 files changed, 138 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index fa51948..1d6828c 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -154,6 +154,11 @@ public class ParameterizedBuiltinOp extends Hop implements MultiThreadedHop
getInput().get(_paramIndexMap.get("target")) : null;
}
+ public Hop getParameterHop(String name) {
+ return _paramIndexMap.containsKey(name) ?
+ getInput().get(_paramIndexMap.get(name)) : null;
+ }
+
@Override
public void setMaxNumThreads( int k ) {
_maxNumThreads = k;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index eb26eff..a8bbcb2 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -28,6 +28,7 @@ public class CNodeTernary extends CNode
{
public enum TernaryType {
PLUS_MULT, MINUS_MULT,
+ REPLACE, REPLACE_NAN,
LOOKUP_RC1;
public static boolean contains(String value) {
@@ -40,10 +41,17 @@ public class CNodeTernary extends CNode
public String getTemplate(boolean sparse) {
switch (this) {
case PLUS_MULT:
- return " double %TMP% = %IN1% + %IN2% * %IN3%;\n" ;
+ return " double %TMP% = %IN1% + %IN2% * %IN3%;\n";
case MINUS_MULT:
- return " double %TMP% = %IN1% - %IN2% * %IN3%;\n" ;
+ return " double %TMP% = %IN1% - %IN2% * %IN3%;\n";
+
+ case REPLACE:
+ return " double %TMP% = (%IN1% == %IN2% || (Double.isNaN(%IN1%) "
+ + "&& Double.isNaN(%IN2%))) ? %IN3% : %IN1%;\n";
+
+ case REPLACE_NAN:
+ return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
case LOOKUP_RC1:
return " double %TMP% = %IN1%[rowIndex*%IN2%+%IN3%-1];\n";
@@ -101,6 +109,8 @@ public class CNodeTernary extends CNode
switch(_type) {
case PLUS_MULT: return "t(+*)";
case MINUS_MULT: return "t(-*)";
+ case REPLACE:
+ case REPLACE_NAN: return "t(rplc)";
case LOOKUP_RC1: return "u(ixrc1)";
default:
return super.toString();
@@ -112,6 +122,8 @@ public class CNodeTernary extends CNode
switch(_type) {
case PLUS_MULT:
case MINUS_MULT:
+ case REPLACE:
+ case REPLACE_NAN:
case LOOKUP_RC1:
_rows = 0;
_cols = 0;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 87ec899..447f6d6 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,8 +34,10 @@ import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
import org.apache.sysml.hops.codegen.cplan.CNodeBinary;
@@ -157,7 +159,7 @@ public class TemplateCell extends TemplateBase
else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
- String primitiveOpName = ((UnaryOp)hop).getOp().toString();
+ String primitiveOpName = ((UnaryOp)hop).getOp().name();
out = new CNodeUnary(cdata1, UnaryType.valueOf(primitiveOpName));
}
else if(hop instanceof BinaryOp)
@@ -165,7 +167,7 @@ public class TemplateCell extends TemplateBase
BinaryOp bop = (BinaryOp) hop;
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
CNode cdata2 = tmp.get(hop.getInput().get(1).getHopID());
- String primitiveOpName = bop.getOp().toString();
+ String primitiveOpName = bop.getOp().name();
//cdata1 is vector
if( TemplateUtils.isColVector(cdata1) )
@@ -207,7 +209,21 @@ public class TemplateCell extends TemplateBase
//construct ternary cnode, primitive operation derived from OpOp3
out = new CNodeTernary(cdata1, cdata2, cdata3,
- TernaryType.valueOf(top.getOp().toString()));
+ TernaryType.valueOf(top.getOp().name()));
+ }
+ else if( hop instanceof ParameterizedBuiltinOp )
+ {
+ CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+ if( TemplateUtils.isColVector(cdata1) )
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
+ else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
+
+ CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+ CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+ TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
+ TernaryType.REPLACE_NAN : TernaryType.REPLACE;
+ out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
}
else if( hop instanceof IndexingOp )
{
@@ -285,7 +301,8 @@ public class TemplateCell extends TemplateBase
//check supported unary, binary, ternary operations
return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp
|| isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrixDense
- || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense);
+ || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense
+ || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));
}
/**
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
index f8f1508..2883893 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateRowAgg.java
@@ -32,6 +32,7 @@ import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.codegen.cplan.CNode;
@@ -78,7 +79,8 @@ public class TemplateRowAgg extends TemplateBase
return !isClosed() &&
( (hop instanceof BinaryOp && (HopRewriteUtils.isBinaryMatrixColVectorOperation(hop)
|| HopRewriteUtils.isBinaryMatrixScalarOperation(hop)) )
- || (hop instanceof UnaryOp && TemplateCell.isValidOperation(hop))
+ || ((hop instanceof UnaryOp || hop instanceof ParameterizedBuiltinOp)
+ && TemplateCell.isValidOperation(hop))
|| (hop instanceof AggUnaryOp && ((AggUnaryOp)hop).getDirection()!=Direction.RowCol)
|| (hop instanceof AggBinaryOp && hop.getDim1()>1
&& HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))));
@@ -255,6 +257,20 @@ public class TemplateRowAgg extends TemplateBase
out = new CNodeTernary(cdata1, cdata2, cdata3,
TernaryType.valueOf(top.getOp().toString()));
}
+ else if( hop instanceof ParameterizedBuiltinOp )
+ {
+ CNode cdata1 = tmp.get(((ParameterizedBuiltinOp)hop).getTargetHop().getHopID());
+ if( TemplateUtils.isColVector(cdata1) )
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_R);
+ else if( cdata1 instanceof CNodeData && hop.getInput().get(0).getDataType().isMatrix() )
+ cdata1 = new CNodeUnary(cdata1, UnaryType.LOOKUP_RC);
+
+ CNode cdata2 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("pattern").getHopID());
+ CNode cdata3 = tmp.get(((ParameterizedBuiltinOp)hop).getParameterHop("replacement").getHopID());
+ TernaryType ttype = (cdata2.isLiteral() && cdata2.getVarname().equals("Double.NaN")) ?
+ TernaryType.REPLACE_NAN : TernaryType.REPLACE;
+ out = new CNodeTernary(cdata1, cdata2, cdata3, ttype);
+ }
else if( hop instanceof IndexingOp )
{
CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
index 3f5fed9..b959638 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateUtils.java
@@ -30,6 +30,7 @@ import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.LiteralOp;
+import org.apache.sysml.hops.ParameterizedBuiltinOp;
import org.apache.sysml.hops.TernaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.Direction;
@@ -105,6 +106,8 @@ public class TemplateUtils
return BinType.contains(((BinaryOp)h).getOp().name());
else if(h instanceof TernaryOp)
return TernaryType.contains(((TernaryOp)h).getOp().name());
+ else if(h instanceof ParameterizedBuiltinOp)
+ return TernaryType.contains(((ParameterizedBuiltinOp)h).getOp().name());
return false;
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index 066b761..10aa038 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -46,6 +46,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
private static final String TEST_NAME8 = TEST_NAME+8;
private static final String TEST_NAME9 = TEST_NAME+9; //sum((X + 7 * Y)^2)
private static final String TEST_NAME10 = TEST_NAME+10; //min/max(X + 7 * Y)
+ private static final String TEST_NAME11 = TEST_NAME+11; //replace((0 / (X - 500))+1, 0/0, 7);
+
private static final String TEST_DIR = "functions/codegen/";
private static final String TEST_CLASS_DIR = TEST_DIR + CellwiseTmplTest.class.getSimpleName() + "/";
@@ -58,7 +60,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for( int i=1; i<=10; i++ ) {
+ for( int i=1; i<=11; i++ ) {
addTestConfiguration( TEST_NAME+i, new TestConfiguration(
TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
}
@@ -114,6 +116,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
public void testCodegenCellwiseRewrite10() {
testCodegenIntegration( TEST_NAME10, true, ExecType.CP );
}
+
+ @Test
+ public void testCodegenCellwiseRewrite11() {
+ testCodegenIntegration( TEST_NAME11, true, ExecType.CP );
+ }
@Test
public void testCodegenCellwise1() {
@@ -165,6 +172,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
public void testCodegenCellwise10() {
testCodegenIntegration( TEST_NAME10, false, ExecType.CP );
}
+
+ @Test
+ public void testCodegenCellwise11() {
+ testCodegenIntegration( TEST_NAME11, false, ExecType.CP );
+ }
@Test
public void testCodegenCellwiseRewrite1_sp() {
@@ -191,6 +203,11 @@ public class CellwiseTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME10, true, ExecType.SPARK );
}
+ @Test
+ public void testCodegenCellwiseRewrite11_sp() {
+ testCodegenIntegration( TEST_NAME11, true, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
@@ -247,7 +264,8 @@ public class CellwiseTmplTest extends AutomatedTestBase
Assert.assertTrue(!heavyHittersContainsSubString("tsmm"));
else if( testname.equals(TEST_NAME10) ) //ensure min/max is fused
Assert.assertTrue(!heavyHittersContainsSubString("uamin","uamax"));
-
+ else if( testname.equals(TEST_NAME11) ) //ensure replace is fused
+ Assert.assertTrue(!heavyHittersContainsSubString("replace"));
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldRewrites;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.R b/src/test/scripts/functions/codegen/cellwisetmpl11.R
new file mode 100644
index 0000000..33531ba
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(7, 1006), 500, 2, byrow=TRUE);
+
+Y = (0 / (X - 500))+1;
+R = replace(Y, is.nan(Y), 7);
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/2e48d951/src/test/scripts/functions/codegen/cellwisetmpl11.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl11.dml b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
new file mode 100644
index 0000000..c77da08
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl11.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(7, 1006), 500, 2);
+
+Y = (0 / (X - 500))+1;
+R = replace(target=Y, pattern=0/0, replacement=7);
+
+write(R, $1)