You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2016/05/14 20:22:35 UTC

[2/5] incubator-systemml git commit: [SYSTEMML-695] Fix 'fuse datagen' rewrites for non-uniform pdfs, tests

[SYSTEMML-695] Fix 'fuse datagen' rewrites for non-uniform pdfs, tests

The rewrites for fusing rand operations with unary/binary operations
(e.g., rand(min=0, max=1)*7 -> rand(min=0, max=7)) were applied without
awareness of the given pdf, which produced incorrect results in case of
pdf=normal as the min/max parameters are reused for distribution
parameters. This patch fixes all related rewrites and introduces
necessary tests for all supported pdfs.

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d21e7d9a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d21e7d9a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d21e7d9a

Branch: refs/heads/master
Commit: d21e7d9aa29d91dfcb4f5905b42b5d99a5c8f753
Parents: d39e7d6
Author: Matthias Boehm <mb...@us.ibm.com>
Authored: Fri May 13 22:53:04 2016 -0700
Committer: Matthias Boehm <mb...@us.ibm.com>
Committed: Sat May 14 00:16:44 2016 -0700

----------------------------------------------------------------------
 .../RewriteAlgebraicSimplificationStatic.java   |  12 +-
 .../functions/misc/RewriteFusedRandTest.java    | 113 +++++++++++++++++++
 .../scripts/functions/misc/RewriteFusedRand.dml |  29 +++++
 .../functions/misc/ZPackageSuite.java           |   1 +
 4 files changed, 152 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d21e7d9a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index d2708a4..fff9310 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -359,12 +359,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			{
 				DataGenOp inputGen = (DataGenOp)left;
 				HashMap<String,Integer> params = inputGen.getParamIndexMap();
+				Hop pdf = left.getInput().get(params.get(DataExpression.RAND_PDF));
 				Hop min = left.getInput().get(params.get(DataExpression.RAND_MIN));
 				Hop max = left.getInput().get(params.get(DataExpression.RAND_MAX));
 				double sval = ((LiteralOp)right).getDoubleValue();
 				
 				if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS || bop.getOp() == OpOp2.MINUS)
-					&& min instanceof LiteralOp && max instanceof LiteralOp )
+					&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp 
+					&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
 				{
 					//create fused data gen operator
 					DataGenOp gen = null;
@@ -395,12 +397,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			{
 				DataGenOp inputGen = (DataGenOp)right;
 				HashMap<String,Integer> params = inputGen.getParamIndexMap();
+				Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
 				Hop min = right.getInput().get(params.get(DataExpression.RAND_MIN));
 				Hop max = right.getInput().get(params.get(DataExpression.RAND_MAX));
 				double sval = ((LiteralOp)left).getDoubleValue();
 				
 				if( (bop.getOp()==OpOp2.MULT || bop.getOp()==OpOp2.PLUS)
-					&& min instanceof LiteralOp && max instanceof LiteralOp )
+					&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp 
+					&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
 				{
 					//create fused data gen operator
 					DataGenOp gen = null;
@@ -449,6 +453,7 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 			{
 				DataGenOp inputGen = (DataGenOp)right;
 				HashMap<String,Integer> params = inputGen.getParamIndexMap();
+				Hop pdf = right.getInput().get(params.get(DataExpression.RAND_PDF));
 				int ixMin = params.get(DataExpression.RAND_MIN);
 				int ixMax = params.get(DataExpression.RAND_MAX);
 				Hop min = right.getInput().get(ixMin);
@@ -456,7 +461,8 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule
 				
 				//apply rewrite under additional conditions (for simplicity)
 				if( inputGen.getParent().size()==1 
-					&& min instanceof LiteralOp && max instanceof LiteralOp )
+					&& min instanceof LiteralOp && max instanceof LiteralOp && pdf instanceof LiteralOp 
+					&& DataExpression.RAND_PDF_UNIFORM.equals(((LiteralOp)pdf).getStringValue()) )
 				{
 					//exchange and *-1 (special case 0 stays 0 instead of -0 for consistency)
 					double newMinVal = (((LiteralOp)max).getDoubleValue()==0)?0:(-1 * ((LiteralOp)max).getDoubleValue());

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d21e7d9a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
new file mode 100644
index 0000000..1491aa6
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteFusedRandTest.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+
+/**
+ * 
+ * 
+ */
+public class RewriteFusedRandTest extends AutomatedTestBase 
+{	
+	private static final String TEST_NAME1 = "RewriteFusedRand";
+	private static final String TEST_DIR = "functions/misc/";
+	private static final String TEST_CLASS_DIR = TEST_DIR + RewriteFusedRandTest.class.getSimpleName() + "/";
+	
+	private static final int rows = 1932;
+	private static final int cols = 14;
+	private static final int seed = 7;
+	
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+	}
+
+	@Test
+	public void testRewriteFusedRandUniformNoRewrite()  {
+		testRewriteFusedRand( TEST_NAME1, "uniform", false );
+	}
+	
+	@Test
+	public void testRewriteFusedRandNormalNoRewrite()  {
+		testRewriteFusedRand( TEST_NAME1, "normal", false );
+	}
+	
+	@Test
+	public void testRewriteFusedRandPoissonNoRewrite()  {
+		testRewriteFusedRand( TEST_NAME1, "poisson", false );
+	}
+	
+	@Test
+	public void testRewriteFusedRandUniform()  {
+		testRewriteFusedRand( TEST_NAME1, "uniform", true );
+	}
+	
+	@Test
+	public void testRewriteFusedRandNormal()  {
+		testRewriteFusedRand( TEST_NAME1, "normal", true );
+	}
+	
+	@Test
+	public void testRewriteFusedRandPoisson()  {
+		testRewriteFusedRand( TEST_NAME1, "poisson", true );
+	}
+	
+	/**
+	 * 
+	 * @param condition
+	 * @param branchRemoval
+	 * @param IPA
+	 */
+	private void testRewriteFusedRand( String testname, String pdf, boolean rewrites )
+	{	
+		boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+		
+		try {
+			TestConfiguration config = getTestConfiguration(testname);
+			loadTestConfiguration(config);
+			
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + testname + ".dml";
+			programArgs = new String[]{ "-args", String.valueOf(rows), 
+					String.valueOf(cols), pdf, String.valueOf(seed), output("R") };
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;
+
+			//run performance tests
+			runTest(true, false, null, -1); 
+			
+			//compare matrices 
+			HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+			Assert.assertEquals("Wrong result, expected: "+rows, new Double(rows), dmlfile.get(new CellIndex(1,1)));
+		}
+		finally {
+			OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+		}
+	}	
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d21e7d9a/src/test/scripts/functions/misc/RewriteFusedRand.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/misc/RewriteFusedRand.dml b/src/test/scripts/functions/misc/RewriteFusedRand.dml
new file mode 100644
index 0000000..6e2cb58
--- /dev/null
+++ b/src/test/scripts/functions/misc/RewriteFusedRand.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X1 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+
+if(1==1){} #prevent cse
+
+X2 = rand(rows=$1, cols=$2, pdf=$3, seed=$4) * 7;
+
+R = as.matrix(sum(rowSums(X1)==rowSums(X2)));
+write(R, $5);
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d21e7d9a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index f535de8..0e01913 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -44,6 +44,7 @@ import org.junit.runners.Suite;
 	PrintExpressionTest.class,
 	PrintMatrixTest.class,
 	ReadAfterWriteTest.class,
+	RewriteFusedRandTest.class,
 	RewriteSimplifyRowColSumMVMultTest.class,
 	RewriteSlicedMatrixMultTest.class,
 	ScalarAssignmentTest.class,