You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/25 18:35:55 UTC
[systemds] 01/02: [SYSTEMDS-2610] CLA Updates
This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
commit b537181802f552f89014c282b681fea7c06ef404
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Sep 20 18:42:02 2021 +0200
[SYSTEMDS-2610] CLA Updates
- CLA Decompression exploiting common value
- Use compression ratio to multiply with cost of some operations
- lower sample ratio
- If the sample size is very small double it
- Workload tree corrections
- Set a max sample size in compression settings.
- Add Hybrid co-coding strategy to use both Que and AllCompare coCoding
---
src/main/java/org/apache/sysds/conf/DMLConfig.java | 2 +-
.../runtime/compress/CompressedMatrixBlock.java | 100 +---------
.../compress/CompressedMatrixBlockFactory.java | 1 +
.../runtime/compress/CompressionSettings.java | 8 +-
.../compress/CompressionSettingsBuilder.java | 21 +-
.../runtime/compress/cocode/CoCodeGreedy.java | 33 ++--
.../runtime/compress/cocode/CoCodeHybrid.java | 56 ++++++
.../runtime/compress/cocode/CoCodePriorityQue.java | 32 +--
.../runtime/compress/cocode/CoCoderFactory.java | 5 +-
.../runtime/compress/colgroup/ColGroupSDC.java | 8 +-
.../compress/colgroup/ColGroupSDCSingle.java | 9 +-
.../compress/colgroup/dictionary/ADictionary.java | 4 +-
.../compress/colgroup/dictionary/Dictionary.java | 11 +-
.../compress/cost/ComputationCostEstimator.java | 68 ++++---
.../estim/CompressedSizeEstimatorFactory.java | 27 ++-
.../runtime/compress/lib/CLALibBinaryCellOp.java | 16 +-
.../runtime/compress/lib/CLALibDecompress.java | 218 +++++++++++++++++++++
.../runtime/compress/lib/CLALibLeftMultBy.java | 44 +----
.../sysds/runtime/compress/lib/CLALibScalar.java | 5 +-
.../sysds/runtime/compress/lib/CLALibUtils.java | 79 ++++++++
.../compress/workload/WorkloadAnalyzer.java | 58 +++---
.../component/compress/workload/WorkloadTest.java | 2 +-
.../compress/configuration/CompressForce.java | 2 +-
23 files changed, 566 insertions(+), 243 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index db59505..a59101b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -131,7 +131,7 @@ public class DMLConfig
_defaultVals.put(COMPRESSED_LOSSY, "false" );
_defaultVals.put(COMPRESSED_VALID_COMPRESSIONS, "SDC,DDC");
_defaultVals.put(COMPRESSED_OVERLAPPING, "true" );
- _defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.05");
+ _defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.02");
_defaultVals.put(COMPRESSED_COCODE, "AUTO");
_defaultVals.put(COMPRESSED_COST_MODEL, "AUTO");
_defaultVals.put(COMPRESSED_TRANSPOSE, "auto");
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index bff9a21..4ff5bcb 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -27,12 +27,8 @@ import java.io.ObjectOutput;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang.NotImplementedException;
@@ -56,6 +52,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.lib.CLALibAppend;
import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
+import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
@@ -104,7 +101,6 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.utils.DMLCompressionStatistics;
@@ -250,15 +246,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
ret.allocateDenseBlock();
- if(isOverlapping()){
- Comparator<AColGroup> comp = Comparator.comparing(x -> effect(x));
- _colGroups.sort(comp);
- }
-
if(k == 1)
- decompress(ret);
+ CLALibDecompress.decompress(ret, getColGroups(), nonZeros, isOverlapping());
else
- decompress(ret, k);
+ CLALibDecompress.decompress(ret, getColGroups(), isOverlapping(), k);
if(this.isOverlapping())
ret.recomputeNonZeros();
@@ -275,47 +266,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
return ret;
}
- private double effect(AColGroup x){
- return - Math.max(x.getMax(), Math.abs(x.getMin()));
- }
-
- private MatrixBlock decompress(MatrixBlock ret) {
-
- ret.setNonZeros(nonZeros == -1 && !this.isOverlapping() ? recomputeNonZeros() : nonZeros);
- final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
- final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
- for(int i = 0; i < getNumRows(); i += blklen)
- for(AColGroup grp : _colGroups)
- grp.decompressToBlockUnSafe(ret, i, Math.min(i + blklen, rlen));
-
- return ret;
- }
-
- private MatrixBlock decompress(MatrixBlock ret, int k) {
- try {
- final ExecutorService pool = CommonThreadPool.get(k);
- final int rlen = getNumRows();
- final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
- final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
- final ArrayList<DecompressTask> tasks = new ArrayList<>();
- for(int i = 0; i * blklen < getNumRows(); i++)
- tasks.add(new DecompressTask(_colGroups, ret, i * blklen, Math.min((i + 1) * blklen, rlen),
- overlappingColGroups));
- List<Future<Long>> rtasks = pool.invokeAll(tasks);
- pool.shutdown();
-
- long nnz = 0;
- for(Future<Long> rt : rtasks)
- nnz += rt.get();
- ret.setNonZeros(nnz);
- }
- catch(InterruptedException | ExecutionException ex) {
- throw new DMLCompressionException("Parallel decompression failed", ex);
- }
-
- return ret;
- }
-
/**
* Get the cached decompressed matrix (if it exists otherwise null).
*
@@ -673,10 +623,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
ret = ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
}
-
+
if(ret.getNumRows() == 0 || ret.getNumColumns() == 0)
throw new DMLCompressionException("Error in outputted MM no dimensions");
-
+
return ret;
}
@@ -788,46 +738,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
return null;
}
- private static class DecompressTask implements Callable<Long> {
- private final List<AColGroup> _colGroups;
- private final MatrixBlock _ret;
- private final int _rl;
- private final int _ru;
- private final boolean _overlapping;
-
- protected DecompressTask(List<AColGroup> colGroups, MatrixBlock ret, int rl, int ru, boolean overlapping) {
- _colGroups = colGroups;
- _ret = ret;
- _rl = rl;
- _ru = ru;
- _overlapping = overlapping;
- }
-
- @Override
- public Long call() {
-
- // preallocate sparse rows to avoid repeated alloc
- if(!_overlapping && _ret.isInSparseFormat()) {
- int[] rnnz = new int[_ru - _rl];
- for(AColGroup grp : _colGroups)
- grp.countNonZerosPerRow(rnnz, _rl, _ru);
- SparseBlock rows = _ret.getSparseBlock();
- for(int i = _rl; i < _ru; i++)
- rows.allocate(i, rnnz[i - _rl]);
- }
-
- // decompress row partition
- for(AColGroup grp : _colGroups)
- grp.decompressToBlockUnSafe(_ret, _rl, _ru);
-
- // post processing (sort due to append)
- if(_ret.isInSparseFormat())
- _ret.sortSparseRows(_rl, _ru);
-
- return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, _ru - 1);
- }
- }
-
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 98567e0..d77ab82 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -484,6 +484,7 @@ public class CompressedMatrixBlockFactory {
case 1:
LOG.debug("--compression phase " + phase + " Grouping : " + getLastTimePhase());
LOG.debug("Grouping using: " + compSettings.columnPartitioner);
+ LOG.debug("Cost Calculated using: " + costEstimator);
LOG.debug("--Cocoded Columns estimated Compression:" + _stats.estimatedSizeCoCoded);
break;
case 2:
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
index b347031..c1a9cd4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
@@ -92,6 +92,11 @@ public class CompressionSettings {
*/
public final int minimumSampleSize;
+ /**
+ * The maximum size of the sample extracted.
+ */
+ public final int maxSampleSize;
+
/** The sample type used for sampling */
public final EstimationType estimationType;
@@ -110,7 +115,7 @@ public class CompressionSettings {
protected CompressionSettings(double samplingRatio, boolean allowSharedDictionary, String transposeInput, int seed,
boolean lossy, EnumSet<CompressionType> validCompressions, boolean sortValuesByLength,
- PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize,
+ PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize, int maxSampleSize,
EstimationType estimationType, CostType costComputationType, double minimumCompressionRatio) {
this.samplingRatio = samplingRatio;
this.allowSharedDictionary = allowSharedDictionary;
@@ -123,6 +128,7 @@ public class CompressionSettings {
this.maxColGroupCoCode = maxColGroupCoCode;
this.coCodePercentage = coCodePercentage;
this.minimumSampleSize = minimumSampleSize;
+ this.maxSampleSize= maxSampleSize;
this.estimationType = estimationType;
this.costComputationType = costComputationType;
this.minimumCompressionRatio = minimumCompressionRatio;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
index 156fdfc..d5fd036 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
@@ -42,6 +42,7 @@ public class CompressionSettingsBuilder {
private int maxColGroupCoCode = 10000;
private double coCodePercentage = 0.01;
private int minimumSampleSize = 2000;
+ private int maxSampleSize = 1000000;
private EstimationType estimationType = EstimationType.HassAndStokes;
private PartitionerType columnPartitioner;
private CostType costType;
@@ -247,6 +248,18 @@ public class CompressionSettingsBuilder {
}
/**
+ * Set the maximum sample size to extract from a given matrix, this overrules the sample percentage if the sample
+ * percentage extracted is higher than this maximum bound.
+ *
+ * @param maxSampleSize The maximum sample size to extract
+ * @return The CompressionSettingsBuilder
+ */
+ public CompressionSettingsBuilder setMaxSampleSize(int maxSampleSize) {
+ this.maxSampleSize = maxSampleSize;
+ return this;
+ }
+
+ /**
* Set the estimation type used for the sampled estimates.
*
* @param estimationType the estimation type in used.
@@ -268,6 +281,12 @@ public class CompressionSettingsBuilder {
return this;
}
+ /**
+ * Set the minimum compression ratio to be achieved by the compression.
+ *
+ * @param ratio The ratio to achieve while compressing
+ * @return The CompressionSettingsBuilder
+ */
public CompressionSettingsBuilder setMinimumCompressionRatio(double ratio) {
this.minimumCompressionRatio = ratio;
return this;
@@ -281,6 +300,6 @@ public class CompressionSettingsBuilder {
public CompressionSettings create() {
return new CompressionSettings(samplingRatio, allowSharedDictionary, transposeInput, seed, lossy,
validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage,
- minimumSampleSize, estimationType, costType, minimumCompressionRatio);
+ minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
index cf366e7..55bf3b9 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
@@ -36,25 +36,26 @@ import org.apache.sysds.runtime.compress.utils.Util;
public class CoCodeGreedy extends AColumnCoCoder {
-
- private final Memorizer mem;
-
protected CoCodeGreedy(CompressedSizeEstimator sizeEstimator, ICostEstimate costEstimator,
CompressionSettings cs) {
super(sizeEstimator, costEstimator, cs);
- mem = new Memorizer();
}
@Override
protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
- for(CompressedSizeInfoColGroup g : colInfos.compressionInfo)
+ colInfos.setInfo(join(colInfos.compressionInfo, _sest, _cest, _cs));
+ return colInfos;
+ }
+
+ protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> inputColumns, CompressedSizeEstimator sEst, ICostEstimate cEst, CompressionSettings cs) {
+ Memorizer mem = new Memorizer(cs, sEst);
+ for(CompressedSizeInfoColGroup g : inputColumns)
mem.put(g);
- colInfos.setInfo(coCodeBruteForce(colInfos.compressionInfo));
- return colInfos;
+ return coCodeBruteForce(inputColumns, cEst, mem);
}
- private List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns) {
+ private static List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns, ICostEstimate cEst, Memorizer mem) {
List<ColIndexes> workset = new ArrayList<>(inputColumns.size());
@@ -69,8 +70,8 @@ public class CoCodeGreedy extends AColumnCoCoder {
for(int j = i + 1; j < workset.size(); j++) {
final ColIndexes c1 = workset.get(i);
final ColIndexes c2 = workset.get(j);
- final double costC1 = _cest.getCostOfColumnGroup(mem.get(c1));
- final double costC2 = _cest.getCostOfColumnGroup(mem.get(c2));
+ final double costC1 = cEst.getCostOfColumnGroup(mem.get(c1));
+ final double costC2 = cEst.getCostOfColumnGroup(mem.get(c2));
mem.incst1();
// pruning filter : skip dominated candidates
@@ -82,7 +83,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
// Join the two column groups.
// and Memorize the new join.
final CompressedSizeInfoColGroup c1c2Inf = mem.getOrCreate(c1, c2);
- final double costC1C2 = _cest.getCostOfColumnGroup(c1c2Inf);
+ final double costC1C2 = cEst.getCostOfColumnGroup(c1c2Inf);
final double newSizeChangeIfSelected = costC1C2 - costC1 - costC2;
@@ -120,11 +121,15 @@ public class CoCodeGreedy extends AColumnCoCoder {
return ret;
}
- protected class Memorizer {
+ protected static class Memorizer {
+ private final CompressionSettings _cs;
+ private final CompressedSizeEstimator _sEst;
private final Map<ColIndexes, CompressedSizeInfoColGroup> mem;
private int st1 = 0, st2 = 0, st3 = 0, st4 = 0;
- public Memorizer() {
+ public Memorizer(CompressionSettings cs, CompressedSizeEstimator sEst) {
+ _cs = cs;
+ _sEst = sEst;
mem = new HashMap<>();
}
@@ -159,7 +164,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
g = CompressedSizeInfoColGroup.addConstGroup(c, left, _cs.validCompressions);
else {
st3++;
- g = _sest.estimateJoinCompressedSize(c, left, right);
+ g = _sEst.estimateJoinCompressedSize(c, left, right);
}
if(leftConst || rightConst)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
new file mode 100644
index 0000000..f762c0d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.cocode;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.cost.ICostEstimate;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+
+/**
+ * This cocode strategy starts out with priority que until a threshold number of columnGroups is achieved, then the
+ * strategy shifts into a greedy all compare.
+ */
+public class CoCodeHybrid extends AColumnCoCoder {
+
+ protected CoCodeHybrid(CompressedSizeEstimator sizeEstimator, ICostEstimate costEstimator, CompressionSettings cs) {
+ super(sizeEstimator, costEstimator, cs);
+ }
+
+ @Override
+ protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
+ final int startSize = colInfos.getInfo().size();
+ final int PriorityQueGoal = 40;
+ if(startSize > 200) {
+
+ colInfos.setInfo(CoCodePriorityQue.join(colInfos.getInfo(), _sest, _cest, PriorityQueGoal));
+
+ final int pqSize = colInfos.getInfo().size();
+ if(pqSize <= PriorityQueGoal)
+ colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, _cest, _cs));
+ }
+ else {
+ colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, _cest, _cs));
+ }
+
+ return colInfos;
+ }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
index 27b678c..b53859e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
@@ -48,12 +48,13 @@ public class CoCodePriorityQue extends AColumnCoCoder {
@Override
protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
- colInfos.setInfo(join(colInfos.getInfo()));
+ colInfos.setInfo(join(colInfos.getInfo(), _sest, _cest, 1));
return colInfos;
}
- private List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> currentGroups) {
- Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> _cest.getCostOfColumnGroup(x));
+ protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> currentGroups,
+ CompressedSizeEstimator sEst, ICostEstimate cEst, int minNumGroups) {
+ Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> cEst.getCostOfColumnGroup(x));
Queue<CompressedSizeInfoColGroup> que = new PriorityQueue<>(currentGroups.size(), comp);
List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
@@ -62,15 +63,16 @@ public class CoCodePriorityQue extends AColumnCoCoder {
que.add(g);
CompressedSizeInfoColGroup l = null;
- if(_cest.isCompareAll()) {
- double costBeforeJoin = _cest.getCostOfCollectionOfGroups(que);
+ if(cEst.isCompareAll()) {
+ double costBeforeJoin = cEst.getCostOfCollectionOfGroups(que);
l = que.poll();
- while(que.peek() != null) {
+ int groupNr = ret.size() + que.size();
+ while(que.peek() != null && groupNr >= minNumGroups) {
CompressedSizeInfoColGroup r = que.poll();
- final CompressedSizeInfoColGroup g = _sest.estimateJoinCompressedSize(l, r);
+ final CompressedSizeInfoColGroup g = sEst.estimateJoinCompressedSize(l, r);
if(g != null) {
- final double costOfJoin = _cest.getCostOfCollectionOfGroups(que, g);
+ final double costOfJoin = cEst.getCostOfCollectionOfGroups(que, g);
if(costOfJoin < costBeforeJoin) {
costBeforeJoin = costOfJoin;
que.add(g);
@@ -86,17 +88,19 @@ public class CoCodePriorityQue extends AColumnCoCoder {
}
l = que.poll();
+ groupNr = ret.size() + que.size();
}
}
else {
l = que.poll();
- while(que.peek() != null) {
+ int groupNr = ret.size() + que.size();
+ while(que.peek() != null && groupNr >= minNumGroups) {
CompressedSizeInfoColGroup r = que.peek();
- if(_cest.shouldTryJoin(l, r)) {
- CompressedSizeInfoColGroup g = _sest.estimateJoinCompressedSize(l, r);
+ if(cEst.shouldTryJoin(l, r)) {
+ CompressedSizeInfoColGroup g = sEst.estimateJoinCompressedSize(l, r);
if(g != null) {
- double costOfJoin = _cest.getCostOfColumnGroup(g);
- double costIndividual = _cest.getCostOfColumnGroup(l) + _cest.getCostOfColumnGroup(r);
+ double costOfJoin = cEst.getCostOfColumnGroup(g);
+ double costIndividual = cEst.getCostOfColumnGroup(l) + cEst.getCostOfColumnGroup(r);
if(costOfJoin < costIndividual) {
que.poll();
@@ -112,8 +116,10 @@ public class CoCodePriorityQue extends AColumnCoCoder {
ret.add(l);
l = que.poll();
+ groupNr = ret.size() + que.size();
}
}
+
if(l != null)
ret.add(l);
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
index c467f70..eaf9cb4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
@@ -63,10 +63,7 @@ public class CoCoderFactory {
// TODO make decision better depending on how much time is allocated for the compression
// for instance if the compressed object is used for a million instructions, it might be good to
// search for a really good compression even if it take longer.
- if(est.getNumColumns() > 200)
- return new CoCodePriorityQue(est, costEstimator, cs);
- else
- return new CoCodeGreedy(est, costEstimator, cs);
+ return new CoCodeHybrid(est, costEstimator, cs);
case GREEDY:
return new CoCodeGreedy(est, costEstimator, cs);
case BIN_PACKING:
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 5c6d365..fb93ed6 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -558,11 +558,13 @@ public class ColGroupSDC extends ColGroupValue {
public ColGroupSDCZeros extractCommon(double[] constV) {
double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length);
- for(int i = 0; i < _colIndexes.length; i++) {
+ if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC.
+ return new ColGroupSDCZeros(_colIndexes, _numRows, _dict, _indexes, _data, getCounts());
+
+ for(int i = 0; i < _colIndexes.length; i++)
constV[_colIndexes[i]] += commonV[i];
- }
+
ADictionary subtractedDict = _dict.subtractTuple(commonV);
return new ColGroupSDCZeros(_colIndexes, _numRows, subtractedDict, _indexes, _data, getCounts());
}
-
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index de0b4a8..b2b4283 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -117,8 +117,8 @@ public class ColGroupSDCSingle extends ColGroupValue {
for(int j = 0; j < nCol; j++)
c[off + _colIndexes[j]] += values[offsetToDefault + j];
}
-
- _indexes.cacheIterator(it, ru );
+
+ _indexes.cacheIterator(it, ru);
}
@Override
@@ -473,11 +473,14 @@ public class ColGroupSDCSingle extends ColGroupValue {
public ColGroupSDCSingleZeros extractCommon(double[] constV) {
double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length);
+ if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC.
+ return new ColGroupSDCSingleZeros(_colIndexes, _numRows, _dict, _indexes, getCachedCounts());
+
for(int i = 0; i < _colIndexes.length; i++)
constV[_colIndexes[i]] += commonV[i];
ADictionary subtractedDict = _dict.subtractTuple(commonV);
- return new ColGroupSDCSingleZeros(_colIndexes, _numRows, subtractedDict, _indexes, getCounts());
+ return new ColGroupSDCSingleZeros(_colIndexes, _numRows, subtractedDict, _indexes, getCachedCounts());
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 15c74b0..df75268 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -35,7 +35,7 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
public abstract class ADictionary implements Serializable {
private static final long serialVersionUID = 9118692576356558592L;
-
+
protected static final Log LOG = LogFactory.getLog(ADictionary.class.getName());
/**
@@ -342,6 +342,8 @@ public abstract class ADictionary implements Serializable {
/**
* Get the values contained in a specific tuple of the dictionary.
*
+ * If the entire row is zero return null.
+ *
* @param index The index where the values are located
* @param nCol The number of columns contained in this dictionary
* @return a materialized double array containing the tuple.
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 5b9d834..0b65ed1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -43,7 +43,7 @@ import org.apache.sysds.utils.MemoryEstimates;
public class Dictionary extends ADictionary {
private static final long serialVersionUID = -6517136537249507753L;
-
+
private final double[] _values;
public Dictionary(double[] values) {
@@ -329,14 +329,13 @@ public class Dictionary extends ADictionary {
if(colIndexes == 1)
sb.append(Arrays.toString(_values));
else {
- sb.append("[\n");
+ sb.append("[\n\t");
for(int i = 0; i < _values.length - 1; i++) {
sb.append(_values[i]);
- sb.append((i) % (colIndexes) == colIndexes - 1 ? "\nt" + i + ": " : ", ");
+ sb.append((i) % (colIndexes) == colIndexes - 1 ? "\n\t" : ", ");
}
sb.append(_values[_values.length - 1]);
-
- sb.append("\n]");
+ sb.append("]");
}
return sb.toString();
}
@@ -476,7 +475,7 @@ public class Dictionary extends ADictionary {
}
@Override
- public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, int[] aggregateColumns, double[] b,
+ public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, int[] aggregateColumns, double[] b,
int cut) {
double[] ret = new double[numVals * aggregateColumns.length];
for(int k = 0, off = 0;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 1653f51..07cf58b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -23,6 +23,7 @@ import java.util.Collection;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class ComputationCostEstimator implements ICostEstimate {
@@ -41,7 +42,6 @@ public class ComputationCostEstimator implements ICostEstimate {
private final int _leftMultiplications;
private final int _rightMultiplications;
private final int _compressedMultiplication;
- // private final int _rowBasedOps;
private final int _dictionaryOps;
private final boolean _isDensifying;
@@ -100,16 +100,21 @@ public class ComputationCostEstimator implements ICostEstimate {
cost += _scans * scanCost(g);
cost += _decompressions * decompressionCost(g);
cost += _overlappingDecompressions * overlappingDecompressionCost(g);
- // 16 is assuming that the left side is 16 rows.
- double lmc = leftMultCost(g) * 16;
- cost += _leftMultiplications * lmc;
- // 16 is assuming that the right side is 16 rows.
- double rmc = rightMultCost(g) * 16;
- cost += _rightMultiplications * rmc;
-
- // cost += _compressedMultiplication * (lmc + rmc);
- cost += _compressedMultiplication * _compressedMultCost(g);
+ // 16 is assuming that the left / right side is 16 rows/cols.
+ final int rowsCols = 16;
+ cost += _leftMultiplications * leftMultCost(g) * rowsCols;
+ cost += _rightMultiplications * rightMultCost(g) * rowsCols;
cost += _dictionaryOps * dictionaryOpsCost(g);
+
+ double size = g.getMinSize();
+ final double compressionRatio = size / MatrixBlock.estimateSizeDenseInMemory(_nRows, _nCols) / g.getColumns().length;
+
+ cost *= 0.001 + compressionRatio;
+
+ cost += _compressedMultiplication * _compressedMultCost(g) * rowsCols;
+
+ // double uncompressedSize = g.getCompressionSize(CompressionType.UNCOMPRESSED);
+
return cost;
}
@@ -118,13 +123,14 @@ public class ComputationCostEstimator implements ICostEstimate {
}
private double leftMultCost(CompressedSizeInfoColGroup g) {
- final int nCols = g.getColumns().length;
- final double preAggregateCost = _nRows;
+ final int nColsInGroup = g.getColumns().length;
+ final double mcf = g.getMostCommonFraction();
+ final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.4 * mcf) : _nRows;
final double numberTuples = g.getNumVals();
final double tupleSparsity = g.getTupleSparsity();
- final double postScalingCost = (nCols > 1 && tupleSparsity > 0.4) ? numberTuples * nCols * tupleSparsity *
- 1.4 : numberTuples * nCols;
+ final double postScalingCost = (nColsInGroup > 1 && tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
+ 1.4 : numberTuples * nColsInGroup;
if(numberTuples < 64000)
return preAggregateCost + postScalingCost;
else
@@ -134,10 +140,11 @@ public class ComputationCostEstimator implements ICostEstimate {
private double _compressedMultCost(CompressedSizeInfoColGroup g) {
final int nColsInGroup = g.getColumns().length;
- final double mcf = g.getMostCommonFraction();
- final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.7 * mcf) : _nRows;
+ // final double mcf = g.getMostCommonFraction();
+ // final double preAggregateCost = (mcf > 0.6 ? _nRows * (1 - 0.6 * mcf) : _nRows) * 4;
+ final double preAggregateCost = _nRows;
- final double numberTuples = (float) g.getNumVals();
+ final double numberTuples = g.getNumVals();
final double tupleSparsity = g.getTupleSparsity();
final double postScalingCost = (nColsInGroup > 1 && tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
1.4 : numberTuples * nColsInGroup;
@@ -163,7 +170,10 @@ public class ComputationCostEstimator implements ICostEstimate {
}
private double overlappingDecompressionCost(CompressedSizeInfoColGroup g) {
- return _nRows * 16 * (g.getNumVals() / 64000 + 1);
+ final double mcf = g.getMostCommonFraction();
+ final double rowsCost = mcf > 0.6 ? _nRows * (1 - 0.6 * mcf) : _nRows;
+ // Setting 64 to mark decompression as expensive.
+ return rowsCost * 16 * ((float)g.getNumVals() / 64000 + 1);
}
private static double dictionaryOpsCost(CompressedSizeInfoColGroup g) {
@@ -259,15 +269,19 @@ public class ComputationCostEstimator implements ICostEstimate {
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(this.getClass().getSimpleName());
- sb.append("\n");
- sb.append(_nRows + " ");
- sb.append(_scans + " ");
- sb.append(_decompressions + " ");
- sb.append(_overlappingDecompressions + " ");
- sb.append(_leftMultiplications + " ");
- sb.append(_rightMultiplications + " ");
- sb.append(_compressedMultiplication + " ");
- sb.append(_dictionaryOps + " ");
+ sb.append("dims(");
+ sb.append(_nRows + ",");
+ sb.append(_nCols + ") ");
+ sb.append("CostVector:[");
+ sb.append(_scans + ",");
+ sb.append(_decompressions + ",");
+ sb.append(_overlappingDecompressions + ",");
+ sb.append(_leftMultiplications + ",");
+ sb.append(_rightMultiplications + ",");
+ sb.append(_compressedMultiplication + ",");
+ sb.append(_dictionaryOps + "]");
+ sb.append(" Densifying:");
+ sb.append(_isDensifying);
return sb.toString();
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
index a4b1c99..a9a8e44 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
@@ -27,7 +27,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class CompressedSizeEstimatorFactory {
protected static final Log LOG = LogFactory.getLog(CompressedSizeEstimatorFactory.class.getName());
- private static final int maxSampleSize = 1000000;
public static CompressedSizeEstimator getSizeEstimator(MatrixBlock data, CompressionSettings cs, int k) {
@@ -36,7 +35,7 @@ public class CompressedSizeEstimatorFactory {
final int nnzRows = (int) Math.ceil(data.getNonZeros() / nCols);
final double sampleRatio = cs.samplingRatio;
- final int sampleSize = Math.min(getSampleSize(sampleRatio, nRows, cs.minimumSampleSize), maxSampleSize);
+ final int sampleSize = getSampleSize(sampleRatio, nRows, nCols, cs.minimumSampleSize, cs.maxSampleSize);
if(nCols > 1000) {
return tryToMakeSampleEstimator(data, cs, sampleRatio, sampleSize / 10, nRows, nnzRows, k);
@@ -79,7 +78,27 @@ public class CompressedSizeEstimatorFactory {
return cs.samplingRatio >= 1.0 || nRows < cs.minimumSampleSize || sampleSize >= nnzRows;
}
- private static int getSampleSize(double sampleRatio, int nRows, int minimumSampleSize) {
- return Math.max((int) Math.ceil(nRows * sampleRatio), minimumSampleSize);
+ /**
+ * This function returns the sample size to use.
+ *
+ * The sampling is bound by the maximum sampling and the minimum sampling other than that a linear relation is used
+ * with the sample ratio.
+ *
+ * Also influencing the sample size is the number of columns. If the number of columns is large the sample size is
+ * scaled down, this gives worse estimations of distinct items, but it makes sure that the compression time is more
+ * consistent.
+ *
+ * @param sampleRatio The sample ratio
+ * @param nRows The number of rows
+ * @param nCols The number of columns
+ * @param minimumSampleSize the minimum sample size
+ * @param maxSampleSize the maximum sample size
+ * @return The sample size to use.
+ */
+ private static int getSampleSize(double sampleRatio, int nRows, int nCols, int minSampleSize, int maxSampleSize) {
+ int sampleSize = (int) Math.ceil(nRows * sampleRatio / Math.max(1, (double)nCols / 150));
+ if(sampleSize < 20000)
+ sampleSize *= 2;
+ return Math.min(Math.max(sampleSize, minSampleSize), maxSampleSize);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
index e330ab8..b9b64fa 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
@@ -105,7 +105,8 @@ public class CLALibBinaryCellOp {
result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 0);
else if(fn instanceof Minus1Multiply)
result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 1);
- else if(fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply || fn instanceof PlusMultiply){
+ else if(fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply ||
+ fn instanceof PlusMultiply) {
CompressedMatrixBlock ret = new CompressedMatrixBlock();
ret.copy(m1);
return ret;
@@ -132,7 +133,11 @@ public class CLALibBinaryCellOp {
// TODO optimize to allow for sparse outputs.
final int outCells = outRows * outCols;
if(atype == BinaryAccessType.MATRIX_COL_VECTOR) {
- result.reset(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+ if(result != null)
+ result.reset(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+ else
+ result = new MatrixBlock(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+
MatrixBlock d_compressed = m1.getCachedDecompressed();
if(d_compressed != null) {
if(left)
@@ -146,12 +151,15 @@ public class CLALibBinaryCellOp {
}
else if(atype == BinaryAccessType.MATRIX_MATRIX) {
- result.reset(outRows, outCols, outCells);
+ if(result != null)
+ result.reset(outRows, outCols, outCells);
+ else
+ result = new MatrixBlock(outRows, outCols, outCells);
MatrixBlock d_compressed = m1.getCachedDecompressed();
if(d_compressed == null)
d_compressed = m1.getUncompressed("MatrixMatrix " + op);
-
+
if(left)
LibMatrixBincell.bincellOp(that, d_compressed, result, op);
else
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
new file mode 100644
index 0000000..ff2d211
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+/**
+ * Library to decompress a list of column groups into a matrix.
+ */
+public class CLALibDecompress {
+ public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> groups, long nonZeros, boolean overlapping) {
+
+ final int rlen = ret.getNumRows();
+ final int clen = ret.getNumColumns();
+ final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / clen);
+ final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
+ final boolean containsSDC = CLALibUtils.containsSDC(groups);
+ double[] constV = containsSDC ? new double[ret.getNumColumns()] : null;
+ final List<AColGroup> filteredGroups = containsSDC ? CLALibUtils.filterSDCGroups(groups, constV) : groups;
+
+ sortGroups(filteredGroups, overlapping);
+ // check if we are using filtered groups, and if we are not force constV to null
+ if(groups == filteredGroups)
+ constV = null;
+
+ final double eps = getEps(constV);
+ for(int i = 0; i < rlen; i += blklen) {
+ final int rl = i;
+ final int ru = Math.min(i + blklen, rlen);
+ for(AColGroup grp : filteredGroups)
+ grp.decompressToBlockUnSafe(ret, rl, ru);
+ if(constV != null && !ret.isInSparseFormat())
+ addVector(ret, constV, eps, rl, ru);
+ }
+
+ ret.setNonZeros(nonZeros == -1 || overlapping ? ret.recomputeNonZeros() : nonZeros);
+
+ return ret;
+ }
+
+ public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> groups, boolean overlapping, int k) {
+
+ try {
+ final ExecutorService pool = CommonThreadPool.get(k);
+ final int rlen = ret.getNumRows();
+ final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / ret.getNumColumns());
+ final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
+
+ final boolean containsSDC = CLALibUtils.containsSDC(groups);
+ double[] constV = containsSDC ? new double[ret.getNumColumns()] : null;
+ final List<AColGroup> filteredGroups = containsSDC ? CLALibUtils.filterSDCGroups(groups, constV) : groups;
+ sortGroups(filteredGroups, overlapping);
+
+ // check if we are using filtered groups, and if we are not force constV to null
+ if(groups == filteredGroups)
+ constV = null;
+
+ final double eps = getEps(constV);
+ final ArrayList<DecompressTask> tasks = new ArrayList<>();
+ for(int i = 0; i * blklen < rlen; i++)
+ tasks.add(new DecompressTask(filteredGroups, ret, eps, i * blklen, Math.min((i + 1) * blklen, rlen),
+ overlapping, constV));
+ List<Future<Long>> rtasks = pool.invokeAll(tasks);
+ pool.shutdown();
+
+ long nnz = 0;
+ for(Future<Long> rt : rtasks)
+ nnz += rt.get();
+ ret.setNonZeros(nnz);
+ }
+ catch(InterruptedException | ExecutionException ex) {
+ throw new DMLCompressionException("Parallel decompression failed", ex);
+ }
+
+ return ret;
+ }
+
+ private static void sortGroups(List<AColGroup> groups, boolean overlapping) {
+ if(overlapping) {
+ // add a bit of stability in decompression
+ Comparator<AColGroup> comp = Comparator.comparing(x -> effect(x));
+ groups.sort(comp);
+ }
+ }
+
+ /**
+ * Calculate an effect value for a column group. This is used to sort the groups before decompression to decompress
+ * the columns that have the smallest effect first.
+ *
+ * @param x A Group
+ * @return A Effect double value.
+ */
+ private static double effect(AColGroup x) {
+ return -Math.max(x.getMax(), Math.abs(x.getMin()));
+ }
+
+ /**
+ * Get a small epsilon from the constant group.
+ *
+ * @param constV the constant vector.
+ * @return epsilon
+ */
+ private static double getEps(double[] constV) {
+ if(constV == null)
+ return 0;
+ else {
+ double max = -Double.MAX_VALUE;
+ double min = Double.MAX_VALUE;
+ for(double v : constV){
+ if(v > max)
+ max = v;
+ if(v < min)
+ min = v;
+ }
+ final double eps = (max-min) * 1e-13;
+ return eps;
+ }
+ }
+
+ private static class DecompressTask implements Callable<Long> {
+ private final List<AColGroup> _colGroups;
+ private final MatrixBlock _ret;
+ private final double _eps;
+ private final int _rl;
+ private final int _ru;
+ private final double[] _constV;
+ private final boolean _overlapping;
+
+ protected DecompressTask(List<AColGroup> colGroups, MatrixBlock ret, double eps, int rl, int ru,
+ boolean overlapping, double[] constV) {
+ _colGroups = colGroups;
+ _ret = ret;
+ _eps = eps;
+ _rl = rl;
+ _ru = ru;
+ _overlapping = overlapping;
+ _constV = constV;
+ }
+
+ @Override
+ public Long call() {
+ // decompress row partition
+ for(AColGroup grp : _colGroups)
+ grp.decompressToBlockUnSafe(_ret, _rl, _ru);
+
+ if(_constV != null)
+ addVector(_ret, _constV, _eps, _rl, _ru);
+
+ return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, _ru - 1);
+ }
+ }
+
+ /**
+ * Add the rowV vector to each row in ret.
+ *
+ * @param ret matrix to add the vector to
+ * @param rowV The row vector to add
+ * @param eps an epsilon defined, to round the output value to zero if the value is less than epsilon away from
+ * zero.
+ * @param rl The row to start at
+ * @param ru The row to end at
+ */
+ private static void addVector(final MatrixBlock ret, final double[] rowV, final double eps, final int rl,
+ final int ru) {
+ final int nCols = ret.getNumColumns();
+ final DenseBlock db = ret.getDenseBlock();
+ if(eps == 0) {
+ for(int row = rl; row < ru; row++) {
+ final double[] _retV = db.values(row);
+ final int off = db.pos(row);
+ for(int col = 0; col < nCols; col++)
+ _retV[off + col] += rowV[col];
+ }
+ }
+ else {
+ for(int row = rl; row < ru; row++) {
+ final double[] _retV = db.values(row);
+ final int off = db.pos(row);
+ for(int col = 0; col < nCols; col++) {
+ final int out = off + col;
+ _retV[out] += rowV[col];
+ if(Math.abs(_retV[out]) <= eps)
+ _retV[out] = 0;
+ }
+ }
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 8b78e58..dc2df68 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -110,11 +110,11 @@ public class CLALibLeftMultBy {
multAllColGroups(groups, groups, result);
}
else {
- final boolean containsSDC = containsSDC(groups);
- final int numColumns = cmb.getNumColumns();
+ final boolean containsSDC = CLALibUtils.containsSDC(groups);
final double[] constV = containsSDC ? new double[cmb.getNumColumns()] : null;
- final List<AColGroup> filteredGroups = filterSDCGroups(groups, constV);
+ final List<AColGroup> filteredGroups = CLALibUtils.filterSDCGroups(groups, constV);
final double[] colSums = containsSDC ? new double[cmb.getNumColumns()] : null;
+ final int numColumns = cmb.getNumColumns();
if(containsSDC)
for(int i = 0; i < groups.size(); i++) {
@@ -298,12 +298,13 @@ public class CLALibLeftMultBy {
}
final int numColumnsOut = ret.getNumColumns();
- final boolean containsSDC = containsSDC(colGroups);
+ final boolean containsSDC = CLALibUtils.containsSDC(colGroups);
// a constant colgroup summing the default values.
- final double[] constV = containsSDC ? new double[numColumnsOut] : null;
- final List<AColGroup> filteredGroups = filterSDCGroups(colGroups, constV);
-
+ double[] constV = containsSDC ? new double[numColumnsOut] : null;
+ final List<AColGroup> filteredGroups = CLALibUtils.filterSDCGroups(colGroups, constV);
+ if(colGroups == filteredGroups)
+ constV = null;
final double[] rowSums = containsSDC ? new double[that.getNumRows()] : null;
if(k == 1) {
@@ -633,33 +634,4 @@ public class CLALibLeftMultBy {
Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
return ColGroupValues;
}
-
- private static boolean containsSDC(List<AColGroup> groups) {
- boolean containsSDC = false;
-
- for(AColGroup g : groups) {
- if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle) {
- containsSDC = true;
- break;
- }
- }
- return containsSDC;
- }
-
- private static List<AColGroup> filterSDCGroups(List<AColGroup> groups, double[] constV) {
- if(constV != null) {
- final List<AColGroup> filteredGroups = new ArrayList<>();
- for(AColGroup g : groups) {
- if(g instanceof ColGroupSDC)
- filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
- else if(g instanceof ColGroupSDCSingle)
- filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
- else
- filteredGroups.add(g);
- }
- return filteredGroups;
- }
- else
- return groups;
- }
}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
index 1822685..ace2843 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
@@ -68,8 +68,7 @@ public class CLALibScalar {
if(m1.isOverlapping() && !(sop.fn instanceof Multiply || sop.fn instanceof Divide)) {
AColGroup constOverlap = constOverlap(m1, sop);
List<AColGroup> newColGroups = (sop instanceof LeftScalarOperator &&
- sop.fn instanceof Minus) ? processOverlappingSubtractionLeft(m1,
- sop,
+ sop.fn instanceof Minus) ? processOverlappingSubtractionLeft(m1, sop,
ret) : processOverlappingAddition(m1, sop, ret);
newColGroups.add(constOverlap);
ret.allocateColGroupList(newColGroups);
@@ -93,8 +92,8 @@ public class CLALibScalar {
}
ret.recomputeNonZeros();
- return ret;
+ return ret;
}
private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
new file mode 100644
index 0000000..c701b96
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
+
+public class CLALibUtils {
+
+ /**
+ * Helper method to determine if the column groups contains SDC
+ *
+ * Note that it only returns true, if there is more than one SDC Group.
+ *
+ * @param groups The ColumnGroups to analyze
+ * @return A Boolean saying it there is >= 2 SDC Groups.
+ */
+ protected static boolean containsSDC(List<AColGroup> groups) {
+ int count = 0;
+ for(AColGroup g : groups) {
+ if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle) {
+ count++;
+ if(count > 1)
+ break;
+ }
+ }
+ return count > 1;
+ }
+
+ /**
+ * Helper method to filter out SDC Groups, to add their common value to the ConstV. This allows exploitation of the
+ * common values in the SDC Groups.
+ *
+ * @param groups The Column Groups
+ * @param constV The Constant vector to add common values to.
+ * @return The Filtered list of Column groups containing no SDC Groups but only SDCZero groups.
+ */
+ protected static List<AColGroup> filterSDCGroups(List<AColGroup> groups, double[] constV) {
+ if(constV != null) {
+ final List<AColGroup> filteredGroups = new ArrayList<>();
+ for(AColGroup g : groups) {
+ if(g instanceof ColGroupSDC)
+ filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
+ else if(g instanceof ColGroupSDCSingle)
+ filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
+ else
+ filteredGroups.add(g);
+ }
+ for(double v : constV)
+ if(!Double.isFinite(v))
+ return groups;
+
+ return filteredGroups;
+ }
+ else
+ return groups;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index e086974..06431db 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -425,40 +425,48 @@ public class WorkloadAnalyzer {
setDecompressionOnAllInputs(hop, parent);
return;
}
- // shortcut instead of comparing to MatrixScalar or RowVector.
- else if(hop.getInput(1).getDim1() == 1 || hop.getInput(1).isScalar() || hop.getInput(0).isScalar()) {
-
+ else {
ArrayList<Hop> in = hop.getInput();
final boolean ol0 = isOverlapping(in.get(0));
final boolean ol1 = isOverlapping(in.get(1));
final boolean ol = ol0 || ol1;
- if(ol && HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
- overlapping.add(hop.getHopID());
- o = new OpNormal(hop, true);
- o.setOverlapping();
+
+ // shortcut instead of comparing to MatrixScalar or RowVector.
+ if(in.get(1).getDim1() == 1 || in.get(1).isScalar() || in.get(0).isScalar()) {
+
+ if(ol && HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
+ overlapping.add(hop.getHopID());
+ o = new OpNormal(hop, true);
+ o.setOverlapping();
+ }
+ else if(ol) {
+ treeLookup.get(in.get(0).getHopID()).setDecompressing();
+ return;
+ }
+ else {
+ o = new OpNormal(hop, true);
+ }
+ if(!HopRewriteUtils.isBinarySparseSafe(hop))
+ o.setDensifying();
+
}
- else if(ol) {
- treeLookup.get(in.get(0).getHopID()).setDecompressing();
+ else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
+ HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
+ HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
+ setDecompressionOnAllInputs(hop, parent);
+ return;
+ }
+ else if(ol0 || ol1){
+ setDecompressionOnAllInputs(hop, parent);
return;
}
else {
- o = new OpNormal(hop, true);
+ String ex = "Setting decompressed because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n"
+ + Explain.explain(hop);
+ LOG.warn(ex);
+ setDecompressionOnAllInputs(hop, parent);
+ return;
}
- if(!HopRewriteUtils.isBinarySparseSafe(hop))
- o.setDensifying();
-
- }
- else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
- HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
- HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
- setDecompressionOnAllInputs(hop, parent);
- return;
- }
- else {
- String ex = "Setting decompressed because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n"
- + Explain.explain(hop);
- LOG.warn(ex);
- setDecompressionOnAllInputs(hop, parent);
}
}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
index f1682a7..ba50a0e 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
@@ -117,7 +117,7 @@ public class WorkloadTest {
args.put("$3", "0");
// no recompile
- tests.add(new Object[] {0, 1, 1, 1, 1, 1, 6, 0, true, false, "functions/lmDS.dml", args});
+ tests.add(new Object[] {0, 1, 1, 1, 1, 1, 5, 0, true, false, "functions/lmDS.dml", args});
// with recompile
tests.add(new Object[] {0, 0, 0, 1, 0, 1, 0, 0, true, true, "functions/lmDS.dml", args});
tests.add(new Object[] {0, 0, 0, 1, 10, 10, 1, 0, true, true, "functions/lmCG.dml", args});
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
index dc5b17f..c5710bd 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
@@ -188,7 +188,7 @@ public class CompressForce extends CompressBase {
// be aware that with multiple blocks it is likely that the small blocks
// initially compress, but is to large for overlapping state therefor will decompress.
// In this test it decompress the second small block but keeps the first in overlapping state.
- runTest(1110, 30, 1, 1, ExecType.SPARK, "mmr_sum_plus_2");
+ compressTest(1110, 10, 1.0, ExecType.SPARK, 1, 6, 1, 1, 1, "mmr_sum_plus_2");
}
@Test