You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2018/02/07 06:16:33 UTC
[2/2] systemml git commit: [SYSTEMML-2082] Codegen support for
ternary ifelse in cell/magg tmpls
[SYSTEMML-2082] Codegen support for ternary ifelse in cell/magg tmpls
This patch adds basic support for ternary ifelse operations in codegen
cell and magg templates along with related tests.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/aa537dad
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/aa537dad
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/aa537dad
Branch: refs/heads/master
Commit: aa537dad43f2cf21badaedcb8629b27ad301032b
Parents: 5457066
Author: Matthias Boehm <mb...@gmail.com>
Authored: Tue Feb 6 20:06:28 2018 -0800
Committer: Matthias Boehm <mb...@gmail.com>
Committed: Tue Feb 6 20:06:28 2018 -0800
----------------------------------------------------------------------
.../sysml/hops/codegen/cplan/CNodeTernary.java | 23 ++++++++------
.../hops/codegen/template/TemplateCell.java | 12 +++++---
.../functions/codegen/CellwiseTmplTest.java | 18 ++++++++++-
.../scripts/functions/codegen/cellwisetmpl18.R | 32 ++++++++++++++++++++
.../functions/codegen/cellwisetmpl18.dml | 30 ++++++++++++++++++
5 files changed, 99 insertions(+), 16 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
index 155cc8b..dc8ff82 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/cplan/CNodeTernary.java
@@ -27,7 +27,7 @@ public class CNodeTernary extends CNode
{
public enum TernaryType {
PLUS_MULT, MINUS_MULT,
- REPLACE, REPLACE_NAN,
+ REPLACE, REPLACE_NAN, IFELSE,
LOOKUP_RC1, LOOKUP_RVECT1;
@@ -52,7 +52,10 @@ public class CNodeTernary extends CNode
case REPLACE_NAN:
return " double %TMP% = Double.isNaN(%IN1%) ? %IN3% : %IN1%;\n";
-
+
+ case IFELSE:
+ return " double %TMP% = (%IN1% != 0) ? %IN2% : %IN3%;\n";
+
case LOOKUP_RC1:
return sparse ?
" double %TMP% = getValue(%IN1v%, %IN1i%, ai, alen, %IN3%-1);\n" :
@@ -124,15 +127,14 @@ public class CNodeTernary extends CNode
@Override
public String toString() {
switch(_type) {
- case PLUS_MULT: return "t(+*)";
- case MINUS_MULT: return "t(-*)";
- case REPLACE:
- case REPLACE_NAN: return "t(rplc)";
- case LOOKUP_RC1: return "u(ixrc1)";
+ case PLUS_MULT: return "t(+*)";
+ case MINUS_MULT: return "t(-*)";
+ case REPLACE:
+ case REPLACE_NAN: return "t(rplc)";
+ case IFELSE: return "t(ifelse)";
+ case LOOKUP_RC1: return "u(ixrc1)";
case LOOKUP_RVECT1: return "u(ixrv1)";
-
- default:
- return super.toString();
+ default: return super.toString();
}
}
@@ -143,6 +145,7 @@ public class CNodeTernary extends CNode
case MINUS_MULT:
case REPLACE:
case REPLACE_NAN:
+ case IFELSE:
case LOOKUP_RC1:
_rows = 0;
_cols = 0;
http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
index 50b42ea..2b8db2a 100644
--- a/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
+++ b/src/main/java/org/apache/sysml/hops/codegen/template/TemplateCell.java
@@ -34,6 +34,7 @@ import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.Hop.AggOp;
import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOp3;
import org.apache.sysml.hops.Hop.ParamBuiltinOp;
import org.apache.sysml.hops.IndexingOp;
import org.apache.sysml.hops.LiteralOp;
@@ -168,7 +169,7 @@ public class TemplateCell extends TemplateBase
&& HopRewriteUtils.isMatrixMultiply(hop) && i==0 ) //skip transpose
rConstructCplan(c.getInput().get(0), memo, tmp, inHops, compileLiterals);
else {
- CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
+ CNodeData cdata = TemplateUtils.createCNodeData(c, compileLiterals);
tmp.put(c.getHopID(), cdata);
inHops.add(c);
}
@@ -208,6 +209,7 @@ public class TemplateCell extends TemplateBase
//add lookups if required
cdata1 = TemplateUtils.wrapLookupIfNecessary(cdata1, hop.getInput().get(0));
+ cdata2 = TemplateUtils.wrapLookupIfNecessary(cdata2, hop.getInput().get(1));
cdata3 = TemplateUtils.wrapLookupIfNecessary(cdata3, hop.getInput().get(2));
//construct ternary cnode, primitive operation derived from OpOp3
@@ -299,11 +301,11 @@ public class TemplateCell extends TemplateBase
//prepare indicators for ternary operations
boolean isTernaryVectorScalarVector = false;
boolean isTernaryMatrixScalarMatrixDense = false;
+ boolean isTernaryIfElse = (HopRewriteUtils.isTernary(hop, OpOp3.IFELSE) && hop.getDataType().isMatrix());
if( hop instanceof TernaryOp && hop.getInput().size()==3 && hop.dimsKnown()
- && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX)) {
+ && HopRewriteUtils.checkInputDataTypes(hop, DataType.MATRIX, DataType.SCALAR, DataType.MATRIX) ) {
Hop left = hop.getInput().get(0);
Hop right = hop.getInput().get(2);
-
isTernaryVectorScalarVector = TemplateUtils.isVector(left) && TemplateUtils.isVector(right);
isTernaryMatrixScalarMatrixDense = HopRewriteUtils.isEqualSize(left, right)
&& !HopRewriteUtils.isSparse(left) && !HopRewriteUtils.isSparse(right);
@@ -312,8 +314,8 @@ public class TemplateCell extends TemplateBase
//check supported unary, binary, ternary operations
return hop.getDataType() == DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && (hop instanceof UnaryOp
|| isBinaryMatrixScalar || isBinaryMatrixVector || isBinaryMatrixMatrix
- || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense
- || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));
+ || isTernaryVectorScalarVector || isTernaryMatrixScalarMatrixDense || isTernaryIfElse
+ || (hop instanceof ParameterizedBuiltinOp && ((ParameterizedBuiltinOp)hop).getOp()==ParamBuiltinOp.REPLACE));
}
protected boolean isSparseSafe(List<Hop> roots, Hop mainInput, List<CNode> outputs, List<AggOp> aggOps, boolean onlySum) {
http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
index bd3b36a..2f44f61 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/codegen/CellwiseTmplTest.java
@@ -53,6 +53,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
private static final String TEST_NAME15 = TEST_NAME+15; //colMins(2*log(X))
private static final String TEST_NAME16 = TEST_NAME+16; //colSums(2*log(X));
private static final String TEST_NAME17 = TEST_NAME+17; //xor operation
+ private static final String TEST_NAME18 = TEST_NAME+18; //sum(ifelse(X,Y,Z))
private static final String TEST_DIR = "functions/codegen/";
@@ -66,7 +67,7 @@ public class CellwiseTmplTest extends AutomatedTestBase
@Override
public void setUp() {
TestUtils.clearAssertionInformation();
- for( int i=1; i<=17; i++ ) {
+ for( int i=1; i<=18; i++ ) {
addTestConfiguration( TEST_NAME+i, new TestConfiguration(
TEST_CLASS_DIR, TEST_NAME+i, new String[] {String.valueOf(i)}) );
}
@@ -304,6 +305,21 @@ public class CellwiseTmplTest extends AutomatedTestBase
testCodegenIntegration( TEST_NAME17, true, ExecType.SPARK );
}
+ @Test
+ public void testCodegenCellwiseRewrite18() {
+ testCodegenIntegration( TEST_NAME18, true, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenCellwise18() {
+ testCodegenIntegration( TEST_NAME18, false, ExecType.CP );
+ }
+
+ @Test
+ public void testCodegenCellwiseRewrite18_sp() {
+ testCodegenIntegration( TEST_NAME18, true, ExecType.SPARK );
+ }
+
private void testCodegenIntegration( String testname, boolean rewrites, ExecType instType )
{
http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.R
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.R b/src/test/scripts/functions/codegen/cellwisetmpl18.R
new file mode 100644
index 0000000..e6a275a
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = matrix(seq(-1000, 198999), 1000, 200, byrow=TRUE);
+Y = matrix(seq(0, 199999), 1000, 200, byrow=TRUE);
+Z = matrix(seq(1000, 200999), 1000, 200, byrow=TRUE);
+
+R = as.matrix(sum(as.numeric(ifelse(X,Y,Z))));
+
+writeMM(as(R,"CsparseMatrix"), paste(args[2], "S", sep=""));
http://git-wip-us.apache.org/repos/asf/systemml/blob/aa537dad/src/test/scripts/functions/codegen/cellwisetmpl18.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/codegen/cellwisetmpl18.dml b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
new file mode 100644
index 0000000..c178dd3
--- /dev/null
+++ b/src/test/scripts/functions/codegen/cellwisetmpl18.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(-1000, 198999), 1000, 200);
+Y = matrix(seq(0, 199999), 1000, 200);
+Z = matrix(seq(1000, 200999), 1000, 200);
+
+while(FALSE){}
+
+R = as.matrix(sum(ifelse(X,Y,Z)));
+
+write(R, $1)