You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/10/31 19:11:08 UTC
[systemds] branch master updated: [SYSTEMDS-2836] Extended update
in-place for unary operators
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new be4f940 [SYSTEMDS-2836] Extended update in-place for unary operators
be4f940 is described below
commit be4f9404a62b291e997ee5205db395d6ff1b2ae7
Author: Ismael Ibrahim <is...@student.tugraz.at>
AuthorDate: Sun Oct 31 20:08:34 2021 +0100
[SYSTEMDS-2836] Extended update in-place for unary operators
AMLS project SS2021.
Closes #1406.
Co-authored-by: Maximilian Theiner <ma...@student.tugraz.at>
Co-authored-by: Alexander Kropiunig <al...@student.tugraz.at>
Co-authored-by: Matthias Boehm <mb...@gmail.com>
---
src/main/java/org/apache/sysds/hops/Hop.java | 2 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 8 +++
src/main/java/org/apache/sysds/hops/UnaryOp.java | 17 ++++-
src/main/java/org/apache/sysds/lops/Unary.java | 5 ++
.../instructions/cp/UnaryCPInstruction.java | 8 +--
.../sysds/runtime/matrix/data/LibMatrixAgg.java | 10 +--
.../sysds/runtime/matrix/data/MatrixBlock.java | 19 +++--
.../updateinplace/UnaryUpdateInPlaceTest.java | 80 ++++++++++++++++++++++
.../functions/updateinplace/UnaryUpdateInplace.dml | 36 ++++++++++
9 files changed, 168 insertions(+), 17 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java
index a25cf10..9114c55 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -65,7 +65,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
public abstract class Hop implements ParseInfo {
private static final Log LOG = LogFactory.getLog(Hop.class.getName());
-
+
public static final long CPThreshold = 2000;
// static variable to assign an unique ID to every hop that is created
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index be916d9..1b94413 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -191,6 +191,14 @@ public class OptimizerUtils
public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true;
/**
+ * Enables the update-in-place for all unary operators with a single
+ * consumer. In this case we do not allocate the output, but directly
+ * write the output values back to the input block.
+ */
+ //TODO enabling it by default requires modifications in lineage-based reuse
+ public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
+
+ /**
* Replace eval second-order function calls with normal function call
* if the function name is a known string (after constant propagation).
*/
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 38199b2..d4e8f34 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -43,6 +43,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
import java.util.ArrayList;
+
/* Unary (cell operations): e.g, b_ij = round(a_ij)
* Semantic: given a value, perform the operation (independent of other values)
*/
@@ -57,7 +58,7 @@ public class UnaryOp extends MultiThreadedHop
private UnaryOp() {
//default constructor for clone
}
-
+
public UnaryOp(String l, DataType dt, ValueType vt, OpOp1 o, Hop inp) {
super(l, dt, vt);
@@ -130,7 +131,7 @@ public class UnaryOp extends MultiThreadedHop
try
{
Hop input = getInput().get(0);
-
+
if( getDataType() == DataType.SCALAR //value type casts or matrix to scalar
|| (_op == OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType()==DataType.SCALAR)
|| (_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType()==DataType.SCALAR))
@@ -165,10 +166,20 @@ public class UnaryOp extends MultiThreadedHop
}
else //default unary
{
+ boolean inplace = false;
+
+ //check in-place
+ if (OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE
+ && input.getParent().size() == 1)
+ {
+ inplace = !(input instanceof DataOp)
+ || !((DataOp) input).isRead();
+ }
+
int k = isCumulativeUnaryOperation() || isExpensiveUnaryOperation() ?
OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ) : 1;
Unary unary1 = new Unary(input.constructLops(),
- _op, getDataType(), getValueType(), et, k, false);
+ _op, getDataType(), getValueType(), et, k, inplace);
setOutputDimensions(unary1);
setLineNumbers(unary1);
setLops(unary1);
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java
index ad4b2b8..f0a59fa 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -122,6 +122,7 @@ public class Unary extends Lop
}
public static boolean isMultiThreadedOp(OpOp1 op) {
+ //TODO extend for all basic unary operations
return op==OpOp1.CUMSUM
|| op==OpOp1.CUMPROD
|| op==OpOp1.CUMMIN
@@ -129,6 +130,10 @@ public class Unary extends Lop
|| op==OpOp1.CUMSUMPROD
|| op==OpOp1.EXP
|| op==OpOp1.LOG
+ || op==OpOp1.ABS
+ || op==OpOp1.ROUND
+ || op==OpOp1.FLOOR
+ || op==OpOp1.CEIL
|| op==OpOp1.SIGMOID
|| op==OpOp1.POW2
|| op==OpOp1.MULT2;
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
index 0c98e84..8f92c07 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
@@ -19,10 +19,10 @@
package org.apache.sysds.runtime.instructions.cp;
-import java.util.Arrays;
-
+import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Unary;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +61,8 @@ public abstract class UnaryCPInstruction extends ComputationCPInstruction {
in.split(parts[1]);
out.split(parts[2]);
func = Builtin.getBuiltinFnObject(opcode);
-
- if( Arrays.asList(new String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode) ){
+ Types.OpOp1 op_type = Types.OpOp1.valueOfByOpcode(opcode);
+ if( Unary.isMultiThreadedOp(op_type)){
UnaryOperator op = new UnaryOperator(func, Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
return new UnaryMatrixCPInstruction(op, in, out, opcode, str);
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 22c437d..0d3c007 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -294,12 +294,14 @@ public class LibMatrixAgg
}
public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, UnaryOperator uop, double[] agg) {
- //prepare meta data
+ //Check this implementation, standard case for cumagg (single threaded)
+
+ //prepare meta data
AggType aggtype = getAggType(uop);
final int m = in.rlen;
final int m2 = out.rlen;
final int n2 = out.clen;
-
+
//filter empty input blocks (incl special handling for sparse-unsafe operations)
if( in.isEmpty() && (agg == null || aggtype == AggType.CUM_SUM_PROD) ) {
return aggregateUnaryMatrixEmpty(in, out, aggtype, null);
@@ -317,7 +319,7 @@ public class LibMatrixAgg
}
//Timing time = new Timing(true);
-
+
if( !in.sparse )
cumaggregateUnaryMatrixDense(in, out, aggtype, uop.fn, agg, 0, m);
else
@@ -336,7 +338,7 @@ public class LibMatrixAgg
AggregateUnaryOperator uaop = InstructionUtils.parseBasicCumulativeAggregateUnaryOperator(uop);
//fall back to sequential if necessary or agg not supported
- if( k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 || in.rlen <= k
+ if( k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 || in.rlen <= k
|| out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD || uaop == null || !out.isThreadSafe()) {
return cumaggregateUnaryMatrix(in, out, uop);
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 0538dd7..2f521b5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2755,6 +2755,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
return ret;
}
+
@Override
public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue result) {
MatrixBlock ret = checkType(result);
@@ -2769,7 +2770,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
ret = new MatrixBlock(rlen, n, sp, sp ? nonZeros : rlen*n);
else
ret.reset(rlen, n, sp);
-
+
//early abort for comparisons w/ special values
if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN, BuiltinCode.ISNA))
if( !containsValue(op.getPattern()) )
@@ -2788,7 +2789,11 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
//note: we apply multi-threading in a best-effort manner here
//only for expensive operators such as exp, log, sigmoid, because
//otherwise allocation, read and write anyway dominates
- ret.allocateDenseBlock(false);
+ if (!op.isInplace() || isEmpty())
+ ret.allocateDenseBlock(false);
+ else
+ ret = this;
+
DenseBlock a = getDenseBlock();
DenseBlock c = ret.getDenseBlock();
for(int bi=0; bi<a.numBlocks(); bi++) {
@@ -2797,7 +2802,11 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
}
ret.recomputeNonZeros();
}
- else {
+ else
+ {
+ if (op.isInplace() && !isInSparseFormat() )
+ ret = this;
+
//default execute unary operations
if(op.sparseSafe)
sparseUnaryOperations(op, ret);
@@ -2870,8 +2879,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab
}
else //DENSE <- DENSE
{
- //allocate dense output block
- ret.allocateDenseBlock(false);
+ if( this != ret ) //!in-place
+ ret.allocateDenseBlock(false);
DenseBlock da = getDenseBlock();
DenseBlock dc = ret.getDenseBlock();
diff --git a/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
new file mode 100644
index 0000000..0c7c133
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.builtin.BuiltinSplitTest;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class UnaryUpdateInPlaceTest extends AutomatedTestBase{
+ private final static String TEST_NAME = "UnaryUpdateInplace";
+ private final static String TEST_DIR = "functions/updateinplace/";
+ private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSplitTest.class.getSimpleName() + "/";
+ private final static double eps = 1e-3;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+ }
+
+ @Test
+ public void testInPlace() {
+ runInPlaceTest(Types.ExecType.CP);
+ }
+
+
+ private void runInPlaceTest(Types.ExecType instType) {
+ Types.ExecMode platformOld = setExecMode(instType);
+ boolean oldFlag = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE;
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain","-nvargs","Out=" + output("Out") };
+
+ OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = true;
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> dmlfileOut1 = readDMLMatrixFromOutputDir("Out");
+ OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = false;
+ runTest(true, false, null, -1);
+ HashMap<MatrixValue.CellIndex, Double> dmlfileOut2 = readDMLMatrixFromOutputDir("Out");
+
+ //compare matrices
+ TestUtils.compareMatrices(dmlfileOut1,dmlfileOut2,eps,"Stat-DML1","Stat-DML2");
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ }
+ finally {
+ rtplatform = platformOld;
+ OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
new file mode 100644
index 0000000..957ffc5
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#A = rand(rows = 100, cols = 100)
+#C = rand(rows = 100, cols = 100)
+
+A = matrix(1, 10, 10);
+C = matrix(1, 10, 10);
+while(FALSE){}
+A = A * seq(1.1,10.1);
+while(FALSE){}
+B = round(A) # does not apply
+C = C * seq(1.1,10.1);
+D = log(C) # applies
+while(FALSE){}
+C = A + B + D*3
+Out = C
+write(Out, $Out);
+print(as.scalar(C[2, 1]))