You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2017/07/15 04:15:40 UTC
[14/23] systemml git commit: Review comments 3
Review comments 3
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/04f692df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/04f692df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/04f692df
Branch: refs/heads/master
Commit: 04f692dfcb25a032044dabb7064241073f959300
Parents: de469d2
Author: Dylan Hutchison <dh...@cs.washington.edu>
Authored: Sun Jun 18 16:54:51 2017 -0700
Committer: Dylan Hutchison <dh...@cs.washington.edu>
Committed: Sun Jun 18 17:43:54 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/sysml/hops/AggUnaryOp.java | 4 +-
.../sysml/hops/rewrite/ProgramRewriter.java | 2 +-
...ementwiseMultChainOptimizationChainTest.java | 127 -------------------
...iteElementwiseMultChainOptimizationTest.java | 127 +++++++++++++++++++
.../functions/misc/ZPackageSuite.java | 2 +-
5 files changed, 132 insertions(+), 130 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index a207831..8e681c1 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -647,7 +647,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
} else if (input11 instanceof BinaryOp ) {
BinaryOp b11 = (BinaryOp)input11;
- switch (b11.getOp()) {
+ switch( b11.getOp() ) {
case MULT: // A*B*C case
in1 = input11.getInput().get(0).constructLops();
in2 = input11.getInput().get(1).constructLops();
@@ -664,6 +664,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
}
break;
+ default: break;
}
} else if( input12 instanceof BinaryOp ) {
BinaryOp b12 = (BinaryOp)input12;
@@ -683,6 +684,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop
handled = true;
}
break;
+ default: break;
}
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
index 1053850..7ee3ccb 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/ProgramRewriter.java
@@ -97,7 +97,7 @@ public class ProgramRewriter
if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION )
_dagRuleSet.add( new RewriteCommonSubexpressionElimination() );
if ( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES)
- _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
+ _dagRuleSet.add( new RewriteElementwiseMultChainOptimization() ); //dependency: cse
if( OptimizerUtils.ALLOW_CONSTANT_FOLDING )
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION )
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
deleted file mode 100644
index e490750..0000000
--- a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationChainTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.test.integration.functions.misc;
-
-import java.util.HashMap;
-
-import org.apache.sysml.api.DMLScript;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.lops.LopProperties.ExecType;
-import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.apache.sysml.test.integration.TestConfiguration;
-import org.apache.sysml.test.utils.TestUtils;
-import org.junit.Assert;
-import org.junit.Test;
-
-/**
- * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
- */
-public class RewriteElementwiseMultChainOptimizationChainTest extends AutomatedTestBase
-{
- private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
- private static final String TEST_DIR = "functions/misc/";
- private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationChainTest.class.getSimpleName() + "/";
-
- private static final int rows = 123;
- private static final int cols = 321;
- private static final double eps = Math.pow(10, -10);
-
- @Override
- public void setUp() {
- TestUtils.clearAssertionInformation();
- addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptNoRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesCP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
- }
-
- @Test
- public void testMatrixMultChainOptRewritesSP() {
- testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
- }
-
- private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
- {
- RUNTIME_PLATFORM platformOld = rtplatform;
- switch( et ){
- case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
- case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
- default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
- }
-
- boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
- if( rtplatform == RUNTIME_PLATFORM.SPARK )
- DMLScript.USE_LOCAL_SPARK_CONFIG = true;
-
- boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
-
- try
- {
- TestConfiguration config = getTestConfiguration(testname);
- loadTestConfiguration(config);
-
- String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + testname + ".dml";
- programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
- fullRScriptName = HOME + testname + ".R";
- rCmd = getRCmd(inputDir(), expectedDir());
-
- double Xsparsity = 0.8, Ysparsity = 0.6;
- double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
- double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
- writeInputMatrixWithMTD("X", X, true);
- writeInputMatrixWithMTD("Y", Y, true);
-
- //execute tests
- runTest(true, false, null, -1);
- runRScript(true);
-
- //compare matrices
- HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
- HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
- TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
-
- //check for presence of power operator, if we did a rewrite
- if( rewrites ) {
- Assert.assertTrue(heavyHittersContainsSubString("^2"));
- }
- }
- finally {
- OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
- rtplatform = platformOld;
- DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
new file mode 100644
index 0000000..91cb4e0
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/functions/misc/RewriteElementwiseMultChainOptimizationTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.misc;
+
+import java.util.HashMap;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Test whether `2*X*3*Y*4*X` successfully rewrites to `Y*(X^2)*24`.
+ */
+public class RewriteElementwiseMultChainOptimizationTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME1 = "RewriteEMultChainOpXYX";
+ private static final String TEST_DIR = "functions/misc/";
+ private static final String TEST_CLASS_DIR = TEST_DIR + RewriteElementwiseMultChainOptimizationTest.class.getSimpleName() + "/";
+
+ private static final int rows = 123;
+ private static final int cols = 321;
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptNoRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, false, ExecType.SPARK);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesCP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.CP);
+ }
+
+ @Test
+ public void testMatrixMultChainOptRewritesSP() {
+ testRewriteMatrixMultChainOp(TEST_NAME1, true, ExecType.SPARK);
+ }
+
+ private void testRewriteMatrixMultChainOp(String testname, boolean rewrites, ExecType et)
+ {
+ RUNTIME_PLATFORM platformOld = rtplatform;
+ switch( et ){
+ case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+ case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ default: rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; break;
+ }
+
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+ if( rtplatform == RUNTIME_PLATFORM.SPARK )
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+ boolean rewritesOld = OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES;
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewrites;
+
+ try
+ {
+ TestConfiguration config = getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testname + ".dml";
+ programArgs = new String[] { "-explain", "hops", "-stats", "-args", input("X"), input("Y"), output("R") };
+ fullRScriptName = HOME + testname + ".R";
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ double Xsparsity = 0.8, Ysparsity = 0.6;
+ double[][] X = getRandomMatrix(rows, cols, -1, 1, Xsparsity, 7);
+ double[][] Y = getRandomMatrix(rows, cols, -1, 1, Ysparsity, 3);
+ writeInputMatrixWithMTD("X", X, true);
+ writeInputMatrixWithMTD("Y", Y, true);
+
+ //execute tests
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("R");
+ HashMap<CellIndex, Double> rfile = readRMatrixFromFS("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
+
+ //check for presence of power operator, if we did a rewrite
+ if( rewrites ) {
+ Assert.assertTrue(heavyHittersContainsSubString("^2"));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES = rewritesOld;
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/04f692df/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
----------------------------------------------------------------------
diff --git a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
index deea784..860cdbe 100644
--- a/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
+++ b/src/test_suites/java/org/apache/sysml/test/integration/functions/misc/ZPackageSuite.java
@@ -50,7 +50,7 @@ import org.junit.runners.Suite;
ReadAfterWriteTest.class,
RewriteCSETransposeScalarTest.class,
RewriteCTableToRExpandTest.class,
- RewriteElementwiseMultChainOptimizationChainTest.class,
+ RewriteElementwiseMultChainOptimizationTest.class,
RewriteEliminateAggregatesTest.class,
RewriteFuseBinaryOpChainTest.class,
RewriteFusedRandTest.class,