You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2021/09/25 18:35:54 UTC

[systemds] branch master updated (b11521b -> 535263b)

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git.


    from b11521b  [MINOR] Add config to SystemDS Worker startup
     new b537181  [SYSTEMDS-2610] CLA Updates
     new 535263b  [SYSTEMDS-3144] Spark Local Context from command line

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 src/main/java/org/apache/sysds/conf/DMLConfig.java |   6 +-
 .../runtime/compress/CompressedMatrixBlock.java    | 100 +---------
 .../compress/CompressedMatrixBlockFactory.java     |   1 +
 .../runtime/compress/CompressionSettings.java      |   8 +-
 .../compress/CompressionSettingsBuilder.java       |  21 +-
 .../runtime/compress/cocode/CoCodeGreedy.java      |  33 ++--
 .../{CoCodeStatic.java => CoCodeHybrid.java}       |  33 +++-
 .../runtime/compress/cocode/CoCodePriorityQue.java |  32 +--
 .../runtime/compress/cocode/CoCoderFactory.java    |   5 +-
 .../runtime/compress/colgroup/ColGroupSDC.java     |   8 +-
 .../compress/colgroup/ColGroupSDCSingle.java       |   9 +-
 .../compress/colgroup/dictionary/ADictionary.java  |   4 +-
 .../compress/colgroup/dictionary/Dictionary.java   |  11 +-
 .../compress/cost/ComputationCostEstimator.java    |  68 ++++---
 .../estim/CompressedSizeEstimatorFactory.java      |  27 ++-
 .../runtime/compress/lib/CLALibBinaryCellOp.java   |  16 +-
 .../runtime/compress/lib/CLALibDecompress.java     | 218 +++++++++++++++++++++
 .../runtime/compress/lib/CLALibLeftMultBy.java     |  44 +----
 .../sysds/runtime/compress/lib/CLALibScalar.java   |   5 +-
 .../sysds/runtime/compress/lib/CLALibUtils.java    |  79 ++++++++
 .../compress/workload/WorkloadAnalyzer.java        |  58 +++---
 .../context/SparkExecutionContext.java             |  50 +++--
 .../component/compress/workload/WorkloadTest.java  |   2 +-
 .../compress/configuration/CompressForce.java      |   2 +-
 24 files changed, 570 insertions(+), 270 deletions(-)
 copy src/main/java/org/apache/sysds/runtime/compress/cocode/{CoCodeStatic.java => CoCodeHybrid.java} (51%)
 create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
 create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java

[systemds] 01/02: [SYSTEMDS-2610] CLA Updates

Posted by ba...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit b537181802f552f89014c282b681fea7c06ef404
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Mon Sep 20 18:42:02 2021 +0200

    [SYSTEMDS-2610] CLA Updates
    
    - CLA Decompression exploiting common value
    - Use compression ratio to multiply with cost of some operations
    - lower sample ratio
    - If the sample size is very small double it
    - Workload tree corrections
    - Set a max sample size in compression settings.
    - Add Hybrid co-coding strategy to use both Que and AllCompare coCoding
---
 src/main/java/org/apache/sysds/conf/DMLConfig.java |   2 +-
 .../runtime/compress/CompressedMatrixBlock.java    | 100 +---------
 .../compress/CompressedMatrixBlockFactory.java     |   1 +
 .../runtime/compress/CompressionSettings.java      |   8 +-
 .../compress/CompressionSettingsBuilder.java       |  21 +-
 .../runtime/compress/cocode/CoCodeGreedy.java      |  33 ++--
 .../runtime/compress/cocode/CoCodeHybrid.java      |  56 ++++++
 .../runtime/compress/cocode/CoCodePriorityQue.java |  32 +--
 .../runtime/compress/cocode/CoCoderFactory.java    |   5 +-
 .../runtime/compress/colgroup/ColGroupSDC.java     |   8 +-
 .../compress/colgroup/ColGroupSDCSingle.java       |   9 +-
 .../compress/colgroup/dictionary/ADictionary.java  |   4 +-
 .../compress/colgroup/dictionary/Dictionary.java   |  11 +-
 .../compress/cost/ComputationCostEstimator.java    |  68 ++++---
 .../estim/CompressedSizeEstimatorFactory.java      |  27 ++-
 .../runtime/compress/lib/CLALibBinaryCellOp.java   |  16 +-
 .../runtime/compress/lib/CLALibDecompress.java     | 218 +++++++++++++++++++++
 .../runtime/compress/lib/CLALibLeftMultBy.java     |  44 +----
 .../sysds/runtime/compress/lib/CLALibScalar.java   |   5 +-
 .../sysds/runtime/compress/lib/CLALibUtils.java    |  79 ++++++++
 .../compress/workload/WorkloadAnalyzer.java        |  58 +++---
 .../component/compress/workload/WorkloadTest.java  |   2 +-
 .../compress/configuration/CompressForce.java      |   2 +-
 23 files changed, 566 insertions(+), 243 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index db59505..a59101b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -131,7 +131,7 @@ public class DMLConfig
 		_defaultVals.put(COMPRESSED_LOSSY,       "false" );
 		_defaultVals.put(COMPRESSED_VALID_COMPRESSIONS, "SDC,DDC");
 		_defaultVals.put(COMPRESSED_OVERLAPPING, "true" );
-		_defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.05");
+		_defaultVals.put(COMPRESSED_SAMPLING_RATIO, "0.02");
 		_defaultVals.put(COMPRESSED_COCODE,      "AUTO");
 		_defaultVals.put(COMPRESSED_COST_MODEL,  "AUTO");
 		_defaultVals.put(COMPRESSED_TRANSPOSE,   "auto");
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
index bff9a21..4ff5bcb 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java
@@ -27,12 +27,8 @@ import java.io.ObjectOutput;
 import java.lang.ref.SoftReference;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
-import java.util.concurrent.Callable;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 
 import org.apache.commons.lang.NotImplementedException;
@@ -56,6 +52,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
 import org.apache.sysds.runtime.compress.lib.CLALibAppend;
 import org.apache.sysds.runtime.compress.lib.CLALibBinaryCellOp;
 import org.apache.sysds.runtime.compress.lib.CLALibCompAgg;
+import org.apache.sysds.runtime.compress.lib.CLALibDecompress;
 import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
 import org.apache.sysds.runtime.compress.lib.CLALibReExpand;
 import org.apache.sysds.runtime.compress.lib.CLALibRightMultBy;
@@ -104,7 +101,6 @@ import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
-import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.utils.DMLCompressionStatistics;
 
@@ -250,15 +246,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
 
 		ret.allocateDenseBlock();
 
-		if(isOverlapping()){
-			Comparator<AColGroup> comp = Comparator.comparing(x -> effect(x));
-			_colGroups.sort(comp);
-		}
-
 		if(k == 1)
-			decompress(ret);
+			CLALibDecompress.decompress(ret, getColGroups(), nonZeros, isOverlapping());
 		else
-			decompress(ret, k);
+			CLALibDecompress.decompress(ret, getColGroups(), isOverlapping(), k);
 
 		if(this.isOverlapping())
 			ret.recomputeNonZeros();
@@ -275,47 +266,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
 		return ret;
 	}
 
