You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/05/04 11:13:03 UTC

[systemds] branch main updated: [SYSTEMDS-3293] Optimize partitions count with memory estimate

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 89056f1ec9 [SYSTEMDS-3293] Optimize partitions count with memory estimate
89056f1ec9 is described below

commit 89056f1ec97082a6c720bec7ba4fdcae65f3a8f1
Author: arnabp <ar...@tugraz.at>
AuthorDate: Wed May 4 13:12:41 2022 +0200

    [SYSTEMDS-3293] Optimize partitions count with memory estimate
    
    This patch extends the optimizer for transformencode to reduce
    the build partitions count if they don't fit in the memory budget.
---
 .../runtime/transform/encode/ColumnEncoder.java    |  9 +++++
 .../transform/encode/ColumnEncoderComposite.java   |  6 ++-
 .../transform/encode/ColumnEncoderRecode.java      |  1 +
 .../transform/encode/MultiColumnEncoder.java       | 46 +++++++++++++++++-----
 .../apache/sysds/runtime/util/DependencyTask.java  |  2 +-
 .../sysds/utils/stats/TransformStatistics.java     |  5 ++-
 6 files changed, 56 insertions(+), 13 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 89423521b1..b243c857c2 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -60,6 +60,7 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
 	protected int _colID;
 	protected ArrayList<Integer> _sparseRowsWZeros = null;
 	protected long _estMetaSize = 0;
+	protected int _estNumDistincts = 0;
 	protected int _nBuildPartitions = 0;
 	protected int _nApplyPartitions = 0;
 
@@ -291,6 +292,14 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
 		return _estMetaSize;
 	}
 
+	public void setEstNumDistincts(int numDistincts) {
+		_estNumDistincts = numDistincts;
+	}
+
+	public int getEstNumDistincts() {
+		return _estNumDistincts;
+	}
+
 	@Override
 	public int compareTo(ColumnEncoder o) {
 		return Integer.compare(getEncoderType(this), getEncoderType(o));
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index a22cab19ab..7194939853 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -361,11 +361,15 @@ public class ColumnEncoderComposite extends ColumnEncoder {
 	}
 
 	public void computeRCDMapSizeEstimate(CacheBlock in, int[] sampleIndices) {
+		int estNumDist = 0;
 		for (ColumnEncoder e : _columnEncoders)
-			if (e.getClass().equals(ColumnEncoderRecode.class))
+			if (e.getClass().equals(ColumnEncoderRecode.class)) {
 				((ColumnEncoderRecode) e).computeRCDMapSizeEstimate(in, sampleIndices);
+				estNumDist = e.getEstNumDistincts();
+			}
 		long totEstSize = _columnEncoders.stream().mapToLong(ColumnEncoder::getEstMetaSize).sum();
 		setEstMetaSize(totEstSize);
+		setEstNumDistincts(estNumDist);
 	}
 
 	public void setNumPartitions(int nBuild, int nApply) {
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 8ed89856d8..a6e0329d3a 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -154,6 +154,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
 		int[] freq = distinctFreq.values().stream().mapToInt(v -> v).toArray();
 		int estDistCount = SampleEstimatorFactory.distinctCount(freq, in.getNumRows(),
 			sampleIndices.length, SampleEstimatorFactory.EstimationType.HassAndStokes);
+		setEstNumDistincts(estDistCount);
 
 		// Compute total size estimates for each partial recode map
 		// We assume each partial map contains all distinct values and have the same size
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index b34a152fc7..a869ef8208 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -46,7 +46,6 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorSample;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
-import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockCSR;
 import org.apache.sysds.runtime.data.SparseRowVector;
@@ -427,18 +426,23 @@ public class MultiColumnEncoder implements Encoder {
 		while (numBlocks[1] > 1 && nRow/numBlocks[1] < minNumRows)
 			numBlocks[1]--;
 
-		// Reduce #build blocks if all don't fit in memory
+		// Reduce #build blocks for the recoders if all don't fit in memory
+		int rcdNumBuildBlks = numBlocks[0];
 		if (numBlocks[0] > 1) {
 			// Estimate recode map sizes
 			estimateRCMapSize(in, recodeEncoders);
-			long totEstSize = recodeEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
-			// Worst case scenario: all partial maps contain all distinct values
-			long totPartMapSize = totEstSize * numBlocks[0];
-			if (totPartMapSize > InfrastructureAnalyzer.getLocalMaxMemory())
-				numBlocks[0] = 1;
-			// TODO: Maintain #blocks per encoder. Reduce only the ones with large maps
-			// TODO: If this not enough, add dependencies between recode build tasks
+			// Memory budget for maps = 70% of heap - sizeof(input)
+			long memBudget = (long) (OptimizerUtils.getLocalMemBudget() - in.getInMemorySize());
+			// Worst case scenario: all partial maps contain all distinct values (if < #rows)
+			long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
+			// Reduce recode build blocks count till they fit int the memory budget
+			while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) {
+				rcdNumBuildBlks--;
+				totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
+				// TODO: Reduce only the ones with large maps
+			}
 		}
+		// TODO: If still don't fit, serialize the column encoders
 
 		// Set to 1 if not set by the above logics
 		for (int i=0; i<2; i++)
@@ -448,6 +452,11 @@ public class MultiColumnEncoder implements Encoder {
 		_partitionDone = true;
 		// Materialize the partition counts in the encoders
 		_columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], numBlocks[1]));
