You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/07/29 20:21:09 UTC

[systemds] branch master updated: [SYSTEMDS-3077] Fix codegen row template (support for colMeans)

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new ab810e5  [SYSTEMDS-3077] Fix codegen row template (support for colMeans)
ab810e5 is described below

commit ab810e5170954ff600eb23aec75b4ce24bffe084
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Thu Jul 29 22:20:47 2021 +0200

    [SYSTEMDS-3077] Fix codegen row template (support for colMeans)
    
    This patch fixes the missing full support for colMeans operations in
    codegen row templates which led to compiler failures during operator
    fusion. We now route colMeans to the basic colSums fusion pattern, with
    additional compensation via a basic div hop afterwards. These patterns
    appeared in various algorithms due to the new rewrites.
---
 src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java | 10 +++++++++-
 .../org/apache/sysds/hops/codegen/template/TemplateRow.java    |  5 +++--
 .../sysds/test/component/compress/workload/WorkloadTest.java   |  3 +--
 .../test/functions/builtin/BuiltinAutoencoder2LayerTest.java   |  2 +-
 .../functions/codegenalg/partone/AlgorithmAutoEncoder.java     |  2 +-
 .../test/functions/federated/algorithms/FederatedPCATest.java  |  1 -
 6 files changed, 15 insertions(+), 8 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
index 11712dd..4886691 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofCompiler.java
@@ -25,13 +25,17 @@ import org.apache.commons.lang3.SystemUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggUnaryOp;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.LiteralOp;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.hops.codegen.cplan.CNode;
 import org.apache.sysds.hops.codegen.cplan.CNodeCell;
@@ -828,8 +832,12 @@ public class SpoofCompiler {
 				hnew = HopRewriteUtils.createUnary(hnew, OpOp1.CAST_AS_MATRIX);
 			}
 			else if( tmpCNode instanceof CNodeRow && (((CNodeRow)tmpCNode).getRowType()==RowType.NO_AGG_CONST
-				|| ((CNodeRow)tmpCNode).getRowType()==RowType.COL_AGG_CONST) )
+				|| ((CNodeRow)tmpCNode).getRowType()==RowType.COL_AGG_CONST) ) {
 				((SpoofFusedOp)hnew).setConstDim2(((CNodeRow)tmpCNode).getConstDim2());
+			}
+			else if( tmpCNode instanceof CNodeRow && HopRewriteUtils.isAggUnaryOp(hop, AggOp.MEAN, Direction.Col) ) {
+				hnew = HopRewriteUtils.createBinary(hnew, new LiteralOp(hop.getInput(0).getDim1()), OpOp2.DIV);
+			}
 			
 			if( !(tmpCNode instanceof CNodeMultiAgg) )
 				HopRewriteUtils.rewireAllParentChildReferences(hop, hnew);
diff --git a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index 1962a65..d811c3c 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -298,7 +298,7 @@ public class TemplateRow extends TemplateBase
 		if(hop instanceof AggUnaryOp)
 		{
 			CNode cdata1 = tmp.get(hop.getInput().get(0).getHopID());
-			if( ((AggUnaryOp)hop).getDirection() == Direction.Row && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) {
+			if( ((AggUnaryOp)hop).getDirection().isRow() && HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_ROW_AGG) ) {
 				if(hop.getInput().get(0).getDim2()==1)
 					out = (cdata1.getDataType()==DataType.SCALAR) ? cdata1 : new CNodeUnary(cdata1,UnaryType.LOOKUP_R);
 				else {
@@ -308,7 +308,8 @@ public class TemplateRow extends TemplateBase
 						inHops2.put("X", hop.getInput().get(0));
 				}
 			}
-			else if (((AggUnaryOp)hop).getDirection() == Direction.Col && ((AggUnaryOp)hop).getOp() == AggOp.SUM ) {
+			else if ( HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, AggOp.MEAN) 
+				&& ((AggUnaryOp)hop).getDirection().isCol() ) { //closes row template
 				//vector add without temporary copy
 				if( cdata1 instanceof CNodeBinary && ((CNodeBinary)cdata1).getType().isVectorScalarPrimitive() )
 					out = new CNodeBinary(cdata1.getInput().get(0), cdata1.getInput().get(1),
diff --git a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
index d2a3258..aacb716 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
@@ -112,7 +112,7 @@ public class WorkloadTest {
 		tests.add(new Object[] {0, 0, 0, 0, 0, 0, 8, 0, true, false, "functions/scale_onlySide.dml", args});
 
 		tests.add(new Object[] {0, 0, 0, 0, 1, 1, 9, 0, true, false, "functions/pca.dml", args});
-		tests.add(new Object[] {0, 0, 0, 0, 1, 1, 7, 0, true, true, "functions/pca.dml", args});
+		tests.add(new Object[] {0, 0, 0, 0, 1, 1, 6, 0, true, true, "functions/pca.dml", args});
 
 		args = new HashMap<>();
 		args.put("$1", testFile);
@@ -220,5 +220,4 @@ public class WorkloadTest {
 			throw new DMLRuntimeException("Error in parsing", e);
 		}
 	}
-
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinAutoencoder2LayerTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinAutoencoder2LayerTest.java
index 81f93a5..f5ef171 100644
--- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinAutoencoder2LayerTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinAutoencoder2LayerTest.java
@@ -38,7 +38,7 @@ public class BuiltinAutoencoder2LayerTest extends AutomatedTestBase
 	private final static int cols = 784;
 	private final static double sparse = 0.1;
 	private final static double dense = 0.7;
-	private final static double tolerance = 1e-3;
+	private final static double tolerance = 2e-3;
 
 	private static int batchSize = 256;
 	private static double step = 1e-5;
diff --git a/src/test/java/org/apache/sysds/test/functions/codegenalg/partone/AlgorithmAutoEncoder.java b/src/test/java/org/apache/sysds/test/functions/codegenalg/partone/AlgorithmAutoEncoder.java
index 321de61..e72708f 100644
--- a/src/test/java/org/apache/sysds/test/functions/codegenalg/partone/AlgorithmAutoEncoder.java
+++ b/src/test/java/org/apache/sysds/test/functions/codegenalg/partone/AlgorithmAutoEncoder.java
@@ -44,7 +44,7 @@ public class AlgorithmAutoEncoder extends AutomatedTestBase
 	
 	private final static double sparsity1 = 0.7; //dense
 	private final static double sparsity2 = 0.1; //sparse
-	private final static double eps       = 1e-5;
+	private final static double eps       = 3e-4;
 	
 	private final static int H1 = 500;
 	private final static int H2 = 2;
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
index 04299a8..2a41866 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedPCATest.java
@@ -131,7 +131,6 @@ public class FederatedPCATest extends AutomatedTestBase {
 
 		// check for federated operations
 		Assert.assertTrue(heavyHittersContainsString("fed_ba+*"));
-		Assert.assertTrue(heavyHittersContainsString("fed_uack+"));
 		Assert.assertTrue(heavyHittersContainsString("fed_tsmm"));
 		if(scaleAndShift) {
 			Assert.assertTrue(heavyHittersContainsString("fed_uacsqk+"));