You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/10/31 19:11:08 UTC

[systemds] branch master updated: [SYSTEMDS-2836] Extended update in-place for unary operators

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new be4f940  [SYSTEMDS-2836] Extended update in-place for unary operators
be4f940 is described below

commit be4f9404a62b291e997ee5205db395d6ff1b2ae7
Author: Ismael Ibrahim <is...@student.tugraz.at>
AuthorDate: Sun Oct 31 20:08:34 2021 +0100

    [SYSTEMDS-2836] Extended update in-place for unary operators
    
    AMLS project SS2021.
    Closes #1406.
    
    Co-authored-by: Maximilian Theiner <ma...@student.tugraz.at>
    Co-authored-by: Alexander Kropiunig <al...@student.tugraz.at>
    Co-authored-by: Matthias Boehm <mb...@gmail.com>
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  2 +-
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  8 +++
 src/main/java/org/apache/sysds/hops/UnaryOp.java   | 17 ++++-
 src/main/java/org/apache/sysds/lops/Unary.java     |  5 ++
 .../instructions/cp/UnaryCPInstruction.java        |  8 +--
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    | 10 +--
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 19 +++--
 .../updateinplace/UnaryUpdateInPlaceTest.java      | 80 ++++++++++++++++++++++
 .../functions/updateinplace/UnaryUpdateInplace.dml | 36 ++++++++++
 9 files changed, 168 insertions(+), 17 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index a25cf10..9114c55 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -65,7 +65,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
 
 public abstract class Hop implements ParseInfo {
 	private static final Log LOG =  LogFactory.getLog(Hop.class.getName());
-	
+
 	public static final long CPThreshold = 2000;
 
 	// static variable to assign an unique ID to every hop that is created
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index be916d9..1b94413 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -191,6 +191,14 @@ public class OptimizerUtils
 	public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true;
 	
 	/**
+	 * Enables the update-in-place for all unary operators with a single
+	 * consumer. In this case we do not allocate the output, but directly
+	 * write the output values back to the input block.
+	 */
+	//TODO enabling it by default requires modifications in lineage-based reuse
+	public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
+	
+	/**
 	 * Replace eval second-order function calls with normal function call
 	 * if the function name is a known string (after constant propagation).
 	 */
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 38199b2..d4e8f34 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -43,6 +43,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
 import java.util.ArrayList;
 
 
+
 /* Unary (cell operations): e.g, b_ij = round(a_ij)
  * 		Semantic: given a value, perform the operation (independent of other values)
  */
@@ -57,7 +58,7 @@ public class UnaryOp extends MultiThreadedHop
 	private UnaryOp() {
 		//default constructor for clone
 	}
-	
+
 	public UnaryOp(String l, DataType dt, ValueType vt, OpOp1 o, Hop inp) {
 		super(l, dt, vt);
 
@@ -130,7 +131,7 @@ public class UnaryOp extends MultiThreadedHop
 		try 
 		{
 			Hop input = getInput().get(0);
-			
+
 			if(    getDataType() == DataType.SCALAR //value type casts or matrix to scalar
 				|| (_op == OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType()==DataType.SCALAR)
 				|| (_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType()==DataType.SCALAR))
@@ -165,10 +166,20 @@ public class UnaryOp extends MultiThreadedHop
 				}
 				else //default unary 
 				{
+					boolean inplace = false;
+
+					//check in-place
+					if (OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE
+						&& input.getParent().size() == 1)
+					{
+						inplace = !(input instanceof DataOp)
+							|| !((DataOp) input).isRead();
+					}
+
 					int k = isCumulativeUnaryOperation() || isExpensiveUnaryOperation() ?
 						OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ) : 1;
 					Unary unary1 = new Unary(input.constructLops(),
-						_op, getDataType(), getValueType(), et, k, false);
+						_op, getDataType(), getValueType(), et, k, inplace);
 					setOutputDimensions(unary1);
 					setLineNumbers(unary1);
 					setLops(unary1);
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java
index ad4b2b8..f0a59fa 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -122,6 +122,7 @@ public class Unary extends Lop
 	}
 	
 	public static boolean isMultiThreadedOp(OpOp1 op) {
+		//TODO extend for all basic unary operations
 		return op==OpOp1.CUMSUM
 			|| op==OpOp1.CUMPROD
 			|| op==OpOp1.CUMMIN
@@ -129,6 +130,10 @@ public class Unary extends Lop
 			|| op==OpOp1.CUMSUMPROD
 			|| op==OpOp1.EXP
 			|| op==OpOp1.LOG
+			|| op==OpOp1.ABS
+			|| op==OpOp1.ROUND
+			|| op==OpOp1.FLOOR
+			|| op==OpOp1.CEIL
 			|| op==OpOp1.SIGMOID
 			|| op==OpOp1.POW2
 			|| op==OpOp1.MULT2;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
index 0c98e84..8f92c07 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
@@ -19,10 +19,10 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.Arrays;
-
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Unary;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +61,8 @@ public abstract class UnaryCPInstruction extends ComputationCPInstruction {
 			in.split(parts[1]);
 			out.split(parts[2]);
 			func = Builtin.getBuiltinFnObject(opcode);
-			
-			if( Arrays.asList(new String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ){
+			Types.OpOp1 op_type = Types.OpOp1.valueOfByOpcode(opcode);
+			if( Unary.isMultiThreadedOp(op_type)){
 				UnaryOperator op = new UnaryOperator(func, Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
 				return new UnaryMatrixCPInstruction(op, in, out, opcode, str);
 			}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 22c437d..0d3c007 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -294,12 +294,14 @@ public class LibMatrixAgg
 	}
 	
 	public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, UnaryOperator uop, double[] agg) {
-		//prepare meta data 
+		//Check this implementation, standard case for cumagg (single threaded)
+
+		//prepare meta data
 		AggType aggtype = getAggType(uop);
 		final int m = in.rlen;
 		final int m2 = out.rlen;
 		final int n2 = out.clen;
-		
+
 		//filter empty input blocks (incl special handling for sparse-unsafe operations)
 		if( in.isEmpty() && (agg == null || aggtype == AggType.CUM_SUM_PROD) ) {
 			return aggregateUnaryMatrixEmpty(in, out, aggtype, null);
@@ -317,7 +319,7 @@ public class LibMatrixAgg
 		}
 		
 		//Timing time = new Timing(true);
-		
+
 		if( !in.sparse )
 			cumaggregateUnaryMatrixDense(in, out, aggtype, uop.fn, agg, 0, m);
 		else
@@ -336,7 +338,7 @@ public class LibMatrixAgg
 		AggregateUnaryOperator uaop = InstructionUtils.parseBasicCumulativeAggregateUnaryOperator(uop);
 		
 		//fall back to sequential if necessary or agg not supported
-		if(    k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 || in.rlen <= k
+		if( k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 || in.rlen <= k
 			|| out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD || uaop == null || !out.isThreadSafe()) {
 			return cumaggregateUnaryMatrix(in, out, uop);
 		}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 0538dd7..2f521b5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2755,6 +2755,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		return ret;
 	}
 
+
 	@Override
 	public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue result) {
 		MatrixBlock ret = checkType(result);
@@ -2769,7 +2770,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 			ret = new MatrixBlock(rlen, n, sp, sp ? nonZeros : rlen*n);
 		else
 			ret.reset(rlen, n, sp);
-		
+
 		//early abort for comparisons w/ special values
 		if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN, BuiltinCode.ISNA))
 			if( !containsValue(op.getPattern()) )
@@ -2788,7 +2789,11 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 			//note: we apply multi-threading in a best-effort manner here
 			//only for expensive operators such as exp, log, sigmoid, because
 			//otherwise allocation, read and write anyway dominates
-			ret.allocateDenseBlock(false);
+			if (!op.isInplace() || isEmpty())
+				ret.allocateDenseBlock(false);
+			else
+				ret = this;
+
 			DenseBlock a = getDenseBlock();
 			DenseBlock c = ret.getDenseBlock();
 			for(int bi=0; bi<a.numBlocks(); bi++) {
@@ -2797,7 +2802,11 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 			}
 			ret.recomputeNonZeros();
 		}
-		else {
+		else
+		{
+			if (op.isInplace() && !isInSparseFormat() )
+				ret = this;
+			
 			//default execute unary operations
 			if(op.sparseSafe)
 				sparseUnaryOperations(op, ret);
@@ -2870,8 +2879,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
 		}
 		else //DENSE <- DENSE
 		{
-			//allocate dense output block
-			ret.allocateDenseBlock(false);
+			if( this != ret ) //!in-place
+				ret.allocateDenseBlock(false);
 			DenseBlock da = getDenseBlock();
 			DenseBlock dc = ret.getDenseBlock();
 			
diff --git a/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
new file mode 100644
index 0000000..0c7c133
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.builtin.BuiltinSplitTest;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class UnaryUpdateInPlaceTest extends AutomatedTestBase{
+	private final static String TEST_NAME = "UnaryUpdateInplace";
+	private final static String TEST_DIR = "functions/updateinplace/";
+	private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSplitTest.class.getSimpleName() + "/";
+	private final static double eps = 1e-3;
+
+	@Override
+	public void setUp() {
+		TestUtils.clearAssertionInformation();
+		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+	}
+
+	@Test
+	public void testInPlace() {
+		runInPlaceTest(Types.ExecType.CP);
+	}
+
+
+	private void runInPlaceTest(Types.ExecType instType) {
+		Types.ExecMode platformOld = setExecMode(instType);
+		boolean oldFlag = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE;
+		
+		try {
+			loadTestConfiguration(getTestConfiguration(TEST_NAME));
+			String HOME = SCRIPT_DIR + TEST_DIR;
+			fullDMLScriptName = HOME + TEST_NAME + ".dml";
+			programArgs = new String[]{"-explain","-nvargs","Out=" + output("Out") };
+
+			OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = true;
+			runTest(true, false, null, -1);
+			HashMap<MatrixValue.CellIndex, Double> dmlfileOut1 = readDMLMatrixFromOutputDir("Out");
+			OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = false;
+			runTest(true, false, null, -1);
+			HashMap<MatrixValue.CellIndex, Double> dmlfileOut2 = readDMLMatrixFromOutputDir("Out");
+
+			//compare matrices
+			TestUtils.compareMatrices(dmlfileOut1,dmlfileOut2,eps,"Stat-DML1","Stat-DML2");
+		}
+		catch(Exception e) {
+			e.printStackTrace();
+		}
+		finally {
+			rtplatform = platformOld;
+			OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = oldFlag;
+		}
+	}
+}
diff --git a/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
new file mode 100644
index 0000000..957ffc5
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#A = rand(rows = 100, cols = 100)
+#C = rand(rows = 100, cols = 100)
+
+A = matrix(1, 10, 10);
+C = matrix(1, 10, 10);
+while(FALSE){}
+A = A * seq(1.1,10.1);
+while(FALSE){}
+B = round(A) # does not apply
+C = C * seq(1.1,10.1);
+D = log(C) # applies
+while(FALSE){}
+C = A + B + D*3
+Out = C
+write(Out, $Out);
+print(as.scalar(C[2, 1]))