+		if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) {
+			int rcdNumBlocks = rcdNumBuildBlks;
+			recodeEncoders.forEach(e -> e.setNumPartitions(rcdNumBlocks, numBlocks[1]));
+		}
+		//System.out.println("Block count = ["+numBlocks[0]+", "+numBlocks[1]+"], Recode block count = "+rcdNumBuildBlks);
 	}
 
 	private void estimateRCMapSize(CacheBlock in, List<ColumnEncoderComposite> rcList) {
@@ -477,6 +486,25 @@ public class MultiColumnEncoder implements Encoder {
 		}
 	}
 
+	// Estimate total memory overhead of the partial recode maps of all recoders
+	private long getTotalMemOverhead(CacheBlock in, int nBuildpart, List<ColumnEncoderComposite> rcEncoders) {
+		long totMemOverhead = 0;
+		if (nBuildpart == 1) {
+			// Sum the estimated map sizes
+			totMemOverhead = rcEncoders.stream().mapToLong(ColumnEncoderComposite::getEstMetaSize).sum();
+			return totMemOverhead;
+		}
+		// Estimate map size of each partition and sum
+		for (ColumnEncoderComposite rce : rcEncoders) {
+			long avgEntrySize = rce.getEstMetaSize()/ rce.getEstNumDistincts();
+			int partSize = in.getNumRows()/nBuildpart;
+			int partNumDist = Math.min(partSize, rce.getEstNumDistincts()); //#distincts not more than #rows
+			long allMapsSize = partNumDist * avgEntrySize * nBuildpart; //worst-case scenario
+			totMemOverhead += allMapsSize;
+		}
+		return totMemOverhead;
+	}
+
 	private static void outputMatrixPreProcessing(MatrixBlock output, CacheBlock input, boolean hasDC) {
 		long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
 		if(output.isInSparseFormat()) {
diff --git a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
index 69c25fede6..943b344502 100644
--- a/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
+++ b/src/main/java/org/apache/sysds/runtime/util/DependencyTask.java
@@ -30,7 +30,7 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.DMLRuntimeException;
 
 public class DependencyTask<E> implements Comparable<DependencyTask<?>>, Callable<E> {
-	public static final boolean ENABLE_DEBUG_DATA = false;
+	public static final boolean ENABLE_DEBUG_DATA = false; // explain task graph
 	protected static final Log LOG = LogFactory.getLog(DependencyTask.class.getName());
 
 	private final Callable<E> _task;
diff --git a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index f1a4f1f3d8..05f06b065c 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -174,8 +174,9 @@ public class TransformStatistics {
 				outMatrixPreProcessingTime.longValue()*1e-9)).append(" sec.\n");
 			sb.append("TransformEncode PostProc. time:\t").append(String.format("%.3f",
 				outMatrixPostProcessingTime.longValue()*1e-9)).append(" sec.\n");
-			sb.append("TransformEncode SizeEst. time:\t").append(String.format("%.3f",
-				mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
+			if(mapSizeEstimationTime.longValue() > 0)
+				sb.append("TransformEncode SizeEst. time:\t").append(String.format("%.3f",
+					mapSizeEstimationTime.longValue()*1e-9)).append(" sec.\n");
 			return sb.toString();
 		}
 		return "";