-	private double effect(AColGroup x){
-		return - Math.max(x.getMax(), Math.abs(x.getMin()));
-	}
-
-	private MatrixBlock decompress(MatrixBlock ret) {
-
-		ret.setNonZeros(nonZeros == -1 && !this.isOverlapping() ? recomputeNonZeros() : nonZeros);
-		final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
-		final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
-		for(int i = 0; i < getNumRows(); i += blklen)
-			for(AColGroup grp : _colGroups)
-				grp.decompressToBlockUnSafe(ret, i, Math.min(i + blklen, rlen));
-
-		return ret;
-	}
-
-	private MatrixBlock decompress(MatrixBlock ret, int k) {
-		try {
-			final ExecutorService pool = CommonThreadPool.get(k);
-			final int rlen = getNumRows();
-			final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / getNumColumns());
-			final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
-			final ArrayList<DecompressTask> tasks = new ArrayList<>();
-			for(int i = 0; i * blklen < getNumRows(); i++)
-				tasks.add(new DecompressTask(_colGroups, ret, i * blklen, Math.min((i + 1) * blklen, rlen),
-					overlappingColGroups));
-			List<Future<Long>> rtasks = pool.invokeAll(tasks);
-			pool.shutdown();
-
-			long nnz = 0;
-			for(Future<Long> rt : rtasks)
-				nnz += rt.get();
-			ret.setNonZeros(nnz);
-		}
-		catch(InterruptedException | ExecutionException ex) {
-			throw new DMLCompressionException("Parallel decompression failed", ex);
-		}
-
-		return ret;
-	}
-
 	/**
 	 * Get the cached decompressed matrix (if it exists otherwise null).
 	 * 
@@ -673,10 +623,10 @@ public class CompressedMatrixBlock extends MatrixBlock {
 			ReorgOperator r_op = new ReorgOperator(SwapIndex.getSwapIndexFnObject(), op.getNumThreads());
 			ret = ret.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0);
 		}
-		
+
 		if(ret.getNumRows() == 0 || ret.getNumColumns() == 0)
 			throw new DMLCompressionException("Error in outputted MM no dimensions");
-		
+
 		return ret;
 	}
 
@@ -788,46 +738,6 @@ public class CompressedMatrixBlock extends MatrixBlock {
 		return null;
 	}
 
-	private static class DecompressTask implements Callable<Long> {
-		private final List<AColGroup> _colGroups;
-		private final MatrixBlock _ret;
-		private final int _rl;
-		private final int _ru;
-		private final boolean _overlapping;
-
-		protected DecompressTask(List<AColGroup> colGroups, MatrixBlock ret, int rl, int ru, boolean overlapping) {
-			_colGroups = colGroups;
-			_ret = ret;
-			_rl = rl;
-			_ru = ru;
-			_overlapping = overlapping;
-		}
-
-		@Override
-		public Long call() {
-
-			// preallocate sparse rows to avoid repeated alloc
-			if(!_overlapping && _ret.isInSparseFormat()) {
-				int[] rnnz = new int[_ru - _rl];
-				for(AColGroup grp : _colGroups)
-					grp.countNonZerosPerRow(rnnz, _rl, _ru);
-				SparseBlock rows = _ret.getSparseBlock();
-				for(int i = _rl; i < _ru; i++)
-					rows.allocate(i, rnnz[i - _rl]);
-			}
-
-			// decompress row partition
-			for(AColGroup grp : _colGroups)
-				grp.decompressToBlockUnSafe(_ret, _rl, _ru);
-
-			// post processing (sort due to append)
-			if(_ret.isInSparseFormat())
-				_ret.sortSparseRows(_rl, _ru);
-
-			return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, _ru - 1);
-		}
-	}
-
 	@Override
 	public String toString() {
 		StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 98567e0..d77ab82 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -484,6 +484,7 @@ public class CompressedMatrixBlockFactory {
 				case 1:
 					LOG.debug("--compression phase " + phase + " Grouping  : " + getLastTimePhase());
 					LOG.debug("Grouping using: " + compSettings.columnPartitioner);
+					LOG.debug("Cost Calculated using: " + costEstimator);
 					LOG.debug("--Cocoded Columns estimated Compression:" + _stats.estimatedSizeCoCoded);
 					break;
 				case 2:
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
index b347031..c1a9cd4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java
@@ -92,6 +92,11 @@ public class CompressionSettings {
 	 */
 	public final int minimumSampleSize;
 
+	/**
+	 * The maximum size of the sample extracted.
+	 */
+	public final int maxSampleSize;
+
 	/** The sample type used for sampling */
 	public final EstimationType estimationType;
 
