You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/05/04 11:13:03 UTC
[systemds] branch main updated: [SYSTEMDS-3293] Optimize partitions count with memory estimate
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 89056f1ec9 [SYSTEMDS-3293] Optimize partitions count with memory estimate
89056f1ec9 is described below
commit 89056f1ec97082a6c720bec7ba4fdcae65f3a8f1
Author: arnabp <ar...@tugraz.at>
AuthorDate: Wed May 4 13:12:41 2022 +0200
[SYSTEMDS-3293] Optimize partitions count with memory estimate
This patch extends the optimizer for transformencode to reduce
the build partitions count if they don't fit in the memory budget.
---
.../runtime/transform/encode/ColumnEncoder.java | 9 +++++
.../transform/encode/ColumnEncoderComposite.java | 6 ++-
.../transform/encode/ColumnEncoderRecode.java | 1 +
.../transform/encode/MultiColumnEncoder.java | 46 +++++++++++++++++-----
.../apache/sysds/runtime/util/DependencyTask.java | 2 +-
.../sysds/utils/stats/TransformStatistics.java | 5 ++-
6 files changed, 56 insertions(+), 13 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 89423521b1..b243c857c2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -60,6 +60,7 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
protected int _colID;
protected ArrayList<Integer> _sparseRowsWZeros = null;
protected long _estMetaSize = 0;
+ protected int _estNumDistincts = 0;
protected int _nBuildPartitions = 0;
protected int _nApplyPartitions = 0;
@@ -291,6 +292,14 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
return _estMetaSize;
}
+ public void setEstNumDistincts(int numDistincts) {
+ _estNumDistincts = numDistincts;
+ }
+
+ public int getEstNumDistincts() {
+ return _estNumDistincts;
+ }
+
@Override
public int compareTo(ColumnEncoder o) {
return Integer.compare(getEncoderType(this), getEncoderType(o));
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index a22cab19ab..7194939853 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -361,11 +361,15 @@ public class ColumnEncoderComposite extends ColumnEncoder {
}
public void computeRCDMapSizeEstimate(CacheBlock in, int[] sampleIndices) {
+ int estNumDist = 0;
for (ColumnEncoder e : _columnEncoders)
- if (e.getClass().equals(ColumnEncoderRecode.class))
+ if (e.getClass().equals(ColumnEncoderRecode.class)) {
((ColumnEncoderRecode) e).computeRCDMapSizeEstimate(in, sampleIndices);
+ estNumDist = e.getEstNumDistincts();
+ }
long totEstSize = _columnEncoders.stream().mapToLong(ColumnEncoder::getEstMetaSize).sum();
setEstMetaSize(totEstSize);
+ setEstNumDistincts(estNumDist);
}
public void setNumPartitions(int nBuild, int nApply) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 8ed89856d8..a6e0329d3a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -154,6 +154,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
int[] freq = distinctFreq.values().stream().mapToInt(v -> v).toArray();
int estDistCount = SampleEstimatorFactory.distinctCount(freq, in.getNumRows(),
sampleIndices.length, SampleEstimatorFactory.EstimationType.HassAndStokes);
+ setEstNumDistincts(estDistCount);
// Compute total size estimates for each partial recode map
// We assume each partial map contains all distinct values and have the same size
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index b34a152fc7..a869ef8208 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -46,7 +46,6 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
@@ -427,18 +426,23 @@ public class MultiColumnEncoder implements Encoder {
while (numBlocks[1] > 1 && nRow/numBlocks[1] < minNumRows)
numBlocks[1]--;
- // Reduce #build blocks if all don't fit in memory
+ // Reduce #build blocks for the recoders if all don't fit in memory
+ int rcdNumBuildBlks = numBlocks[0];
if (numBlocks[0] > 1) {
// Estimate recode map sizes
estimateRCMapSize(in, recodeEncoders);
- long totEstSize = recodeEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
- // Worst case scenario: all partial maps contain all distinct values
- long totPartMapSize = totEstSize * numBlocks[0];
- if (totPartMapSize > InfrastructureAnalyzer.getLocalMaxMemory())
- numBlocks[0] = 1;
- // TODO: Maintain #blocks per encoder. Reduce only the ones with large maps
- // TODO: If this not enough, add dependencies between recode build tasks
+ // Memory budget for maps = 70% of heap - sizeof(input)
+ long memBudget = (long) (OptimizerUtils.getLocalMemBudget() - in.getInMemorySize());
+ // Worst case scenario: all partial maps contain all distinct values (if < #rows)
+ long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
+ // Reduce recode build blocks count till they fit int the memory budget
+ while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) {
+ rcdNumBuildBlks--;
+ totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
+ // TODO: Reduce only the ones with large maps
+ }
}
+ // TODO: If still don't fit, serialize the column encoders
// Set to 1 if not set by the above logics
for (int i=0; i<2; i++)
@@ -448,6 +452,11 @@ public class MultiColumnEncoder implements Encoder {
_partitionDone = true;
// Materialize the partition counts in the encoders
_columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], numBlocks[1]));
+ if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) {
+ int rcdNumBlocks = rcdNumBuildBlks;
+ recodeEncoders.forEach(e -> e.setNumPartitions(rcdNumBlocks, numBlocks[1]));
+ }
+ //System.out.println("Block count = ["+numBlocks[0]+", "+numBlocks[1]+"], Recode block count = "+rcdNumBuildBlks);
}
private void estimateRCMapSize(CacheBlock in, List<ColumnEncoderComposite> rcList) {
@@ -477,6 +486,25 @@ public class MultiColumnEncoder implements Encoder {
}
}
+ // Estimate total memory overhead of the partial recode maps of all recoders
+ private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List<ColumnEncoderComposite> rcEncoders) {
+ long totMemOverhead = 0;
+ if (nBuildpart == 1) {
+ // Sum the estimated map sizes
+ totMemOverhead = rcEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
+ return totMemOverhead;
+ }
+ // Estimate map size of each partition and sum
+ for (ColumnEncoderComposite rce : rcEncoders) {
+ long avgEntrySize = rce.getEstMetaSize()/ rce.getEstNumDistincts();
+ int partSize = in.getNumRows()/nBuildpart;
+ int partNumDist = Math.min(partSize, rce.getEstNumDistincts()); //#distincts not more than #rows
+ long allMapsSize = partNumDist * avgEntrySize * nBuildpart; //worst-case scenario
+ totMemOverhead += allMapsSize;
+ }
+ return totMemOverhead;
+ }
+
private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock input, boolean hasDC) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(output.isInSparseFormat()) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
index 69c25fede6..943b344502 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
@@ -30,7 +30,7 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
public class DependencyTask<E> implements Comparable<DependencyTask<?>>, Callable<E> {
- public static final boolean ENABLE_DEBUG_DATA = false;
+ public static final boolean ENABLE_DEBUG_DATA = false; // explain task graph
protected static final Log LOG = LogFactory.getLog(DependencyTask.class.getName());
private final Callable<E> _task;
diff --git a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index f1a4f1f3d8..05f06b065c 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -174,8 +174,9 @@ public class TransformStatistics {
outMatrixPreProcessingTime.longValue()*1e-9)).append(" sec.\n");
sb.append("TransformEncode PostProc. time:\t").append(String.format("%.3f",
outMatrixPostProcessingTime.longValue()*1e-9)).append(" sec.\n");
- sb.append("TransformEncode SizeEst. time:\t").append(String.format("%.3f",
- mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
+ if(mapSizeEstimationTime.longValue() > 0)
+ sb.append("TransformEncode SizeEst. time:\t").append(String.format("%.3f",
+ mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
return sb.toString();
}
return "";