You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by mb...@apache.org on 2021/09/02 22:16:13 UTC
[systemds] 01/02: [SYSTEMDS-3118] Extended parfor parser/runtime
(frame result variables)
This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 60d16c474b76ecb4d45d3cd6e36580672fc6f1da
Author: Matthias Boehm <mb...@gmail.com>
AuthorDate: Fri Sep 3 00:12:53 2021 +0200
[SYSTEMDS-3118] Extended parfor parser/runtime (frame result variables)
This patch extends parfor by support for frame results variables during
dependency analysis and merge of worker result variables. So far, this
captures only in-memory frame result merge.
---
.../apache/sysds/parser/ParForStatementBlock.java | 3 +-
.../runtime/controlprogram/ParForProgramBlock.java | 92 +++++++++--------
.../controlprogram/caching/FrameObject.java | 7 ++
.../runtime/controlprogram/parfor/ResultMerge.java | 90 ++--------------
.../parfor/ResultMergeFrameLocalMemory.java | 114 +++++++++++++++++++++
.../parfor/ResultMergeLocalAutomatic.java | 4 +-
.../parfor/ResultMergeLocalFile.java | 2 +-
.../parfor/ResultMergeLocalMemory.java | 2 +-
.../{ResultMerge.java => ResultMergeMatrix.java} | 50 ++-------
.../parfor/ResultMergeRemoteSpark.java | 2 +-
.../parfor/ResultMergeRemoteSparkWCompare.java | 2 +-
.../parfor/ParForDependencyAnalysisTest.java | 10 +-
...est.java => ParForListFrameResultVarsTest.java} | 22 +++-
src/test/scripts/component/parfor/parfor54e.dml | 26 +++++
src/test/scripts/component/parfor/parfor54f.dml | 26 +++++
.../functions/parfor/parfor_frameResults.dml | 32 ++++++
16 files changed, 306 insertions(+), 178 deletions(-)
diff --git a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
index 74c55c5..607641c 100644
--- a/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
+++ b/src/main/java/org/apache/sysds/parser/ParForStatementBlock.java
@@ -677,7 +677,7 @@ public class ParForStatementBlock extends ForStatementBlock
for(DataIdentifier write : datsUpdated) {
if( !c._var.equals( write.getName() ) ) continue;
- if( cdt != DataType.MATRIX && cdt != DataType.LIST ) {
+ if( cdt != DataType.MATRIX && cdt != DataType.FRAME && cdt != DataType.LIST ) {
//cannot infer type, need to exit (conservative approach)
throw new LanguageException("PARFOR loop dependency analysis: cannot check "
+ "for dependencies due to unknown datatype of var '"+c._var+"': "+cdt.name()+".");
@@ -716,6 +716,7 @@ public class ParForStatementBlock extends ForStatementBlock
return;
}
else if( (cdt == DataType.MATRIX && dat2dt == DataType.MATRIX)
+ || (cdt == DataType.FRAME && dat2dt == DataType.FRAME )
|| (cdt == DataType.LIST && dat2dt == DataType.LIST ) )
{
boolean invalid = false;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index 25d49bb..42ab8bc 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -38,6 +38,7 @@ import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
@@ -51,6 +52,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.RemoteDPParForSpark;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForJobReturn;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForSpark;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMerge;
+import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeFrameLocalMemory;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalAutomatic;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalFile;
import org.apache.sysds.runtime.controlprogram.parfor.ResultMergeLocalMemory;
@@ -1056,9 +1058,9 @@ public class ParForProgramBlock extends ForProgramBlock
* @param out output matrix
* @param in array of input matrix objects
*/
- private static void cleanWorkerResultVariables(ExecutionContext ec, MatrixObject out, MatrixObject[] in, boolean parallel) {
+ private static void cleanWorkerResultVariables(ExecutionContext ec, CacheableData<?> out, CacheableData<?>[] in, boolean parallel) {
//check for empty inputs (no iterations executed)
- Stream<MatrixObject> results = Arrays.stream(in).filter(m -> m!=null && m!=out);
+ Stream<CacheableData<?>> results = Arrays.stream(in).filter(m -> m!=null && m!=out);
//perform cleanup (parallel to mitigate file deletion bottlenecks)
(parallel ? results.parallel() : results)
.forEach(m -> ec.cleanupCacheableData(m));
@@ -1307,33 +1309,41 @@ public class ParForProgramBlock extends ForProgramBlock
return dp;
}
- private ResultMerge createResultMerge( PResultMerge prm, MatrixObject out, MatrixObject[] in, String fname, boolean accum, ExecutionContext ec )
+ private ResultMerge<?> createResultMerge( PResultMerge prm,
+ CacheableData<?> out, CacheableData<?>[] in, String fname, boolean accum, ExecutionContext ec )
{
- ResultMerge rm = null;
+ ResultMerge<?> rm = null;
- //create result merge implementation (determine degree of parallelism
- //only for spark to avoid unnecessary spark context creation)
- switch( prm )
- {
- case LOCAL_MEM:
- rm = new ResultMergeLocalMemory( out, in, fname, accum );
- break;
- case LOCAL_FILE:
- rm = new ResultMergeLocalFile( out, in, fname, accum );
- break;
- case LOCAL_AUTOMATIC:
- rm = new ResultMergeLocalAutomatic( out, in, fname, accum );
- break;
- case REMOTE_SPARK:
- int numMap = Math.max(_numThreads,
- SparkExecutionContext.getDefaultParallelism(true));
- int numRed = numMap; //equal map/reduce
- rm = new ResultMergeRemoteSpark( out, in,
- fname, accum, ec, numMap, numRed );
- break;
-
- default:
- throw new DMLRuntimeException("Undefined result merge: '" +prm.toString()+"'.");
+ if( out instanceof FrameObject ) {
+ rm = new ResultMergeFrameLocalMemory((FrameObject)out, (FrameObject[])in, fname, accum);
+ }
+ else if(out instanceof MatrixObject) {
+ //create result merge implementation (determine degree of parallelism
+ //only for spark to avoid unnecessary spark context creation)
+ switch( prm )
+ {
+ case LOCAL_MEM:
+ rm = new ResultMergeLocalMemory( (MatrixObject)out, (MatrixObject[])in, fname, accum );
+ break;
+ case LOCAL_FILE:
+ rm = new ResultMergeLocalFile( (MatrixObject)out, (MatrixObject[])in, fname, accum );
+ break;
+ case LOCAL_AUTOMATIC:
+ rm = new ResultMergeLocalAutomatic( (MatrixObject)out, (MatrixObject[])in, fname, accum );
+ break;
+ case REMOTE_SPARK:
+ int numMap = Math.max(_numThreads,
+ SparkExecutionContext.getDefaultParallelism(true));
+ int numRed = numMap; //equal map/reduce
+ rm = new ResultMergeRemoteSpark( (MatrixObject)out,
+ (MatrixObject[])in, fname, accum, ec, numMap, numRed );
+ break;
+ default:
+ throw new DMLRuntimeException("Undefined result merge: '" +prm.toString()+"'.");
+ }
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported result merge data: "+out.getClass().getSimpleName());
}
return rm;
@@ -1437,14 +1447,15 @@ public class ParForProgramBlock extends ForProgramBlock
{
Data dat = ec.getVariable(var._name);
- if( dat instanceof MatrixObject ) //robustness scalars
+ if( dat instanceof MatrixObject | dat instanceof FrameObject )
{
- MatrixObject out = (MatrixObject) dat;
- MatrixObject[] in = Arrays.stream(results).map(vars ->
- vars.get(var._name)).toArray(MatrixObject[]::new);
+ CacheableData<?> out = (CacheableData<?>) dat;
+ Stream<Object> tmp = Arrays.stream(results).map(vars -> vars.get(var._name));
+ CacheableData<?>[] in = (dat instanceof MatrixObject) ?
+ tmp.toArray(MatrixObject[]::new) : tmp.toArray(FrameObject[]::new);
String fname = constructResultMergeFileName();
- ResultMerge rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, ec);
- MatrixObject outNew = USE_PARALLEL_RESULT_MERGE ?
+ ResultMerge<?> rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, ec);
+ CacheableData<?> outNew = USE_PARALLEL_RESULT_MERGE ?
rm.executeParallelMerge(_numThreads) :
rm.executeSerialMerge();
@@ -1653,18 +1664,19 @@ public class ParForProgramBlock extends ForProgramBlock
if( var == LocalTaskQueue.NO_MORE_TASKS ) // task queue closed (no more tasks)
break;
- MatrixObject out = null;
+ CacheableData<?> out = null;
synchronized( _ec.getVariables() ){
- out = _ec.getMatrixObject(var._name);
+ out = _ec.getCacheableData(var._name);
}
- MatrixObject[] in = new MatrixObject[ _refVars.length ];
- for( int i=0; i< _refVars.length; i++ )
- in[i] = (MatrixObject) _refVars[i].get( var._name );
+ Stream<Object> tmp = Arrays.stream(_refVars).map(vars -> vars.get(var._name));
+ CacheableData<?>[] in = (out instanceof MatrixObject) ?
+ tmp.toArray(MatrixObject[]::new) : tmp.toArray(FrameObject[]::new);
+
String fname = constructResultMergeFileName();
- ResultMerge rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, _ec);
- MatrixObject outNew = null;
+ ResultMerge<?> rm = createResultMerge(_resultMerge, out, in, fname, var._isAccum, _ec);
+ CacheableData<?> outNew = null;
if( USE_PARALLEL_RESULT_MERGE )
outNew = rm.executeParallelMerge( _numThreads );
else
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
index 5eae986..4485388 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/FrameObject.java
@@ -41,6 +41,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageRecomputeUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -86,6 +87,12 @@ public class FrameObject extends CacheableData<FrameBlock>
*/
public FrameObject(FrameObject fo) {
super(fo);
+
+ MetaDataFormat metaOld = (MetaDataFormat) fo.getMetaData();
+ _metaData = new MetaDataFormat(
+ new MatrixCharacteristics(metaOld.getDataCharacteristics()),
+ metaOld.getFileFormat());
+ _schema = fo._schema.clone();
}
@Override
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
index 18b09a1..b69ba96 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
@@ -21,42 +21,33 @@ package org.apache.sysds.runtime.controlprogram.parfor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import java.io.Serializable;
-import java.util.List;
-/**
- * Due to independence of all iterations, any result has the following properties:
- * (1) non local var, (2) matrix object, and (3) completely independent.
- * These properties allow us to realize result merging in parallel without any synchronization.
- *
- */
-public abstract class ResultMerge implements Serializable
+public abstract class ResultMerge<T extends CacheableData<?>> implements Serializable
{
//note: this class needs to be serializable to ensure that all attributes of
//ResultMergeRemoteSparkWCompare are included in the task closure
- private static final long serialVersionUID = 2620430969346516677L;
+ private static final long serialVersionUID = -6756689640511059030L;
protected static final Log LOG = LogFactory.getLog(ResultMerge.class.getName());
protected static final String NAME_SUFFIX = "_rm";
protected static final BinaryOperator PLUS = InstructionUtils.parseBinaryOperator("+");
//inputs to result merge
- protected MatrixObject _output = null;
- protected MatrixObject[] _inputs = null;
- protected String _outputFName = null;
- protected boolean _isAccum = false;
+ protected T _output = null;
+ protected T[] _inputs = null;
+ protected String _outputFName = null;
+ protected boolean _isAccum = false;
protected ResultMerge( ) {
//do nothing
}
- public ResultMerge( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) {
+ public ResultMerge( T out, T[] in, String outputFilename, boolean accum ) {
_output = out;
_inputs = in;
_outputFName = outputFilename;
@@ -70,7 +61,7 @@ public abstract class ResultMerge implements Serializable
*
* @return output (merged) matrix
*/
- public abstract MatrixObject executeSerialMerge();
+ public abstract T executeSerialMerge();
/**
* Merge all given input matrices in parallel into the given output matrix.
@@ -80,67 +71,6 @@ public abstract class ResultMerge implements Serializable
* @param par degree of parallelism
* @return output (merged) matrix
*/
- public abstract MatrixObject executeParallelMerge( int par );
-
- protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly ) {
- mergeWithoutComp(out, in, appendOnly, false);
- }
+ public abstract T executeParallelMerge(int par);
- protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly, boolean par ) {
- //pass through to matrix block operations
- if( _isAccum )
- out.binaryOperationsInPlace(PLUS, in);
- else
- out.merge(in, appendOnly, par);
- }
-
- /**
- * NOTE: append only not applicable for wiht compare because output must be populated with
- * initial state of matrix - with append, this would result in duplicates.
- *
- * @param out output matrix block
- * @param in input matrix block
- * @param compare ?
- */
- protected void mergeWithComp( MatrixBlock out, MatrixBlock in, DenseBlock compare )
- {
- //Notes for result correctness:
- // * Always iterate over entire block in order to compare all values
- // (using sparse iterator would miss values set to 0)
- // * Explicit NaN awareness because for cases were original matrix contains
- // NaNs, since NaN != NaN, otherwise we would potentially overwrite results
- // * For the case of accumulation, we add out += (new-old) to ensure correct results
- // because all inputs have the old values replicated
-
- if( in.isEmptyBlock(false) ) {
- if( _isAccum ) return; //nothing to do
- for( int i=0; i<in.getNumRows(); i++ )
- for( int j=0; j<in.getNumColumns(); j++ )
- if( compare.get(i, j) != 0 )
- out.quickSetValue(i, j, 0);
- }
- else { //SPARSE/DENSE
- int rows = in.getNumRows();
- int cols = in.getNumColumns();
- for( int i=0; i<rows; i++ )
- for( int j=0; j<cols; j++ ) {
- double valOld = compare.get(i,j);
- double valNew = in.quickGetValue(i,j); //input value
- if( (valNew != valOld && !Double.isNaN(valNew) ) //for changed values
- || Double.isNaN(valNew) != Double.isNaN(valOld) ) //NaN awareness
- {
- double value = !_isAccum ? valNew :
- (out.quickGetValue(i, j) + (valNew - valOld));
- out.quickSetValue(i, j, value);
- }
- }
- }
- }
-
- protected long computeNonZeros( MatrixObject out, List<MatrixObject> in ) {
- //sum of nnz of input (worker result) - output var existing nnz
- long outNNZ = out.getDataCharacteristics().getNonZeros();
- return outNNZ - in.size() * outNNZ + in.stream()
- .mapToLong(m -> m.getDataCharacteristics().getNonZeros()).sum();
- }
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java
new file mode 100644
index 0000000..cd2d99f
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeFrameLocalMemory.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.controlprogram.parfor;
+
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+public class ResultMergeFrameLocalMemory extends ResultMerge<FrameObject>
+{
+ private static final long serialVersionUID = 549739254879310540L;
+
+ public ResultMergeFrameLocalMemory(FrameObject out, FrameObject[] in, String outputFilename, boolean accum) {
+ super( out, in, outputFilename, accum );
+ }
+
+ @Override
+ public FrameObject executeSerialMerge()
+ {
+ FrameObject foNew = null; //always create new matrix object (required for nested parallelism)
+
+ if( LOG.isTraceEnabled() )
+ LOG.trace("ResultMerge (local, in-memory): Execute serial merge for output "
+ +_output.hashCode()+" (fname="+_output.getFileName()+")");
+
+ try
+ {
+ //get old and new output frame blocks
+ FrameBlock outFB = _output.acquireRead();
+ FrameBlock outFBNew = new FrameBlock(outFB);
+
+ //create compare matrix if required (existing data in result)
+ FrameBlock compare = outFB;
+ int rlen = compare.getNumRows();
+ int clen = compare.getNumColumns();
+
+ //serial merge all inputs
+ boolean flagMerged = false;
+ for( FrameObject in : _inputs )
+ {
+ //check for empty inputs (no iterations executed)
+ if( in != null && in != _output )
+ {
+ if( LOG.isTraceEnabled() )
+ LOG.trace("ResultMergeFrame (local, in-memory): Merge input "+in.hashCode()+" (fname="+in.getFileName()+")");
+
+ //read/pin input_i
+ FrameBlock inMB = in.acquireRead();
+
+ //core merge
+ for(int j=0; j<clen; j++) {
+ ValueType vt = compare.getSchema()[j];
+ for(int i=0; i<rlen; i++) {
+ Object val1 = compare.get(i, j);
+ Object val2 = inMB.get(i, j);
+ if( UtilFunctions.compareTo(vt, val1, val2) != 0 )
+ outFBNew.set(i, j, val2);
+ }
+ }
+
+ //unpin and clear in-memory input_i
+ in.release();
+ in.clearData();
+ flagMerged = true;
+ }
+ }
+
+ //create output and release old output
+ foNew = flagMerged ? createNewFrameObject(_output, outFBNew) : _output;
+ _output.release();
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+
+ //LOG.trace("ResultMerge (local, in-memory): Executed serial merge for output "+_output.getVarName()+" (fname="+_output.getFileName()+") in "+time.stop()+"ms");
+
+ return foNew;
+ }
+
+ @Override
+ public FrameObject executeParallelMerge( int par ) {
+ if( LOG.isTraceEnabled() )
+ LOG.trace("ResultMerge (local, in-memory): Execute parallel (par="+par+") "
+ + "merge for output "+_output.hashCode()+" (fname="+_output.getFileName()+")");
+ return executeSerialMerge();
+ }
+
+ private static FrameObject createNewFrameObject( FrameObject foOld, FrameBlock dataNew ) {
+ FrameObject ret = new FrameObject(foOld);
+ ret.acquireModify(dataNew);
+ ret.release();
+ return ret;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java
index 92ec8f9..ea5195d 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalAutomatic.java
@@ -26,11 +26,11 @@ import org.apache.sysds.runtime.controlprogram.parfor.opt.OptimizerRuleBased;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.meta.DataCharacteristics;
-public class ResultMergeLocalAutomatic extends ResultMerge
+public class ResultMergeLocalAutomatic extends ResultMergeMatrix
{
private static final long serialVersionUID = 1600893100602101732L;
- private ResultMerge _rm = null;
+ private ResultMergeMatrix _rm = null;
public ResultMergeLocalAutomatic( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) {
super( out, in, outputFilename, accum );
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
index db3d741..441ba3e 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalFile.java
@@ -67,7 +67,7 @@ import java.util.Map.Entry;
* NOTE: file merge typically used due to memory constraints - parallel merge would increase the memory
* consumption again.
*/
-public class ResultMergeLocalFile extends ResultMerge
+public class ResultMergeLocalFile extends ResultMergeMatrix
{
private static final long serialVersionUID = -6905893742840020489L;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
index 5c604dd..f422423 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeLocalMemory.java
@@ -39,7 +39,7 @@ import java.util.ArrayList;
*
*
*/
-public class ResultMergeLocalMemory extends ResultMerge
+public class ResultMergeLocalMemory extends ResultMergeMatrix
{
private static final long serialVersionUID = -3543612508601511701L;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
similarity index 67%
copy from src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
copy to src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
index 18b09a1..7d0776c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMerge.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeMatrix.java
@@ -19,13 +19,9 @@
package org.apache.sysds.runtime.controlprogram.parfor;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlock;
-import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import java.io.Serializable;
import java.util.List;
@@ -36,52 +32,18 @@ import java.util.List;
* These properties allow us to realize result merging in parallel without any synchronization.
*
*/
-public abstract class ResultMerge implements Serializable
+public abstract class ResultMergeMatrix extends ResultMerge<MatrixObject> implements Serializable
{
- //note: this class needs to be serializable to ensure that all attributes of
- //ResultMergeRemoteSparkWCompare are included in the task closure
- private static final long serialVersionUID = 2620430969346516677L;
+ private static final long serialVersionUID = 5319002218804570071L;
- protected static final Log LOG = LogFactory.getLog(ResultMerge.class.getName());
- protected static final String NAME_SUFFIX = "_rm";
- protected static final BinaryOperator PLUS = InstructionUtils.parseBinaryOperator("+");
-
- //inputs to result merge
- protected MatrixObject _output = null;
- protected MatrixObject[] _inputs = null;
- protected String _outputFName = null;
- protected boolean _isAccum = false;
-
- protected ResultMerge( ) {
- //do nothing
+ public ResultMergeMatrix() {
+ super();
}
- public ResultMerge( MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum ) {
- _output = out;
- _inputs = in;
- _outputFName = outputFilename;
- _isAccum = accum;
+ public ResultMergeMatrix(MatrixObject out, MatrixObject[] in, String outputFilename, boolean accum) {
+ super(out, in, outputFilename, accum);
}
- /**
- * Merge all given input matrices sequentially into the given output matrix.
- * The required space in-memory is the size of the output matrix plus the size
- * of one input matrix at a time.
- *
- * @return output (merged) matrix
- */
- public abstract MatrixObject executeSerialMerge();
-
- /**
- * Merge all given input matrices in parallel into the given output matrix.
- * The required space in-memory is the size of the output matrix plus the size
- * of all input matrices.
- *
- * @param par degree of parallelism
- * @return output (merged) matrix
- */
- public abstract MatrixObject executeParallelMerge( int par );
-
protected void mergeWithoutComp( MatrixBlock out, MatrixBlock in, boolean appendOnly ) {
mergeWithoutComp(out, in, appendOnly, false);
}
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java
index 8a70ecf..6f33225 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSpark.java
@@ -44,7 +44,7 @@ import org.apache.sysds.utils.Statistics;
import java.util.Arrays;
-public class ResultMergeRemoteSpark extends ResultMerge
+public class ResultMergeRemoteSpark extends ResultMergeMatrix
{
private static final long serialVersionUID = -6924566953903424820L;
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
index a152c52..6b8d424 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/ResultMergeRemoteSparkWCompare.java
@@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.DataConverter;
import scala.Tuple2;
-public class ResultMergeRemoteSparkWCompare extends ResultMerge implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock>
+public class ResultMergeRemoteSparkWCompare extends ResultMergeMatrix implements PairFunction<Tuple2<MatrixIndexes,Tuple2<Iterable<MatrixBlock>,MatrixBlock>>, MatrixIndexes, MatrixBlock>
{
private static final long serialVersionUID = -5970805069405942836L;
diff --git a/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java b/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java
index 04f575a..cf7c71a 100644
--- a/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java
+++ b/src/test/java/org/apache/sysds/test/component/parfor/ParForDependencyAnalysisTest.java
@@ -66,8 +66,8 @@ import org.apache.sysds.test.TestConfiguration;
* 49a: dep, 49b: dep
* * accumulators
* 53a: no, 53b dep, 53c dep, 53d dep, 53e dep
- * * lists
- * 54a: no, 54b: no, 54c: dep, 54d: dep
+ * * lists/frames
+ * 54a: no, 54b: no, 54c: dep, 54d: dep, 54e: no-dep, 54f: dep
* * negative loop increment
* 55a: no, 55b: yes
*/
@@ -328,6 +328,12 @@ public class ParForDependencyAnalysisTest extends AutomatedTestBase
public void testDependencyAnalysis54d() { runTest("parfor54d.dml", true); }
@Test
+ public void testDependencyAnalysis54e() { runTest("parfor54e.dml", false); }
+
+ @Test
+ public void testDependencyAnalysis54f() { runTest("parfor54f.dml", true); }
+
+ @Test
public void testDependencyAnalysis55a() { runTest("parfor55a.dml", false); }
@Test
diff --git a/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java b/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java
similarity index 75%
rename from src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java
rename to src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java
index fc952e1..a206781 100644
--- a/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListResultVarsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/parfor/misc/ParForListFrameResultVarsTest.java
@@ -25,16 +25,18 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
-public class ParForListResultVarsTest extends AutomatedTestBase
+public class ParForListFrameResultVarsTest extends AutomatedTestBase
{
private final static String TEST_DIR = "functions/parfor/";
private final static String TEST_NAME1 = "parfor_listResults";
- private final static String TEST_CLASS_DIR = TEST_DIR + ParForListResultVarsTest.class.getSimpleName() + "/";
+ private final static String TEST_NAME2 = "parfor_frameResults";
+
+ private final static String TEST_CLASS_DIR = TEST_DIR + ParForListFrameResultVarsTest.class.getSimpleName() + "/";
@Override
public void setUp() {
- addTestConfiguration(TEST_NAME1,
- new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R" }) );
+ addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+ addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[]{"R"}));
}
@Test
@@ -47,11 +49,21 @@ public class ParForListResultVarsTest extends AutomatedTestBase
runListResultVarTest(TEST_NAME1, 35, 10);
}
+ @Test
+ public void testParForFrameResult1a() {
+ runListResultVarTest(TEST_NAME2, 2, 1);
+ }
+
+ @Test
+ public void testParForFrameResult1b() {
+ runListResultVarTest(TEST_NAME2, 35, 10);
+ }
+
private void runListResultVarTest(String testName, int rows, int cols) {
loadTestConfiguration(getTestConfiguration(testName));
String HOME = SCRIPT_DIR + TEST_DIR;
- fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
+ fullDMLScriptName = HOME + testName + ".dml";
programArgs = new String[]{"-explain","-args",
String.valueOf(rows), String.valueOf(cols), output("R") };
diff --git a/src/test/scripts/component/parfor/parfor54e.dml b/src/test/scripts/component/parfor/parfor54e.dml
new file mode 100644
index 0000000..70837e9
--- /dev/null
+++ b/src/test/scripts/component/parfor/parfor54e.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(as.frame("a"), as.frame("b"), as.frame("c"));
+parfor( i in 1:nrow(A) )
+ A[i,1] = as.frame(as.scalar(A[i,1])+"-"+i);
+print(toString(A));
diff --git a/src/test/scripts/component/parfor/parfor54f.dml b/src/test/scripts/component/parfor/parfor54f.dml
new file mode 100644
index 0000000..23bcf44
--- /dev/null
+++ b/src/test/scripts/component/parfor/parfor54f.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = rbind(as.frame("a"), as.frame("b"), as.frame("c"));
+parfor( i in 1:nrow(A) )
+ A[i,1] = as.frame(as.scalar(A[1,1])+"-"+i);
+print(toString(A));
diff --git a/src/test/scripts/functions/parfor/parfor_frameResults.dml b/src/test/scripts/functions/parfor/parfor_frameResults.dml
new file mode 100644
index 0000000..b1a54be
--- /dev/null
+++ b/src/test/scripts/functions/parfor/parfor_frameResults.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F = as.frame(matrix(0,7,1));
+
+parfor(i in 1:nrow(F))
+ F[i,1] = as.frame(rowMeans(as.matrix(F[i]))+i);
+
+R1 = matrix(0,0,1)
+for(i in 1:length(F))
+ R1 = rbind(R1, as.matrix(F[i,1]));
+
+R = as.matrix(sum(R1==seq(1,7)));
+write(R, $3);