@@ -110,7 +115,7 @@ public class CompressionSettings {
 
 	protected CompressionSettings(double samplingRatio, boolean allowSharedDictionary, String transposeInput, int seed,
 		boolean lossy, EnumSet<CompressionType> validCompressions, boolean sortValuesByLength,
-		PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize,
+		PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize, int maxSampleSize,
 		EstimationType estimationType, CostType costComputationType, double minimumCompressionRatio) {
 		this.samplingRatio = samplingRatio;
 		this.allowSharedDictionary = allowSharedDictionary;
@@ -123,6 +128,7 @@ public class CompressionSettings {
 		this.maxColGroupCoCode = maxColGroupCoCode;
 		this.coCodePercentage = coCodePercentage;
 		this.minimumSampleSize = minimumSampleSize;
+		this.maxSampleSize= maxSampleSize;
 		this.estimationType = estimationType;
 		this.costComputationType = costComputationType;
 		this.minimumCompressionRatio = minimumCompressionRatio;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
index 156fdfc..d5fd036 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java
@@ -42,6 +42,7 @@ public class CompressionSettingsBuilder {
 	private int maxColGroupCoCode = 10000;
 	private double coCodePercentage = 0.01;
 	private int minimumSampleSize = 2000;
+	private int maxSampleSize = 1000000;
 	private EstimationType estimationType = EstimationType.HassAndStokes;
 	private PartitionerType columnPartitioner;
 	private CostType costType;
@@ -247,6 +248,18 @@ public class CompressionSettingsBuilder {
 	}
 
 	/**
+	 * Set the maximum sample size to extract from a given matrix, this overrules the sample percentage if the sample
+	 * percentage extracted is higher than this maximum bound.
+	 * 
+	 * @param maxSampleSize The maximum sample size to extract
+	 * @return The CompressionSettingsBuilder
+	 */
+	public CompressionSettingsBuilder setMaxSampleSize(int maxSampleSize) {
+		this.maxSampleSize = maxSampleSize;
+		return this;
+	}
+
+	/**
 	 * Set the estimation type used for the sampled estimates.
 	 * 
 	 * @param estimationType the estimation type in used.
@@ -268,6 +281,12 @@ public class CompressionSettingsBuilder {
 		return this;
 	}
 
+	/**
+	 * Set the minimum compression ratio to be achieved by the compression.
+	 * 
+	 * @param ratio The ratio to achieve while compressing
+	 * @return The CompressionSettingsBuilder
+	 */
 	public CompressionSettingsBuilder setMinimumCompressionRatio(double ratio) {
 		this.minimumCompressionRatio = ratio;
 		return this;
@@ -281,6 +300,6 @@ public class CompressionSettingsBuilder {
 	public CompressionSettings create() {
 		return new CompressionSettings(samplingRatio, allowSharedDictionary, transposeInput, seed, lossy,
 			validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage,
-			minimumSampleSize, estimationType, costType, minimumCompressionRatio);
+			minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio);
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
index cf366e7..55bf3b9 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java
@@ -36,25 +36,26 @@ import org.apache.sysds.runtime.compress.utils.Util;
 
 public class CoCodeGreedy extends AColumnCoCoder {
 
-
-	private final Memorizer mem;
-
 	protected CoCodeGreedy(CompressedSizeEstimator sizeEstimator, ICostEstimate costEstimator,
 		CompressionSettings cs) {
 		super(sizeEstimator, costEstimator, cs);
-		mem = new Memorizer();
 	}
 
 	@Override
 	protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
-		for(CompressedSizeInfoColGroup g : colInfos.compressionInfo)
+		colInfos.setInfo(join(colInfos.compressionInfo, _sest, _cest, _cs));
+		return colInfos;
+	}
+
+	protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> inputColumns, CompressedSizeEstimator sEst, ICostEstimate cEst, CompressionSettings cs) {
+		Memorizer mem = new Memorizer(cs, sEst);
+		for(CompressedSizeInfoColGroup g : inputColumns)
 			mem.put(g);
 		
-		colInfos.setInfo(coCodeBruteForce(colInfos.compressionInfo));
-		return colInfos;
+		return coCodeBruteForce(inputColumns, cEst, mem);
 	}
 
-	private List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns) {
+	private static List<CompressedSizeInfoColGroup> coCodeBruteForce(List<CompressedSizeInfoColGroup> inputColumns, ICostEstimate cEst, Memorizer mem) {
 
 		List<ColIndexes> workset = new ArrayList<>(inputColumns.size());
 
@@ -69,8 +70,8 @@ public class CoCodeGreedy extends AColumnCoCoder {
 				for(int j = i + 1; j < workset.size(); j++) {
 					final ColIndexes c1 = workset.get(i);
 					final ColIndexes c2 = workset.get(j);
-					final double costC1 = _cest.getCostOfColumnGroup(mem.get(c1));
-					final double costC2 = _cest.getCostOfColumnGroup(mem.get(c2));
+					final double costC1 = cEst.getCostOfColumnGroup(mem.get(c1));
+					final double costC2 = cEst.getCostOfColumnGroup(mem.get(c2));
 
 					mem.incst1();
 					// pruning filter : skip dominated candidates
@@ -82,7 +83,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
 					// Join the two column groups.
 					// and Memorize the new join.
 					final CompressedSizeInfoColGroup c1c2Inf = mem.getOrCreate(c1, c2);
-					final double costC1C2 = _cest.getCostOfColumnGroup(c1c2Inf);
+					final double costC1C2 = cEst.getCostOfColumnGroup(c1c2Inf);
 
 					final double newSizeChangeIfSelected = costC1C2 - costC1 - costC2;
 
@@ -120,11 +121,15 @@ public class CoCodeGreedy extends AColumnCoCoder {
 		return ret;
 	}
 
-	protected class Memorizer {
+	protected static class Memorizer {
+		private final CompressionSettings _cs;
+		private final CompressedSizeEstimator _sEst;
 		private final Map<ColIndexes, CompressedSizeInfoColGroup> mem;
 		private int st1 = 0, st2 = 0, st3 = 0, st4 = 0;
 
-		public Memorizer() {
+		public Memorizer(CompressionSettings cs, CompressedSizeEstimator sEst) {
+			_cs = cs;
+			_sEst = sEst;
 			mem = new HashMap<>();
 		}
 
@@ -159,7 +164,7 @@ public class CoCodeGreedy extends AColumnCoCoder {
 					g = CompressedSizeInfoColGroup.addConstGroup(c, left, _cs.validCompressions);
 				else {
 					st3++;
-					g = _sest.estimateJoinCompressedSize(c, left, right);
+					g = _sEst.estimateJoinCompressedSize(c, left, right);
 				}
 
 				if(leftConst || rightConst)
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
new file mode 100644
index 0000000..f762c0d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeHybrid.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.cocode;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.cost.ICostEstimate;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+
+/**
+ * This cocode strategy starts out with priority que until a threshold number of columnGroups is achieved, then the
+ * strategy shifts into a greedy all compare.
+ */
+public class CoCodeHybrid extends AColumnCoCoder {
+
+    protected CoCodeHybrid(CompressedSizeEstimator sizeEstimator, ICostEstimate costEstimator, CompressionSettings cs) {
+        super(sizeEstimator, costEstimator, cs);
+    }
+
+    @Override
+    protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
+        final int startSize = colInfos.getInfo().size();
+        final int PriorityQueGoal = 40;
+        if(startSize > 200) {
+
+            colInfos.setInfo(CoCodePriorityQue.join(colInfos.getInfo(), _sest, _cest, PriorityQueGoal));
+
+            final int pqSize = colInfos.getInfo().size();
+            if(pqSize <= PriorityQueGoal)
+                colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, _cest, _cs));
+        }
+        else {
+            colInfos.setInfo(CoCodeGreedy.join(colInfos.getInfo(), _sest, _cest, _cs));
+        }
+
+        return colInfos;
+    }
+
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
index 27b678c..b53859e 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.java
@@ -48,12 +48,13 @@ public class CoCodePriorityQue extends AColumnCoCoder {
 
 	@Override
 	protected CompressedSizeInfo coCodeColumns(CompressedSizeInfo colInfos, int k) {
-		colInfos.setInfo(join(colInfos.getInfo()));
+		colInfos.setInfo(join(colInfos.getInfo(), _sest, _cest, 1));
 		return colInfos;
 	}
 
-	private List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> currentGroups) {
-		Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> _cest.getCostOfColumnGroup(x));
+	protected static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> currentGroups,
+		CompressedSizeEstimator sEst, ICostEstimate cEst, int minNumGroups) {
+		Comparator<CompressedSizeInfoColGroup> comp = Comparator.comparing(x -> cEst.getCostOfColumnGroup(x));
 		Queue<CompressedSizeInfoColGroup> que = new PriorityQueue<>(currentGroups.size(), comp);
 		List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
 
@@ -62,15 +63,16 @@ public class CoCodePriorityQue extends AColumnCoCoder {
 				que.add(g);
 
 		CompressedSizeInfoColGroup l = null;
-		if(_cest.isCompareAll()) {
-			double costBeforeJoin = _cest.getCostOfCollectionOfGroups(que);
+		if(cEst.isCompareAll()) {
+			double costBeforeJoin = cEst.getCostOfCollectionOfGroups(que);
 			l = que.poll();
-			while(que.peek() != null) {
+			int groupNr = ret.size() + que.size();
+			while(que.peek() != null && groupNr >= minNumGroups) {
 
 				CompressedSizeInfoColGroup r = que.poll();
-				final CompressedSizeInfoColGroup g = _sest.estimateJoinCompressedSize(l, r);
+				final CompressedSizeInfoColGroup g = sEst.estimateJoinCompressedSize(l, r);
 				if(g != null) {
-					final double costOfJoin = _cest.getCostOfCollectionOfGroups(que, g);
+					final double costOfJoin = cEst.getCostOfCollectionOfGroups(que, g);
 					if(costOfJoin < costBeforeJoin) {
 						costBeforeJoin = costOfJoin;
 						que.add(g);
@@ -86,17 +88,19 @@ public class CoCodePriorityQue extends AColumnCoCoder {
 				}
 
 				l = que.poll();
+				groupNr = ret.size() + que.size();
 			}
 		}
 		else {
 			l = que.poll();
-			while(que.peek() != null) {
+			int groupNr = ret.size() + que.size();
+			while(que.peek() != null && groupNr >= minNumGroups) {
 				CompressedSizeInfoColGroup r = que.peek();
-				if(_cest.shouldTryJoin(l, r)) {
-					CompressedSizeInfoColGroup g = _sest.estimateJoinCompressedSize(l, r);
+				if(cEst.shouldTryJoin(l, r)) {
+					CompressedSizeInfoColGroup g = sEst.estimateJoinCompressedSize(l, r);
 					if(g != null) {
-						double costOfJoin = _cest.getCostOfColumnGroup(g);
-						double costIndividual = _cest.getCostOfColumnGroup(l) + _cest.getCostOfColumnGroup(r);
+						double costOfJoin = cEst.getCostOfColumnGroup(g);
+						double costIndividual = cEst.getCostOfColumnGroup(l) + cEst.getCostOfColumnGroup(r);
 
 						if(costOfJoin < costIndividual) {
 							que.poll();
@@ -112,8 +116,10 @@ public class CoCodePriorityQue extends AColumnCoCoder {
 					ret.add(l);
 
 				l = que.poll();
+				groupNr = ret.size() + que.size();
 			}
 		}
+
 		if(l != null)
 			ret.add(l);
 
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
index c467f70..eaf9cb4 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCoderFactory.java
@@ -63,10 +63,7 @@ public class CoCoderFactory {
 				// TODO make decision better depending on how much time is allocated for the compression
 				// for instance if the compressed object is used for a million instructions, it might be good to
 				// search for a really good compression even if it take longer.
-				if(est.getNumColumns() > 200)
-					return new CoCodePriorityQue(est, costEstimator, cs);
-				else
-					return new CoCodeGreedy(est, costEstimator, cs);
+				return new CoCodeHybrid(est, costEstimator, cs);
 			case GREEDY:
 				return new CoCodeGreedy(est, costEstimator, cs);
 			case BIN_PACKING:
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index 5c6d365..fb93ed6 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -558,11 +558,13 @@ public class ColGroupSDC extends ColGroupValue {
 	public ColGroupSDCZeros extractCommon(double[] constV) {
 		double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length);
 
-		for(int i = 0; i < _colIndexes.length; i++) {
+		if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC.
+			return new ColGroupSDCZeros(_colIndexes, _numRows, _dict, _indexes, _data, getCounts());
+
+		for(int i = 0; i < _colIndexes.length; i++)
 			constV[_colIndexes[i]] += commonV[i];
-		}
+
 		ADictionary subtractedDict = _dict.subtractTuple(commonV);
 		return new ColGroupSDCZeros(_colIndexes, _numRows, subtractedDict, _indexes, _data, getCounts());
 	}
-
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
index de0b4a8..b2b4283 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java
@@ -117,8 +117,8 @@ public class ColGroupSDCSingle extends ColGroupValue {
 			for(int j = 0; j < nCol; j++)
 				c[off + _colIndexes[j]] += values[offsetToDefault + j];
 		}
-		
-		_indexes.cacheIterator(it, ru );
+
+		_indexes.cacheIterator(it, ru);
 	}
 
 	@Override
@@ -473,11 +473,14 @@ public class ColGroupSDCSingle extends ColGroupValue {
 	public ColGroupSDCSingleZeros extractCommon(double[] constV) {
 		double[] commonV = _dict.getTuple(getNumValues() - 1, _colIndexes.length);
 
+		if(commonV == null) // The common tuple was all zero. Therefore this column group should never have been SDC.
+			return new ColGroupSDCSingleZeros(_colIndexes, _numRows, _dict, _indexes, getCachedCounts());
+
 		for(int i = 0; i < _colIndexes.length; i++)
 			constV[_colIndexes[i]] += commonV[i];
 
 		ADictionary subtractedDict = _dict.subtractTuple(commonV);
-		return new ColGroupSDCSingleZeros(_colIndexes, _numRows, subtractedDict, _indexes, getCounts());
+		return new ColGroupSDCSingleZeros(_colIndexes, _numRows, subtractedDict, _indexes, getCachedCounts());
 	}
 
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index 15c74b0..df75268 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -35,7 +35,7 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 public abstract class ADictionary implements Serializable {
 
 	private static final long serialVersionUID = 9118692576356558592L;
-	
+
 	protected static final Log LOG = LogFactory.getLog(ADictionary.class.getName());
 
 	/**
@@ -342,6 +342,8 @@ public abstract class ADictionary implements Serializable {
 	/**
 	 * Get the values contained in a specific tuple of the dictionary.
 	 * 
+	 * If the entire row is zero return null.
+	 * 
 	 * @param index The index where the values are located
 	 * @param nCol  The number of columns contained in this dictionary
 	 * @return a materialized double array containing the tuple.
diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 5b9d834..0b65ed1 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -43,7 +43,7 @@ import org.apache.sysds.utils.MemoryEstimates;
 public class Dictionary extends ADictionary {
 
 	private static final long serialVersionUID = -6517136537249507753L;
-	
+
 	private final double[] _values;
 
 	public Dictionary(double[] values) {
@@ -329,14 +329,13 @@ public class Dictionary extends ADictionary {
 		if(colIndexes == 1)
 			sb.append(Arrays.toString(_values));
 		else {
-			sb.append("[\n");
+			sb.append("[\n\t");
 			for(int i = 0; i < _values.length - 1; i++) {
 				sb.append(_values[i]);
-				sb.append((i) % (colIndexes) == colIndexes - 1 ? "\nt" + i + ": " : ", ");
+				sb.append((i) % (colIndexes) == colIndexes - 1 ? "\n\t" : ", ");
 			}
 			sb.append(_values[_values.length - 1]);
-
-			sb.append("\n]");
+			sb.append("]");
 		}
 		return sb.toString();
 	}
@@ -476,7 +475,7 @@ public class Dictionary extends ADictionary {
 	}
 
 	@Override
-	public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, int[] aggregateColumns, double[] b, 
+	public Dictionary preaggValuesFromDense(int numVals, int[] colIndexes, int[] aggregateColumns, double[] b,
 		int cut) {
 		double[] ret = new double[numVals * aggregateColumns.length];
 		for(int k = 0, off = 0;
diff --git a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 1653f51..07cf58b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -23,6 +23,7 @@ import java.util.Collection;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public class ComputationCostEstimator implements ICostEstimate {
 
@@ -41,7 +42,6 @@ public class ComputationCostEstimator implements ICostEstimate {
 	private final int _leftMultiplications;
 	private final int _rightMultiplications;
 	private final int _compressedMultiplication;
-	// private final int _rowBasedOps;
 	private final int _dictionaryOps;
 
 	private final boolean _isDensifying;
@@ -100,16 +100,21 @@ public class ComputationCostEstimator implements ICostEstimate {
 		cost += _scans * scanCost(g);
 		cost += _decompressions * decompressionCost(g);
 		cost += _overlappingDecompressions * overlappingDecompressionCost(g);
-		// 16 is assuming that the left side is 16 rows.
-		double lmc = leftMultCost(g) * 16;
-		cost += _leftMultiplications * lmc;
-		// 16 is assuming that the right side is 16 rows.
-		double rmc = rightMultCost(g) * 16;
-		cost += _rightMultiplications * rmc;
-
-		// cost += _compressedMultiplication * (lmc + rmc);
-		cost += _compressedMultiplication * _compressedMultCost(g);
+		// 16 is assuming that the left / right side is 16 rows/cols.
+		final int rowsCols = 16;
+		cost += _leftMultiplications *  leftMultCost(g) * rowsCols;
+		cost += _rightMultiplications * rightMultCost(g) * rowsCols;
 		cost += _dictionaryOps * dictionaryOpsCost(g);
+		
+		double size = g.getMinSize();
+		final double compressionRatio =  size / MatrixBlock.estimateSizeDenseInMemory(_nRows, _nCols) / g.getColumns().length;
+
+		cost *=  0.001 + compressionRatio;
+
+		cost += _compressedMultiplication * _compressedMultCost(g) * rowsCols;
+
+		// double uncompressedSize = g.getCompressionSize(CompressionType.UNCOMPRESSED);
+
 		return cost;
 	}
 
@@ -118,13 +123,14 @@ public class ComputationCostEstimator implements ICostEstimate {
 	}
 
 	private double leftMultCost(CompressedSizeInfoColGroup g) {
-		final int nCols = g.getColumns().length;
-		final double preAggregateCost = _nRows;
+		final int nColsInGroup = g.getColumns().length;
+		final double mcf = g.getMostCommonFraction();
+		final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.4 * mcf) : _nRows;
 
 		final double numberTuples = g.getNumVals();
 		final double tupleSparsity = g.getTupleSparsity();
-		final double postScalingCost = (nCols > 1 && tupleSparsity > 0.4) ? numberTuples * nCols * tupleSparsity *
-			1.4 : numberTuples * nCols;
+		final double postScalingCost = (nColsInGroup > 1 && tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
+			1.4 : numberTuples * nColsInGroup;
 		if(numberTuples < 64000)
 			return preAggregateCost + postScalingCost;
 		else
@@ -134,10 +140,11 @@ public class ComputationCostEstimator implements ICostEstimate {
 
 	private double _compressedMultCost(CompressedSizeInfoColGroup g) {
 		final int nColsInGroup = g.getColumns().length;
-		final double mcf = g.getMostCommonFraction();
-		final double preAggregateCost = mcf > 0.6 ? _nRows * (1 - 0.7 * mcf) : _nRows;
+		// final double mcf = g.getMostCommonFraction();
+		// final double preAggregateCost = (mcf > 0.6 ? _nRows * (1 - 0.6 * mcf) : _nRows) * 4;
+		final double preAggregateCost = _nRows;
 
-		final double numberTuples = (float) g.getNumVals();
+		final double numberTuples = g.getNumVals();
 		final double tupleSparsity = g.getTupleSparsity();
 		final double postScalingCost = (nColsInGroup > 1 && tupleSparsity > 0.4) ? numberTuples * nColsInGroup * tupleSparsity *
 			1.4 : numberTuples * nColsInGroup;
@@ -163,7 +170,10 @@ public class ComputationCostEstimator implements ICostEstimate {
 	}
 
 	private double overlappingDecompressionCost(CompressedSizeInfoColGroup g) {
-		return _nRows * 16 * (g.getNumVals() / 64000 + 1);
+		final double mcf = g.getMostCommonFraction();
+		final double rowsCost = mcf > 0.6 ? _nRows * (1 - 0.6 * mcf) : _nRows;
+		//  Setting 64 to mark decompression as expensive.
+		return rowsCost * 16 * ((float)g.getNumVals() / 64000 + 1);
 	}
 
 	private static double dictionaryOpsCost(CompressedSizeInfoColGroup g) {
@@ -259,15 +269,19 @@ public class ComputationCostEstimator implements ICostEstimate {
 	public String toString() {
 		StringBuilder sb = new StringBuilder();
 		sb.append(this.getClass().getSimpleName());
-		sb.append("\n");
-		sb.append(_nRows + "  ");
-		sb.append(_scans + " ");
-		sb.append(_decompressions + " ");
-		sb.append(_overlappingDecompressions + " ");
-		sb.append(_leftMultiplications + " ");
-		sb.append(_rightMultiplications + " ");
-		sb.append(_compressedMultiplication + " ");
-		sb.append(_dictionaryOps + " ");
+		sb.append("dims(");
+		sb.append(_nRows + ",");
+		sb.append(_nCols + ") ");
+		sb.append("CostVector:[");
+		sb.append(_scans + ",");
+		sb.append(_decompressions + ",");
+		sb.append(_overlappingDecompressions + ",");
+		sb.append(_leftMultiplications + ",");
+		sb.append(_rightMultiplications + ",");
+		sb.append(_compressedMultiplication + ",");
+		sb.append(_dictionaryOps + "]");
+		sb.append(" Densifying:");
+		sb.append(_isDensifying);
 		return sb.toString();
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
index a4b1c99..a9a8e44 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorFactory.java
@@ -27,7 +27,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public class CompressedSizeEstimatorFactory {
 	protected static final Log LOG = LogFactory.getLog(CompressedSizeEstimatorFactory.class.getName());
-	private static final int maxSampleSize = 1000000;
 
 	public static CompressedSizeEstimator getSizeEstimator(MatrixBlock data, CompressionSettings cs, int k) {
 
@@ -36,7 +35,7 @@ public class CompressedSizeEstimatorFactory {
 		final int nnzRows = (int) Math.ceil(data.getNonZeros() / nCols);
 
 		final double sampleRatio = cs.samplingRatio;
-		final int sampleSize = Math.min(getSampleSize(sampleRatio, nRows, cs.minimumSampleSize), maxSampleSize);
+		final int sampleSize = getSampleSize(sampleRatio, nRows, nCols, cs.minimumSampleSize, cs.maxSampleSize);
 
 		if(nCols > 1000) {
 			return tryToMakeSampleEstimator(data, cs, sampleRatio, sampleSize / 10, nRows, nnzRows, k);
@@ -79,7 +78,27 @@ public class CompressedSizeEstimatorFactory {
 		return cs.samplingRatio >= 1.0 || nRows < cs.minimumSampleSize || sampleSize >= nnzRows;
 	}
 
-	private static int getSampleSize(double sampleRatio, int nRows, int minimumSampleSize) {
-		return Math.max((int) Math.ceil(nRows * sampleRatio), minimumSampleSize);
+	/**
+	 * This function returns the sample size to use.
+	 * 
+	 * The sampling is bound by the maximum sampling and the minimum sampling other than that a linear relation is used
+	 * with the sample ratio.
+	 * 
+	 * Also influencing the sample size is the number of columns. If the number of columns is large the sample size is
+	 * scaled down, this gives worse estimations of distinct items, but it makes sure that the compression time is more
+	 * consistent.
+	 * 
+	 * @param sampleRatio       The sample ratio
+	 * @param nRows             The number of rows
+	 * @param nCols             The number of columns
+	 * @param minimumSampleSize the minimum sample size
+	 * @param maxSampleSize     the maximum sample size
+	 * @return The sample size to use.
+	 */
+	private static int getSampleSize(double sampleRatio, int nRows, int nCols, int minSampleSize, int maxSampleSize) {
+		int sampleSize = (int) Math.ceil(nRows * sampleRatio / Math.max(1, (double)nCols / 150));
+		if(sampleSize < 20000)
+			sampleSize *= 2;
+		return Math.min(Math.max(sampleSize, minSampleSize), maxSampleSize);
 	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
index e330ab8..b9b64fa 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java
@@ -105,7 +105,8 @@ public class CLALibBinaryCellOp {
 				result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 0);
 			else if(fn instanceof Minus1Multiply)
 				result = CompressedMatrixBlockFactory.createConstant(m1.getNumRows(), m1.getNumColumns(), 1);
-			else if(fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply || fn instanceof PlusMultiply){
+			else if(fn instanceof Minus || fn instanceof Plus || fn instanceof MinusMultiply ||
+				fn instanceof PlusMultiply) {
 				CompressedMatrixBlock ret = new CompressedMatrixBlock();
 				ret.copy(m1);
 				return ret;
@@ -132,7 +133,11 @@ public class CLALibBinaryCellOp {
 		// TODO optimize to allow for sparse outputs.
 		final int outCells = outRows * outCols;
 		if(atype == BinaryAccessType.MATRIX_COL_VECTOR) {
-			result.reset(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+			if(result != null)
+				result.reset(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+			else
+				result = new MatrixBlock(outRows, Math.max(outCols, that.getNumColumns()), outCells);
+
 			MatrixBlock d_compressed = m1.getCachedDecompressed();
 			if(d_compressed != null) {
 				if(left)
@@ -146,12 +151,15 @@ public class CLALibBinaryCellOp {
 
 		}
 		else if(atype == BinaryAccessType.MATRIX_MATRIX) {
-			result.reset(outRows, outCols, outCells);
+			if(result != null)
+				result.reset(outRows, outCols, outCells);
+			else
+				result = new MatrixBlock(outRows, outCols, outCells);
 
 			MatrixBlock d_compressed = m1.getCachedDecompressed();
 			if(d_compressed == null)
 				d_compressed = m1.getUncompressed("MatrixMatrix " + op);
-			
+
 			if(left)
 				LibMatrixBincell.bincellOp(that, d_compressed, result, op);
 			else
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
new file mode 100644
index 0000000..ff2d211
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.CommonThreadPool;
+
+/**
+ * Library to decompress a list of column groups into a matrix.
+ */
+public class CLALibDecompress {
+	public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> groups, long nonZeros, boolean overlapping) {
+
+		final int rlen = ret.getNumRows();
+		final int clen = ret.getNumColumns();
+		final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / clen);
+		final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
+		final boolean containsSDC = CLALibUtils.containsSDC(groups);
+		double[] constV = containsSDC ? new double[ret.getNumColumns()] : null;
+		final List<AColGroup> filteredGroups = containsSDC ? CLALibUtils.filterSDCGroups(groups, constV) : groups;
+
+		sortGroups(filteredGroups, overlapping);
+		// check if we are using filtered groups, and if we are not force constV to null
+		if(groups == filteredGroups)
+			constV = null;
+
+		final double eps = getEps(constV);
+		for(int i = 0; i < rlen; i += blklen) {
+			final int rl = i;
+			final int ru = Math.min(i + blklen, rlen);
+			for(AColGroup grp : filteredGroups)
+				grp.decompressToBlockUnSafe(ret, rl, ru);
+			if(constV != null && !ret.isInSparseFormat())
+				addVector(ret, constV, eps, rl, ru);
+		}
+
+		ret.setNonZeros(nonZeros == -1 || overlapping ? ret.recomputeNonZeros() : nonZeros);
+
+		return ret;
+	}
+
+	public static MatrixBlock decompress(MatrixBlock ret, List<AColGroup> groups, boolean overlapping, int k) {
+
+		try {
+			final ExecutorService pool = CommonThreadPool.get(k);
+			final int rlen = ret.getNumRows();
+			final int block = (int) Math.ceil((double) (CompressionSettings.BITMAP_BLOCK_SZ) / ret.getNumColumns());
+			final int blklen = block > 1000 ? block + 1000 - block % 1000 : Math.max(64, block);
+
+			final boolean containsSDC = CLALibUtils.containsSDC(groups);
+			double[] constV = containsSDC ? new double[ret.getNumColumns()] : null;
+			final List<AColGroup> filteredGroups = containsSDC ? CLALibUtils.filterSDCGroups(groups, constV) : groups;
+			sortGroups(filteredGroups, overlapping);
+
+			// check if we are using filtered groups, and if we are not force constV to null
+			if(groups == filteredGroups)
+				constV = null;
+
+			final double eps = getEps(constV);
+			final ArrayList<DecompressTask> tasks = new ArrayList<>();
+			for(int i = 0; i * blklen < rlen; i++)
+				tasks.add(new DecompressTask(filteredGroups, ret, eps, i * blklen, Math.min((i + 1) * blklen, rlen),
+					overlapping, constV));
+			List<Future<Long>> rtasks = pool.invokeAll(tasks);
+			pool.shutdown();
+
+			long nnz = 0;
+			for(Future<Long> rt : rtasks)
+				nnz += rt.get();
+			ret.setNonZeros(nnz);
+		}
+		catch(InterruptedException | ExecutionException ex) {
+			throw new DMLCompressionException("Parallel decompression failed", ex);
+		}
+
+		return ret;
+	}
+
+	private static void sortGroups(List<AColGroup> groups, boolean overlapping) {
+		if(overlapping) {
+			// add a bit of stability in decompression
+			Comparator<AColGroup> comp = Comparator.comparing(x -> effect(x));
+			groups.sort(comp);
+		}
+	}
+
+	/**
+	 * Calculate an effect value for a column group. This is used to sort the groups before decompression to decompress
+	 * the columns that have the smallest effect first.
+	 * 
+	 * @param x A Group
+	 * @return A Effect double value.
+	 */
+	private static double effect(AColGroup x) {
+		return -Math.max(x.getMax(), Math.abs(x.getMin()));
+	}
+
+	/**
+	 * Get a small epsilon from the constant group.
+	 * 
+	 * @param constV the constant vector.
+	 * @return epsilon
+	 */
+	private static double getEps(double[] constV) {
+		if(constV == null)
+			return 0;
+		else {
+			double max = -Double.MAX_VALUE;
+			double min = Double.MAX_VALUE;
+			for(double v : constV){
+				if(v > max)
+					max = v;
+				if(v < min)
+					min = v;
+			}
+			final double eps = (max-min) * 1e-13;
+			return eps;
+		}
+	}
+
+	private static class DecompressTask implements Callable<Long> {
+		private final List<AColGroup> _colGroups;
+		private final MatrixBlock _ret;
+		private final double _eps;
+		private final int _rl;
+		private final int _ru;
+		private final double[] _constV;
+		private final boolean _overlapping;
+
+		protected DecompressTask(List<AColGroup> colGroups, MatrixBlock ret, double eps, int rl, int ru,
+			boolean overlapping, double[] constV) {
+			_colGroups = colGroups;
+			_ret = ret;
+			_eps = eps;
+			_rl = rl;
+			_ru = ru;
+			_overlapping = overlapping;
+			_constV = constV;
+		}
+
+		@Override
+		public Long call() {
+			// decompress row partition
+			for(AColGroup grp : _colGroups)
+				grp.decompressToBlockUnSafe(_ret, _rl, _ru);
+
+			if(_constV != null)
+				addVector(_ret, _constV, _eps, _rl, _ru);
+
+			return _overlapping ? 0 : _ret.recomputeNonZeros(_rl, _ru - 1);
+		}
+	}
+
+	/**
+	 * Add the rowV vector to each row in ret.
+	 * 
+	 * @param ret  matrix to add the vector to
+	 * @param rowV The row vector to add
+	 * @param eps  an epsilon defined, to round the output value to zero if the value is less than epsilon away from
+	 *             zero.
+	 * @param rl   The row to start at
+	 * @param ru   The row to end at
+	 */
+	private static void addVector(final MatrixBlock ret, final double[] rowV, final double eps, final int rl,
+		final int ru) {
+		final int nCols = ret.getNumColumns();
+		final DenseBlock db = ret.getDenseBlock();
+		if(eps == 0) {
+			for(int row = rl; row < ru; row++) {
+				final double[] _retV = db.values(row);
+				final int off = db.pos(row);
+				for(int col = 0; col < nCols; col++)
+					_retV[off + col] += rowV[col];
+			}
+		}
+		else {
+			for(int row = rl; row < ru; row++) {
+				final double[] _retV = db.values(row);
+				final int off = db.pos(row);
+				for(int col = 0; col < nCols; col++) {
+					final int out = off + col;
+					_retV[out] += rowV[col];
+					if(Math.abs(_retV[out]) <= eps)
+						_retV[out] = 0;
+				}
+			}
+		}
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
index 8b78e58..dc2df68 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.java
@@ -110,11 +110,11 @@ public class CLALibLeftMultBy {
 			multAllColGroups(groups, groups, result);
 		}
 		else {
-			final boolean containsSDC = containsSDC(groups);
-			final int numColumns = cmb.getNumColumns();
+			final boolean containsSDC = CLALibUtils.containsSDC(groups);
 			final double[] constV = containsSDC ? new double[cmb.getNumColumns()] : null;
-			final List<AColGroup> filteredGroups = filterSDCGroups(groups, constV);
+			final List<AColGroup> filteredGroups = CLALibUtils.filterSDCGroups(groups, constV);
 			final double[] colSums = containsSDC ? new double[cmb.getNumColumns()] : null;
+			final int numColumns = cmb.getNumColumns();
 
 			if(containsSDC)
 				for(int i = 0; i < groups.size(); i++) {
@@ -298,12 +298,13 @@ public class CLALibLeftMultBy {
 		}
 
 		final int numColumnsOut = ret.getNumColumns();
-		final boolean containsSDC = containsSDC(colGroups);
+		final boolean containsSDC = CLALibUtils.containsSDC(colGroups);
 
 		// a constant colgroup summing the default values.
-		final double[] constV = containsSDC ? new double[numColumnsOut] : null;
-		final List<AColGroup> filteredGroups = filterSDCGroups(colGroups, constV);
-
+		double[] constV = containsSDC ? new double[numColumnsOut] : null;
+		final List<AColGroup> filteredGroups = CLALibUtils.filterSDCGroups(colGroups, constV);
+		if(colGroups == filteredGroups)
+			constV = null;
 		final double[] rowSums = containsSDC ? new double[that.getNumRows()] : null;
 
 		if(k == 1) {
@@ -633,33 +634,4 @@ public class CLALibLeftMultBy {
 		Collections.sort(ColGroupValues, Comparator.comparing(AColGroup::getNumValues).reversed());
 		return ColGroupValues;
 	}
-
-	private static boolean containsSDC(List<AColGroup> groups) {
-		boolean containsSDC = false;
-
-		for(AColGroup g : groups) {
-			if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle) {
-				containsSDC = true;
-				break;
-			}
-		}
-		return containsSDC;
-	}
-
-	private static List<AColGroup> filterSDCGroups(List<AColGroup> groups, double[] constV) {
-		if(constV != null) {
-			final List<AColGroup> filteredGroups = new ArrayList<>();
-			for(AColGroup g : groups) {
-				if(g instanceof ColGroupSDC)
-					filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
-				else if(g instanceof ColGroupSDCSingle)
-					filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
-				else
-					filteredGroups.add(g);
-			}
-			return filteredGroups;
-		}
-		else
-			return groups;
-	}
 }
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
index 1822685..ace2843 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java
@@ -68,8 +68,7 @@ public class CLALibScalar {
 		if(m1.isOverlapping() && !(sop.fn instanceof Multiply || sop.fn instanceof Divide)) {
 			AColGroup constOverlap = constOverlap(m1, sop);
 			List<AColGroup> newColGroups = (sop instanceof LeftScalarOperator &&
-				sop.fn instanceof Minus) ? processOverlappingSubtractionLeft(m1,
-					sop,
+				sop.fn instanceof Minus) ? processOverlappingSubtractionLeft(m1, sop,
 					ret) : processOverlappingAddition(m1, sop, ret);
 			newColGroups.add(constOverlap);
 			ret.allocateColGroupList(newColGroups);
@@ -93,8 +92,8 @@ public class CLALibScalar {
 		}
 
 		ret.recomputeNonZeros();
-		return ret;
 
+		return ret;
 	}
 
 	private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) {
diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
new file mode 100644
index 0000000..c701b96
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.lib;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle;
+
+public class CLALibUtils {
+
+	/**
+	 * Helper method to determine if the column groups contains SDC
+	 * 
+	 * Note that it only returns true, if there is more than one SDC Group.
+	 * 
+	 * @param groups The ColumnGroups to analyze
+	 * @return A Boolean saying it there is >= 2 SDC Groups.
+	 */
+	protected static boolean containsSDC(List<AColGroup> groups) {
+		int count = 0;
+		for(AColGroup g : groups) {
+			if(g instanceof ColGroupSDC || g instanceof ColGroupSDCSingle) {
+				count++;
+				if(count > 1)
+					break;
+			}
+		}
+		return count > 1;
+	}
+
+	/**
+	 * Helper method to filter out SDC Groups, to add their common value to the ConstV. This allows exploitation of the
+	 * common values in the SDC Groups.
+	 * 
+	 * @param groups The Column Groups
+	 * @param constV The Constant vector to add common values to.
+	 * @return The Filtered list of Column groups containing no SDC Groups but only SDCZero groups.
+	 */
+	protected static List<AColGroup> filterSDCGroups(List<AColGroup> groups, double[] constV) {
+		if(constV != null) {
+			final List<AColGroup> filteredGroups = new ArrayList<>();
+			for(AColGroup g : groups) {
+				if(g instanceof ColGroupSDC)
+					filteredGroups.add(((ColGroupSDC) g).extractCommon(constV));
+				else if(g instanceof ColGroupSDCSingle)
+					filteredGroups.add(((ColGroupSDCSingle) g).extractCommon(constV));
+				else
+					filteredGroups.add(g);
+			}
+			for(double v : constV)
+				if(!Double.isFinite(v))
+					return groups;
+			
+			return filteredGroups;
+		}
+		else
+			return groups;
+	}
+}
diff --git a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index e086974..06431db 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -425,40 +425,48 @@ public class WorkloadAnalyzer {
 					setDecompressionOnAllInputs(hop, parent);
 					return;
 				}
-				// shortcut instead of comparing to MatrixScalar or RowVector.
-				else if(hop.getInput(1).getDim1() == 1 || hop.getInput(1).isScalar() || hop.getInput(0).isScalar()) {
-
+				else {
 					ArrayList<Hop> in = hop.getInput();
 					final boolean ol0 = isOverlapping(in.get(0));
 					final boolean ol1 = isOverlapping(in.get(1));
 					final boolean ol = ol0 || ol1;
-					if(ol && HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
-						overlapping.add(hop.getHopID());
-						o = new OpNormal(hop, true);
-						o.setOverlapping();
+
+					// shortcut instead of comparing to MatrixScalar or RowVector.
+					if(in.get(1).getDim1() == 1 || in.get(1).isScalar() || in.get(0).isScalar()) {
+
+						if(ol && HopRewriteUtils.isBinary(hop, OpOp2.PLUS, OpOp2.MULT, OpOp2.DIV, OpOp2.MINUS)) {
+							overlapping.add(hop.getHopID());
+							o = new OpNormal(hop, true);
+							o.setOverlapping();
+						}
+						else if(ol) {
+							treeLookup.get(in.get(0).getHopID()).setDecompressing();
+							return;
+						}
+						else {
+							o = new OpNormal(hop, true);
+						}
+						if(!HopRewriteUtils.isBinarySparseSafe(hop))
+							o.setDensifying();
+
 					}
-					else if(ol) {
-						treeLookup.get(in.get(0).getHopID()).setDecompressing();
+					else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
+						HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
+						HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
+						setDecompressionOnAllInputs(hop, parent);
+						return;
+					}
+					else if(ol0 || ol1){
+						setDecompressionOnAllInputs(hop, parent);
 						return;
 					}
 					else {
-						o = new OpNormal(hop, true);
+						String ex = "Setting decompressed because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n"
+							+ Explain.explain(hop);
+						LOG.warn(ex);
+						setDecompressionOnAllInputs(hop, parent);
+						return;
 					}
-					if(!HopRewriteUtils.isBinarySparseSafe(hop))
-						o.setDensifying();
-
-				}
-				else if(HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) ||
-					HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) ||
-					HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
-					setDecompressionOnAllInputs(hop, parent);
-					return;
-				}
-				else {
-					String ex = "Setting decompressed because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n"
-						+ Explain.explain(hop);
-					LOG.warn(ex);
-					setDecompressionOnAllInputs(hop, parent);
 				}
 
 			}
diff --git a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
index f1682a7..ba50a0e 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/workload/WorkloadTest.java
@@ -117,7 +117,7 @@ public class WorkloadTest {
 		args.put("$3", "0");
 
 		// no recompile
-		tests.add(new Object[] {0, 1, 1, 1, 1, 1, 6, 0, true, false, "functions/lmDS.dml", args});
+		tests.add(new Object[] {0, 1, 1, 1, 1, 1, 5, 0, true, false, "functions/lmDS.dml", args});
 		// with recompile
 		tests.add(new Object[] {0, 0, 0, 1, 0, 1, 0, 0, true, true, "functions/lmDS.dml", args});
 		tests.add(new Object[] {0, 0, 0, 1, 10, 10, 1, 0, true, true, "functions/lmCG.dml", args});
diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
index dc5b17f..c5710bd 100644
--- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
+++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java
@@ -188,7 +188,7 @@ public class CompressForce extends CompressBase {
 		// be aware that with multiple blocks it is likely that the small blocks
 		// initially compress, but is to large for overlapping state therefor will decompress.
 		// In this test it decompress the second small block but keeps the first in overlapping state.
-		runTest(1110, 30, 1, 1, ExecType.SPARK, "mmr_sum_plus_2");
+		compressTest(1110, 10, 1.0, ExecType.SPARK, 1, 6, 1, 1, 1, "mmr_sum_plus_2");
 	}
 
 	@Test

[systemds] 02/02: [SYSTEMDS-3144] Spark Local Context from command line

Posted by ba...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 535263b28e4f36b7575f487bcf9967d9d42bffdd
Author: baunsgaard <ba...@tugraz.at>
AuthorDate: Sat Sep 25 19:54:01 2021 +0200

    [SYSTEMDS-3144] Spark Local Context from command line
    
    This commit adds the ability to start a systemDS instance with a
    local spark context, this enables us to use our spark instructions
    even without a spark cluster.
    
    Also added in this commit is a fallback to our a local spark instance,
    in case the spark context is tried to be created but fails.
    
    Closes #1398
    Closes #1399
---
 src/main/java/org/apache/sysds/conf/DMLConfig.java |  4 +-
 .../context/SparkExecutionContext.java             | 50 ++++++++++++++--------
 2 files changed, 36 insertions(+), 18 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index a59101b..6194bc2 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -89,7 +89,8 @@ public class DMLConfig
 	public static final String SYNCHRONIZE_GPU      = "sysds.gpu.sync.postProcess"; // boolean: whether to synchronize GPUs after every instruction
 	public static final String EAGER_CUDA_FREE      = "sysds.gpu.eager.cudaFree"; // boolean: whether to perform eager CUDA free on rmvar
 	public static final String GPU_EVICTION_POLICY  = "sysds.gpu.eviction.policy"; // string: can be lru, lfu, min_evict
-	public static final String LOCAL_SPARK_NUM_THREADS = "sysds.local.spark.number.threads";
+	public static final String USE_LOCAL_SPARK_CONFIG = "sysds.local.spark"; // If set to true, it forces spark execution to a local spark context.
+	public static final String LOCAL_SPARK_NUM_THREADS = "sysds.local.spark.number.threads"; // the number of threads allowed to be used in the local spark configuration, default is * to enable use of all threads.
 	public static final String LINEAGECACHESPILL    = "sysds.lineage.cachespill"; // boolean: whether to spill cache entries to disk
 	public static final String COMPILERASSISTED_RW  = "sysds.lineage.compilerassisted"; // boolean: whether to apply compiler assisted rewrites
 	
@@ -152,6 +153,7 @@ public class DMLConfig
 		_defaultVals.put(GPU_MEMORY_ALLOCATOR,   "cuda");
 		_defaultVals.put(AVAILABLE_GPUS,         "-1");
 		_defaultVals.put(GPU_EVICTION_POLICY,    "min_evict");
+		_defaultVals.put(USE_LOCAL_SPARK_CONFIG, "false");
 		_defaultVals.put(LOCAL_SPARK_NUM_THREADS, "*"); // * Means it allocates the number of available threads on the local host machine.
 		_defaultVals.put(SYNCHRONIZE_GPU,        "false" );
 		_defaultVals.put(EAGER_CUDA_FREE,        "false" );
diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index ca3f69c..67efd5c 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -42,6 +42,7 @@ import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.storage.RDDInfo;
 import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.util.LongAccumulator;
+import org.apache.sysds.api.DMLException;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.api.mlcontext.MLContext;
 import org.apache.sysds.api.mlcontext.MLContextUtil;
@@ -213,23 +214,15 @@ public class SparkExecutionContext extends ExecutionContext
 		}
 		else
 		{
-			if(DMLScript.USE_LOCAL_SPARK_CONFIG) {
-				// For now set 4 cores for integration testing :)
-				SparkConf conf = createSystemDSSparkConf()
-						.setMaster("local[" +
-							ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_SPARK_NUM_THREADS)+
-							"]").setAppName("My local integration test app");
-				// This is discouraged in spark but have added only for those testcase that cannot stop the context properly
-				// conf.set("spark.driver.allowMultipleContexts", "true");
-				conf.set("spark.ui.enabled", "false");
-				_spctx = new JavaSparkContext(conf);
-			}
-			else //default cluster setup
-			{
-				//setup systemds-preferred spark configuration (w/o user choice)
-				SparkConf conf = createSystemDSSparkConf();
-				_spctx = new JavaSparkContext(conf);
-			}
+			final SparkConf conf = createSystemDSSparkConf();
+			final DMLConfig dmlConfig= ConfigurationManager.getDMLConfig();
+			// Use Spark local config, if already set to True ... keep true, otherwise look up if it should be local.
+			DMLScript.USE_LOCAL_SPARK_CONFIG = DMLScript.USE_LOCAL_SPARK_CONFIG ? true : dmlConfig.getBooleanValue(DMLConfig.USE_LOCAL_SPARK_CONFIG);
+			
+			if(DMLScript.USE_LOCAL_SPARK_CONFIG)
+				setLocalConfSettings(conf);
+			
+			_spctx = createContext(conf);
 
 			_parRDDs.clear();
 		}
@@ -253,6 +246,29 @@ public class SparkExecutionContext extends ExecutionContext
 		}
 	}
 
+
+	private static JavaSparkContext createContext(SparkConf conf){
+		try{
+			return new JavaSparkContext(conf);
+		} 
+		catch(Exception e){
+			if(e.getMessage().contains("A master URL must be set in your configuration")){
+				LOG.error("Error constructing Spark Context, falling back to local Spark context creation");
+				setLocalConfSettings(conf);
+				return createContext(conf);
+			}
+			else
+				throw new DMLException("Error while creating Spark context", e);
+		}
+	}
+
+	private static void setLocalConfSettings(SparkConf conf){
+		final String threads = ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.LOCAL_SPARK_NUM_THREADS);
+		conf.setMaster("local[" + threads + "]");
+		conf.setAppName("LocalSparkContextApp");
+		conf.set("spark.ui.enabled", "false");
+	}
+
 	/**
 	 * Sets up a SystemDS-preferred Spark configuration based on the implicit
 	 * default configuration (as passed via configurations from outside).