You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/08/03 14:53:35 UTC

[systemds] branch main updated: [SYSTEMDS-3390] Improve performance of countDistinctApprox()

This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 75385a9493 [SYSTEMDS-3390] Improve performance of countDistinctApprox()
75385a9493 is described below

commit 75385a949312a8ea3591559633f9961e811ac067
Author: Badrul Chowdhury <ba...@gmail.com>
AuthorDate: Wed Aug 3 16:03:42 2022 +0200

    [SYSTEMDS-3390] Improve performance of countDistinctApprox()
    
    This patch improves the performance of countDistinctApprox() row/col
    aggregation by replacing matrix slicing with direct ops on the input
    matrix. This has the most impact in local CP execution mode, as
    some simple experiments show:
    
    (numbers represent average over 3 runs)
    1. row aggregation
        (A) dense: 10000x1000 with sparsity=0.9
        1.198s with slicing, 0.874s without slicing - a 27% improvement
    
        (B) sparse: 10000x1000 with sparsity=0.1
        0.528s with slicing, 0.512s without slicing - a 3% improvement
    
    As expected, the larger and the more dense the input matrix,
    the larger the performance improvement.
    
    2. col aggregation
        (A) dense: 1000x10000 with sparsity=0.9
        1.186s with slicing, 1.036s without slicing - a 13% improvement
    
        (B) sparse: 1000x10000 with sparsity=0.1
        1.272s with slicing, 0.647s without slicing - a 49% improvement
    
    In this case, the sparser the input matrix, the larger the performance
    improvement. This phenomenon is a result of employing a hash map M
    in the implementation: as the RxC input matrix becomes denser, M's
    keyset size approaches C, and the performance approaches the baseline,
    which uses slicing.
    
    Closes #1650
---
 .../cp/AggregateUnaryCPInstruction.java            |  41 ++-
 .../matrix/data/LibMatrixCountDistinct.java        | 323 +++++++++++++++++----
 .../runtime/matrix/data/sketch/MatrixSketch.java   |  24 +-
 .../CountDistinctApproxSketch.java                 |   2 +-
 .../data/sketch/countdistinctapprox/KMVSketch.java | 224 +++++++++-----
 .../countdistinctapprox/SmallestPriorityQueue.java |   5 +
 .../test/component/matrix/CountDistinctTest.java   |   2 +-
 .../countDistinct/CountDistinctApproxCol.java      |  48 +++
 .../countDistinct/CountDistinctApproxRow.java      |  48 +++
 .../functions/countDistinct/CountDistinctBase.java |   6 +-
 .../countDistinct/CountDistinctRowOrColBase.java   |  45 ++-
 11 files changed, 576 insertions(+), 192 deletions(-)

diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index fbcf6ff7f3..ddf00ada2b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -82,8 +82,12 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str);
 		} 
 		else if(opcode.equalsIgnoreCase("uacd")){
-			return new AggregateUnaryCPInstruction(new SimpleOperator(null),
-			in1, out, AUType.COUNT_DISTINCT, opcode, str);
+			CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT)
+					.setDirection(Types.Direction.RowCol)
+					.setIndexFunction(ReduceAll.getReduceAllFnObject());
+
+			return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT,
+					opcode, str);
 		}
 		else if(opcode.equalsIgnoreCase("uacdap")){
 			CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
@@ -199,9 +203,15 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				if( !ec.getVariables().keySet().contains(input1.getName()) )
 					throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist.");
 				MatrixBlock input = ec.getMatrixInput(input1.getName());
-				CountDistinctOperator op = new CountDistinctOperator(_type);
+
+				// Operator type: test and cast
+				if (!(_optr instanceof CountDistinctOperator)) {
+					throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName());
+				}
+				CountDistinctOperator op = (CountDistinctOperator) (_optr);
+
 				//TODO add support for row or col count distinct.
-				int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
+				int res = (int) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
 				ec.releaseMatrixInput(input1.getName());
 				ec.setScalarOutput(output_name, new IntObject(res));
 				break;
@@ -219,27 +229,16 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
 				CountDistinctOperator op = (CountDistinctOperator) _optr;  // It is safe to cast at this point
 
 				if (op.getDirection().isRowCol()) {
-					int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
+					long res = (long) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
 					ec.releaseMatrixInput(input1.getName());
 					ec.setScalarOutput(output_name, new IntObject(res));
-				} else if (op.getDirection().isRow()) {
-					//TODO Do not slice out the matrix but directly process on the input
-					MatrixBlock res = input.slice(0, input.getNumRows() - 1, 0, 0);
-					for (int i = 0; i < input.getNumRows(); ++i) {
-						res.setValue(i, 0, LibMatrixCountDistinct.estimateDistinctValues(input.slice(i, i), op));
-					}
-					ec.releaseMatrixInput(input1.getName());
-					ec.setMatrixOutput(output_name, res);
-				} else if (op.getDirection().isCol()) {
-					//TODO Do not slice out the matrix but directly process on the input
-					MatrixBlock res = input.slice(0, 0, 0, input.getNumColumns() - 1);
-					for (int j = 0; j < input.getNumColumns(); ++j) {
-						res.setValue(0, j, LibMatrixCountDistinct.estimateDistinctValues(input.slice(0, input.getNumRows() - 1, j, j), op));
-					}
+				} else {  // Row/Col
+					// Note that for each row, the max number of distinct values < NNZ < max number of columns = 1000:
+					// Since count distinct approximate estimates are unreliable for values < 1024,
+					// we will force a naive count.
+					MatrixBlock res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
 					ec.releaseMatrixInput(input1.getName());
 					ec.setMatrixOutput(output_name, res);
-				} else {
-					throw new DMLRuntimeException("Direction for CountDistinctOperator not recognized");
 				}
 
 				break;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 4b13abc995..1198b18dd5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -19,81 +19,93 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
-import java.util.HashSet;
-import java.util.Set;
+import java.util.*;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLException;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.data.DenseBlock;
-import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.*;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 import org.apache.sysds.utils.Hash.HashType;
 
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
 /**
  * This class contains various methods for counting the number of distinct values inside a MatrixBlock
  */
 public interface LibMatrixCountDistinct {
-	static final Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+	Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
 
 	/**
 	 * The minimum number NonZero of cells in the input before using approximate techniques for counting number of
 	 * distinct values.
 	 */
-	public static int minimumSize = 1024;
+	int minimumSize = 1024;
 
 	/**
 	 * Public method to count the number of distinct values inside a matrix. Depending on which CountDistinctOperator
 	 * selected it either gets the absolute number or a estimated value.
 	 * 
 	 * TODO: Support counting num distinct in rows, or columns axis.
-	 * 
-	 * TODO: Add support for distributed spark operations
-	 * 
 	 * TODO: If the MatrixBlock type is CompressedMatrix, simply read the values from the ColGroups.
 	 * 
 	 * @param in the input matrix to count number distinct values in
 	 * @param op the selected operator to use
-	 * @return the distinct count
+	 * @return A matrix block containing the absolute distinct count for the entire input or along given row/col axis
 	 */
-	public static int estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
-		int res = 0;
+	static MatrixBlock estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
 		if(op.getOperatorType() == CountDistinctOperatorTypes.KMV &&
 			(op.getHashType() == HashType.ExpHash || op.getHashType() == HashType.StandardJava)) {
 			throw new DMLException(
 				"Invalid hashing configuration using " + op.getHashType() + " and " + op.getOperatorType());
 		}
 		else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL) {
-			throw new NotImplementedException("HyperLogLog not implemented");
+			throw new NotImplementedException("HyperLogLog has not been implemented yet");
+		}
+
+		// shortcut in the simplest case.
+		if(in.getLength() == 1 || in.isEmpty()) {
+			return new MatrixBlock(1);
 		}
-		// shortcut in simplest case.
-		if(in.getLength() == 1 || in.isEmpty())
-			return 1;
-		else if(in.getNonZeros() < minimumSize) {
-			// Just use naive implementation if the number of nonZeros values size is small.
-			res = countDistinctValuesNaive(in);
+
+		long averageNnzPerRowOrCol;
+		if (op.getDirection().isRowCol()) {
+			averageNnzPerRowOrCol = in.getNonZeros();
+		} else if (op.getDirection().isRow()) {
+			// The average nnz per row is susceptible to skew. However, given that CP instructions is limited to
+			// matrices of size at most 1000 x 1000, the performance impact of using naive counting over sketch per
+			// row/col as determined by the average is negligible. Besides, the average is the simplest measure
+			// available without calculating nnz per row/col.
+			averageNnzPerRowOrCol = (long) Math.floor(in.getNonZeros() / (double) in.getNumRows());
+		} else if (op.getDirection().isCol()) {
+			averageNnzPerRowOrCol = (long) Math.floor(in.getNonZeros() / (double) in.getNumColumns());
+		} else {
+			throw new IllegalArgumentException("Unrecognized direction " + op.getDirection());
 		}
-		else {
+
+		// Result is a dense 1x1 (RowCol), Mx1 (Row), or 1xN (Col) matrix
+		MatrixBlock res;
+		if (averageNnzPerRowOrCol < minimumSize) {
+			// Resort to naive counting for small enough matrices
+			res = countDistinctValuesNaive(in, op);
+		} else {
 			switch(op.getOperatorType()) {
 				case COUNT:
-					res = countDistinctValuesNaive(in);
+					res = countDistinctValuesNaive(in, op);
 					break;
 				case KMV:
-					res = new KMVSketch(op).getScalarValue(in);
+					res = new KMVSketch(op).getValue(in);
 					break;
 				default:
-					throw new DMLException("Invalid or not implemented Estimator Type");
+					throw new DMLException("Invalid estimator type for aggregation: " + LibMatrixCountDistinct.class.getSimpleName());
 			}
 		}
 
-		if(res <= 0)
-			throw new DMLRuntimeException("Impossible estimate of distinct values");
 		return res;
 	}
 
@@ -102,66 +114,257 @@ public interface LibMatrixCountDistinct {
 	 * 
 	 * Benefit: precise, but uses memory, on the scale of inputs number of distinct values.
 	 * 
-	 * @param in The input matrix to count number distinct values in
-	 * @return The absolute distinct count
+	 * @param blkIn The input matrix to count number distinct values in
+	 * @return A matrix block containing the absolute distinct count for the entire input or along given row/col axis
 	 */
-	private static int countDistinctValuesNaive(MatrixBlock in) {
+	private static MatrixBlock countDistinctValuesNaive(MatrixBlock blkIn, CountDistinctOperator op) {
+
+		if (blkIn.isEmpty()) {
+			return new MatrixBlock(1);
+		}
+		else if(blkIn instanceof CompressedMatrixBlock) {
+			throw new NotImplementedException("countDistinct() does not support CompressedMatrixBlock");
+		}
+
 		Set<Double> distinct = new HashSet<>();
+		MatrixBlock blkOut;
 		double[] data;
-		if(in.isEmpty())
-			return 1;
-		else if(in instanceof CompressedMatrixBlock)
-			throw new NotImplementedException();
 
-		long nonZeros = in.getNonZeros();
+		if (op.getDirection().isRowCol()) {
+			blkOut = new MatrixBlock(1, 1, false);
 
-		if(nonZeros != -1 && nonZeros < in.getNumColumns() * in.getNumRows()) {
-			distinct.add(0d);
-		}
+			long distinctCount = 0;
+			long nonZeros = blkIn.getNonZeros();
 
-		if(in.sparseBlock != null) {
-			SparseBlock sb = in.sparseBlock;
+			// Check if input matrix contains any 0 values for RowCol case.
+			// This does not apply to row/col case, where we count nnz per row or col during iteration.
+			if(nonZeros != -1 && nonZeros < (long) blkIn.getNumColumns() * blkIn.getNumRows()) {
+				distinct.add(0d);
+			}
 
-			if(in.sparseBlock.isContiguous()) {
-				data = sb.values(0);
-				countDistinctValuesNaive(data, distinct);
+			if(blkIn.getSparseBlock() != null) {
+				SparseBlock sb = blkIn.getSparseBlock();
+				if(blkIn.getSparseBlock().isContiguous()) {
+					// COO, CSR
+					data = sb.values(0);
+					distinctCount = countDistinctValuesNaive(data, distinct);
+				} else {
+					// MCSR
+					for(int i = 0; i < blkIn.getNumRows(); i++) {
+						if(!sb.isEmpty(i)) {
+							data = blkIn.getSparseBlock().values(i);
+							distinctCount = countDistinctValuesNaive(data, distinct);
+						}
+					}
+				}
+			} else if(blkIn.getDenseBlock() != null) {
+				DenseBlock db = blkIn.getDenseBlock();
+				for (int i = 0; i <= db.numBlocks(); i++) {
+					data = db.valuesAt(i);
+					distinctCount = countDistinctValuesNaive(data, distinct);
+				}
 			}
-			else {
-				for(int i = 0; i < in.getNumRows(); i++) {
-					if(!sb.isEmpty(i)) {
-						data = in.sparseBlock.values(i);
+
+			blkOut.setValue(0, 0, distinctCount);
+		} else if (op.getDirection().isRow()) {
+			blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+			blkOut.allocateBlock();
+
+			if (blkIn.getDenseBlock() != null) {
+				// The naive approach would be to iterate through every (i, j) in the input. However, can do better
+				// by exploiting the physical layout of dense blocks - contiguous blocks in row-major order - in memory.
+				DenseBlock db = blkIn.getDenseBlock();
+				for (int bix=0; bix<db.numBlocks(); ++bix) {
+					data = db.valuesAt(bix);
+					for (int rix=bix * db.blockSize(); rix<blkIn.getNumRows(); rix++) {
+						distinct.clear();
+						for (int cix=0; cix<blkIn.getNumColumns(); ++cix) {
+							distinct.add(data[db.pos(rix, cix)]);
+						}
+						blkOut.setValue(rix, 0, distinct.size());
+					}
+				}
+			} else if (blkIn.getSparseBlock() != null) {
+				// Each sparse block type - COO, CSR, MCSR - has a different data representation, which we will exploit
+				// separately.
+				SparseBlock sb = blkIn.getSparseBlock();
+				if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.MCSR)) {
+					// Currently, SparseBlockIterator only provides an interface for cell-wise iteration.
+					// TODO Explore row-wise and column-wise methods for SparseBlockIterator
+
+					// MCSR enables O(1) access to column values per row
+					for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+						if (sb.isEmpty(rix)) {
+							continue;
+						}
+						distinct.clear();
+						data = sb.values(rix);
 						countDistinctValuesNaive(data, distinct);
+						blkOut.setValue(rix, 0, distinct.size());
+					}
+				} else if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+					// Casting is safe given if-condition
+					SparseBlockCSR csrBlock = (SparseBlockCSR) sb;
+
+					// Data lies in one contiguous block in CSR format. We will iterate in row-major using O(1) op
+					// size(row) to determine the number of columns per row.
+					data = csrBlock.values();
+					// We want to iterate through all rows to keep track of the row index for constructing the output
+					for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+						if (csrBlock.isEmpty(rix)) {
+							continue;
+						}
+						distinct.clear();
+						int rpos = csrBlock.pos(rix);
+						int clen = csrBlock.size(rix);
+						for (int colOffset=0; colOffset<clen; ++colOffset) {
+							distinct.add(data[rpos + colOffset]);
+						}
+						blkOut.setValue(rix, 0, distinct.size());
+					}
+				} else { // COO
+					if (!(sb instanceof SparseBlockCOO)) {
+						throw new IllegalArgumentException("Input matrix is of unrecognized type: "
+								+ sb.getClass().getSimpleName());
+					}
+					SparseBlockCOO cooBlock = (SparseBlockCOO) sb;
+
+					// For COO, we want to avoid using pos(row) and size(row) as they use binary search, which is a
+					// O(log N) op. Also, isEmpty(row) uses pos(row) internally.
+					int[] rixs = cooBlock.rowIndexes();
+					data = cooBlock.values();
+					int i = 0;  // data iterator
+					int rix = 0;  // row index
+					while (rix < cooBlock.numRows() && i < rixs.length) {
+						distinct.clear();
+						while (i + 1 < rixs.length && rixs[i] == rixs[i + 1]) {
+							distinct.add(data[i]);
+							i++;
+						}
+						if (i + 1 < rixs.length) {  // rixs[i] != rixs[i + 1]
+							distinct.add(data[i]);
+						}
+						blkOut.setValue(rix, 0, distinct.size());
+						rix = (i + 1 < rixs.length)? rixs[i + 1] : rix;
+						i++;
 					}
 				}
 			}
-		}
-		else if(in.denseBlock != null) {
-			DenseBlock db = in.denseBlock;
-			for(int i = 0; i <= db.numBlocks(); i++) {
-				data = db.valuesAt(i);
-				countDistinctValuesNaive(data, distinct);
+		} else {  // Col aggregation
+			blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumRows());
+			blkOut.allocateBlock();
+
+			// All dense and sparse formats (COO, CSR, MCSR) are row-major formats, so there is no obvious way to iterate
+			// in column-major order besides iterating through every (i, j) pair. getValue() skips over empty cells in CSR
+			// and MCSR formats, but not so in COO format. This results in O(log2 R * log2 C) time for every lookup,
+			// amounting to O(RC * log2R * log2C) for the whole block (R, C <= 1000 in CP case). We will eschew this
+			// approach in favor of one using a hash map M of (column index, distinct values) to obtain a pseudo column-major
+			// grouping of distinct values instead. Given this setup, we will simply iterate over the input
+			// (according to specific dense/sparse format) in row-major order and populate M. Finally, an O(C) iteration
+			// over M will yield the final result.
+			Map<Integer, Set<Double>> distinctValuesByCol = new HashMap<>();
+			if (blkIn.getDenseBlock() != null) {
+				DenseBlock db = blkIn.getDenseBlock();
+				for (int bix=0; bix<db.numBlocks(); ++bix) {
+					data = db.valuesAt(bix);
+					for (int cix=0; cix<blkIn.getNumColumns(); ++cix) {
+						Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+						for (int rix=bix * db.blockSize(); rix<blkIn.getNumRows(); rix++) {
+							double val = data[db.pos(rix, cix)];
+							distinctValues.add(val);
+						}
+						distinctValuesByCol.put(cix, distinctValues);
+					}
+				}
+			} else if (blkIn.getSparseBlock() != null) {
+				SparseBlock sb = blkIn.getSparseBlock();
+				if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.MCSR)) {
+					for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+						if (sb.isEmpty(rix)) {
+							continue;
+						}
+						int[] cixs = sb.indexes(rix);
+						data = sb.values(rix);
+						for (int j=0; j<sb.size(rix); ++j) {
+							int cix = cixs[j];
+							Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+							distinctValues.add(data[j]);
+							distinctValuesByCol.put(cix, distinctValues);
+						}
+					}
+				} else if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+					SparseBlockCSR csrBlock = (SparseBlockCSR) sb;
+					data = csrBlock.values();
+					for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+						if (csrBlock.isEmpty(rix)) {
+							continue;
+						}
+						distinct.clear();
+						int rpos = csrBlock.pos(rix);
+						int clen = csrBlock.size(rix);
+						int[] cixs = csrBlock.indexes();
+						for (int colOffset=0; colOffset<clen; ++colOffset) {
+							int cix = cixs[rpos + colOffset];
+							Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+							distinctValues.add(data[rpos + colOffset]);
+							distinctValuesByCol.put(cix, distinctValues);
+						}
+					}
+				} else {  // COO
+					if (!(sb instanceof SparseBlockCOO)) {
+						throw new IllegalArgumentException("Input matrix is of unrecognized type: "
+								+ sb.getClass().getSimpleName());
+					}
+					SparseBlockCOO cooBlock = (SparseBlockCOO) sb;
+
+					int[] rixs = cooBlock.rowIndexes();
+					int[] cixs = cooBlock.indexes();
+					data = cooBlock.values();
+					int i = 0;  // data iterator
+					while (i < rixs.length) {
+						while (i + 1 < rixs.length && rixs[i] == rixs[i + 1]) {
+							int cix = cixs[i];
+							Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+							distinctValues.add(data[i]);
+							distinctValuesByCol.put(cix, distinctValues);
+							i++;
+						}
+						if (i + 1 < rixs.length) {
+							int cix = cixs[i];
+							Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+							distinctValues.add(data[i]);
+							distinctValuesByCol.put(cix, distinctValues);
+						}
+						i++;
+					}
+				}
+			}
+			// Fill in output block with column aggregation results
+			for (int cix : distinctValuesByCol.keySet()) {
+				blkOut.setValue(0, cix, distinctValuesByCol.get(cix).size());
 			}
 		}
 
-		return distinct.size();
+		return blkOut;
 	}
 
-	private static Set<Double> countDistinctValuesNaive(double[] valuesPart, Set<Double> distinct) {
-		for(double v : valuesPart) 
+	private static long countDistinctValuesNaive(double[] valuesPart, Set<Double> distinct) {
+		for(double v : valuesPart)
 			distinct.add(v);
-		return distinct;
+
+		return distinct.size();
 	}
 
-	public static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, CountDistinctOperator op) {
+	static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, CountDistinctOperator op) {
 		if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
-			return new KMVSketch(op).getMatrixValue(arg0);
+			return new KMVSketch(op).getValueFromSketch(arg0);
 		else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
 			throw new NotImplementedException("Not implemented yet");
 		else
 			throw new NotImplementedException("Not implemented yet");
 	}
 
-	public static CorrMatrixBlock createSketch(MatrixBlock blkIn, CountDistinctOperator op) {
+	static CorrMatrixBlock createSketch(MatrixBlock blkIn, CountDistinctOperator op) {
 		if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
 			return new KMVSketch(op).create(blkIn);
 		else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
@@ -170,7 +373,7 @@ public interface LibMatrixCountDistinct {
 			throw new NotImplementedException("Not implemented yet");
 	}
 
-	public static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, CorrMatrixBlock arg1, CountDistinctOperator op) {
+	static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, CorrMatrixBlock arg1, CountDistinctOperator op) {
 		if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
 			return new KMVSketch(op).union(arg0, arg1);
 		else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
index f9c5f63a03..6feb52a140 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
@@ -22,15 +22,15 @@ package org.apache.sysds.runtime.matrix.data.sketch;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-public interface MatrixSketch<T> {
+public interface MatrixSketch {
 
 	/**
-	 * Get scalar distinct count from a input matrix block.
+	 * Get scalar distinct count from an input matrix block.
 	 * 
-	 * @param blkIn A input block to estimate the number of distinct values in
-	 * @return The distinct count estimate
+	 * @param blkIn An input block to estimate the number of distinct values in
+	 * @return The result matrix block containing the distinct count estimate
 	 */
-	T getScalarValue(MatrixBlock blkIn);
+	MatrixBlock getValue(MatrixBlock blkIn);
 
 	/**
 	 * Obtain matrix distinct count value from estimation Used for estimating distinct in rows or columns.
@@ -38,31 +38,31 @@ public interface MatrixSketch<T> {
 	 * @param blkIn The sketch block to extract the count from
 	 * @return The result matrix block
 	 */
-	public MatrixBlock getMatrixValue(CorrMatrixBlock blkIn);
+	MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn);
 
 	/**
-	 * Create a initial sketch of a given block.
+	 * Create an initial sketch of a given block.
 	 * 
 	 * @param blkIn A block to process
 	 * @return A sketch
 	 */
-	public CorrMatrixBlock create(MatrixBlock blkIn);
+	CorrMatrixBlock create(MatrixBlock blkIn);
 
 	/**
 	 * Union two sketches together to from a combined sketch.
 	 * 
 	 * @param arg0 Sketch one
 	 * @param arg1 Sketch two
-	 * @return The combined sketch
+	 * @return The sketch union is a sketch
 	 */
-	public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
+	CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
 
 	/**
 	 * Intersect two sketches
 	 * 
 	 * @param arg0 Sketch one
 	 * @param arg1 Sketch two
-	 * @return The intersected sketch
+	 * @return The sketch intersection is a sketch
 	 */
-	public CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
+	CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
index 9893e098c5..d5df3b241a 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
@@ -26,7 +26,7 @@ import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 // Package private
-abstract class CountDistinctApproxSketch implements MatrixSketch<Integer> {
+abstract class CountDistinctApproxSketch implements MatrixSketch {
 	CountDistinctOperator op;
 
 	CountDistinctApproxSketch(Operator op) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
index 01cfb289e5..31e7d15c5d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.data.DenseBlock;
@@ -52,49 +51,89 @@ public class KMVSketch extends CountDistinctApproxSketch {
 	}
 
 	@Override
-	public Integer getScalarValue(MatrixBlock in) {
-
-		// D is the number of possible distinct values in the MatrixBlock.
-		// plus 1 to take account of 0 input.
-		long D = in.getNonZeros() + 1;
-
-		/**
-		 * To ensure that the likelihood to hash to the same value we need O(D^2) positions to hash to assign. If the
-		 * value is higher than int (which is the area we hash to) then use Integer Max value as largest hashing space.
-		 */
-		long tmp = D * D;
-		int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
-		/**
-		 * The estimator is asymptotically unbiased as k becomes large, but memory usage also scales with k. Furthermore k
-		 * value must be within range: D >> k >> 0
-		 */
-		int k = D > 64 ? 64 : (int) D;
+	public MatrixBlock getValue(MatrixBlock blkIn) {
+
+		if (this.op.getDirection().isRowCol()) {
+			// D is the number of possible distinct values in the MatrixBlock.
+			// plus 1 to take account of 0 input.
+			long D = blkIn.getNonZeros() + 1;
+
+			/**
+			 * To ensure that the likelihood to hash to the same value we need O(D^2) positions to hash to assign. If the
+			 * value is higher than int (which is the area we hash to) then use Integer Max value as largest hashing space.
+			 */
+			long tmp = D * D;
+			int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+			/**
+			 * The estimator is asymptotically unbiased as k becomes large, but memory usage also scales with k. Furthermore k
+			 * value must be within range: D >> k >> 0
+			 */
+			int k = D > 64 ? 64 : (int) D;
+
+			SmallestPriorityQueue spq = getKSmallestHashes(blkIn, k, M);
 
-		SmallestPriorityQueue spq = getKSmallestHashes(in, k, M);
+			if(LOG.isDebugEnabled()) {
+				LOG.debug("M not forced to int size: " + tmp);
+				LOG.debug("M: " + M);
+				LOG.debug("M: " + M);
+				LOG.debug("kth smallest hash:" + spq.peek());
+				LOG.debug("spq: " + spq);
+			}
 
-		if(LOG.isDebugEnabled()) {
-			LOG.debug("M not forced to int size: " + tmp);
-			LOG.debug("M: " + M);
-			LOG.debug("M: " + M);
-			LOG.debug("kth smallest hash:" + spq.peek());
-			LOG.debug("spq: " + spq.toString());
-		}
 
-		if(spq.size() < k) {
-			return spq.size();
-		}
-		else {
-			double kthSmallestHash = spq.poll();
-			double U_k = kthSmallestHash / (double) M;
-			double estimate = (double) (k - 1) / U_k;
-			double ceilEstimate = Math.min(estimate, (double) D);
+			long res = countDistinctValuesKMV(spq, k, M, D);
+			if(res <= 0) {
+				throw new DMLRuntimeException("Impossible estimate of distinct values");
+			}
 
-			if(LOG.isDebugEnabled()) {
-				LOG.debug("U_k : " + U_k);
-				LOG.debug("Estimate: " + estimate);
-				LOG.debug("Ceil worst case: " + D);
+			// Result is a 1x1 matrix block
+			return new MatrixBlock(res);
+
+		} else if (this.op.getDirection().isRow()) {
+			long D = (long) Math.floor(blkIn.getNonZeros() / (double) blkIn.getNumRows()) + 1;
+			long tmp = D * D;
+			int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+			int k = D > 64 ? 64 : (int) D;
+
+			MatrixBlock resultMatrix = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+			resultMatrix.allocateBlock();
+
+			SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+			for (int i=0; i<blkIn.getNumRows(); ++i) {
+				for (int j=0; j<blkIn.getNumColumns(); ++j) {
+					spq.add(blkIn.getValue(i, j));
+				}
+
+				long res = countDistinctValuesKMV(spq, k, M, D);
+				resultMatrix.setValue(i, 0, res);
+
+				spq.clear();
+			}
+
+			return resultMatrix;
+
+		} else {  // Col
+			long D = (long) Math.floor(blkIn.getNonZeros() / (double) blkIn.getNumColumns()) + 1;
+			long tmp = D * D;
+			int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+			int k = D > 64 ? 64 : (int) D;
+
+			MatrixBlock resultMatrix = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
+			resultMatrix.allocateBlock();
+
+			SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+			for (int j=0; j<blkIn.getNumColumns(); ++j) {
+				for (int i=0; i<blkIn.getNumRows(); ++i) {
+					spq.add(blkIn.getValue(i, j));
+				}
+
+				long res = countDistinctValuesKMV(spq, k, M, D);
+				resultMatrix.setValue(0, j, res);
+
+				spq.clear();
 			}
-			return (int) ceilEstimate;
+
+			return resultMatrix;
 		}
 	}
 
@@ -146,21 +185,45 @@ public class KMVSketch extends CountDistinctApproxSketch {
 		}
 	}
 
+	private long countDistinctValuesKMV(SmallestPriorityQueue spq, int k, int M, long D) {
+		long res;
+		if(spq.size() < k) {
+			res = spq.size();
+		}
+		else {
+			double kthSmallestHash = spq.poll();
+			double U_k = kthSmallestHash / (double) M;
+			double estimate = (double) (k - 1) / U_k;
+			double ceilEstimate = Math.min(estimate, (double) D);
+
+			if(LOG.isDebugEnabled()) {
+				LOG.debug("U_k : " + U_k);
+				LOG.debug("Estimate: " + estimate);
+				LOG.debug("Ceil worst case: " + D);
+			}
+			res = Math.round(ceilEstimate);
+		}
+
+		return res;
+	}
+
 	@Override
-	public MatrixBlock getMatrixValue(CorrMatrixBlock arg0) {
+	public MatrixBlock getValueFromSketch(CorrMatrixBlock arg0) {
 		MatrixBlock blkIn = arg0.getValue();
-		if(op.getDirection() == Types.Direction.Row) {
-			// 1000 x 1 blkOut -> slice out the first column of the matrix
-			MatrixBlock blkOut = blkIn.slice(0, blkIn.getNumRows() - 1, 0, 0);
+		if(op.getDirection().isRow()) {
+			// 1000 x 1 blkOut
+			MatrixBlock blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+			blkOut.allocateBlock();
 			for(int i = 0; i < blkIn.getNumRows(); ++i) {
 				getDistinctCountFromSketchByIndex(arg0, i, blkOut);
 			}
 
 			return blkOut;
 		}
-		else if(op.getDirection() == Types.Direction.Col) {
-			// 1 x 1000 blkOut -> slice out the first row of the matrix
-			MatrixBlock blkOut = blkIn.slice(0, 0, 0, blkIn.getNumColumns() - 1);
+		else if(op.getDirection().isCol()) {
+			// 1 x 1000 blkOut
+			MatrixBlock blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
+			blkOut.allocateBlock();
 			for(int j = 0; j < blkIn.getNumColumns(); ++j) {
 				getDistinctCountFromSketchByIndex(arg0, j, blkOut);
 			}
@@ -169,8 +232,9 @@ public class KMVSketch extends CountDistinctApproxSketch {
 		}
 		else { // op.getDirection().isRowCol()
 
-			// 1 x 1 blkOut -> slice out the first row and column of the matrix
-			MatrixBlock blkOut = blkIn.slice(0, 0, 0, 0);
+			// 1 x 1 blkOut
+			MatrixBlock blkOut = new MatrixBlock(1, 1, false, 1);
+			blkOut.allocateBlock();
 			getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
 
 			return blkOut;
@@ -181,41 +245,43 @@ public class KMVSketch extends CountDistinctApproxSketch {
 		MatrixBlock blkIn = arg0.getValue();
 		MatrixBlock blkInCorr = arg0.getCorrection();
 
-		if(op.getOperatorType() == CountDistinctOperatorTypes.KMV) {
-			double kthSmallestHash;
-			if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
-				kthSmallestHash = blkIn.getValue(idx, 0);
-			}
-			else { // op.getDirection().isCol()
-				kthSmallestHash = blkIn.getValue(0, idx);
-			}
+		if(op.getOperatorType() != CountDistinctOperatorTypes.KMV) {
+			throw new IllegalArgumentException(this.getClass().getSimpleName() + " cannot use " + op.getOperatorType());
+		}
 
-			double nHashes = blkInCorr.getValue(idx, 0);
-			double k = blkInCorr.getValue(idx, 1);
-			double D = blkInCorr.getValue(idx, 2);
+		double kthSmallestHash;
+		if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+			kthSmallestHash = blkIn.getValue(idx, 0);
+		}
+		else { // op.getDirection().isCol()
+			kthSmallestHash = blkIn.getValue(0, idx);
+		}
 
-			double D2 = D * D;
-			double M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : D2;
+		double nHashes = blkInCorr.getValue(idx, 0);
+		double k = blkInCorr.getValue(idx, 1);
+		double D = blkInCorr.getValue(idx, 2);
 
-			double ceilEstimate;
-			if(nHashes != 0 && nHashes < k) {
-				ceilEstimate = nHashes;
-			}
-			else if(nHashes == 0) {
-				ceilEstimate = 1;
-			}
-			else {
-				double U_k = kthSmallestHash / M;
-				double estimate = (k - 1) / U_k;
-				ceilEstimate = Math.min(estimate, D);
-			}
+		double D2 = D * D;
+		double M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : D2;
 
-			if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
-				blkOut.setValue(idx, 0, ceilEstimate);
-			}
-			else { // op.getDirection().isCol()
-				blkOut.setValue(0, idx, ceilEstimate);
-			}
+		double ceilEstimate;
+		if(nHashes != 0 && nHashes < k) {
+			ceilEstimate = nHashes;
+		}
+		else if(nHashes == 0) {
+			ceilEstimate = 1;
+		}
+		else {
+			double U_k = kthSmallestHash / M;
+			double estimate = (k - 1) / U_k;
+			ceilEstimate = Math.min(estimate, D);
+		}
+
+		if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+			blkOut.setValue(idx, 0, ceilEstimate);
+		}
+		else { // op.getDirection().isCol()
+			blkOut.setValue(0, idx, ceilEstimate);
 		}
 	}
 
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
index 0a29028c66..f3f7336181 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
@@ -77,6 +77,11 @@ public class SmallestPriorityQueue {
 		return this.size() == 0;
 	}
 
+	public void clear() {
+		this.containedSet.clear();
+		this.smallestHashes.clear();
+	}
+
 	@Override
 	public String toString() {
 		return smallestHashes.toString();
diff --git a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index cd20b67c35..5de18c4b3e 100644
--- a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++ b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -145,7 +145,7 @@ public class CountDistinctTest {
 				});
 			}
 			else {
-				int out = LibMatrixCountDistinct.estimateDistinctValues(in, op);
+				int out = (int) LibMatrixCountDistinct.estimateDistinctValues(in, op).getValue(0, 0);
 				int count = out;
 				boolean success = Math.abs(nrUnique - count) <= nrUnique * epsilon;
 				StringBuilder sb = new StringBuilder();
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index e808cb5a76..5a7eccc447 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -20,6 +20,8 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
 
 public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 
@@ -51,4 +53,50 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 	public void setUp() {
 		super.addTestConfiguration();
 	}
+
+	@Test
+	public void testCPSparseLargeDefaultMCSR() {
+		Types.ExecType ex = Types.ExecType.CP;
+
+		int actualDistinctCount = 10;
+		int rows = 1000, cols = 10000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+	}
+
+	@Test
+	public void testCPSparseLargeCSR() {
+		int actualDistinctCount = 10;
+		int rows = 1000, cols = 10000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+				tolerance);
+	}
+
+	@Test
+	public void testCPSparseLargeCOO() {
+		int actualDistinctCount = 10;
+		int rows = 1000, cols = 10000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+				tolerance);
+	}
+
+	@Test
+	public void testCPDenseLarge() {
+		Types.ExecType ex = Types.ExecType.CP;
+
+		int actualDistinctCount = 100;
+		int rows = 1000, cols = 10000;
+		double sparsity = 0.9;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index 05a125636f..c9aa75e375 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -20,6 +20,8 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
 
 public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 
@@ -51,4 +53,50 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 	public void setUp() {
 		super.addTestConfiguration();
 	}
+
+	@Test
+	public void testCPSparseLargeDefaultMCSR() {
+		Types.ExecType ex = Types.ExecType.CP;
+
+		int actualDistinctCount = 10;
+		int rows = 10000, cols = 1000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+	}
+
+	@Test
+	public void testCPSparseLargeCSR() {
+		int actualDistinctCount = 10;
+		int rows = 10000, cols = 1000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+				tolerance);
+	}
+
+	@Test
+	public void testCPSparseLargeCOO() {
+		int actualDistinctCount = 10;
+		int rows = 10000, cols = 1000;
+		double sparsity = 0.1;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+				tolerance);
+	}
+
+	@Test
+	public void testCPDenseLarge() {
+		Types.ExecType ex = Types.ExecType.CP;
+
+		int actualDistinctCount = 100;
+		int rows = 10000, cols = 1000;
+		double sparsity = 0.9;
+		double tolerance = actualDistinctCount * this.percentTolerance;
+
+		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+	}
 }
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 041cf51a00..5bf850d49a 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysds.test.functions.countDistinct;
 
-import static org.junit.Assert.assertTrue;
-
 import org.apache.sysds.common.Types;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
+import static org.junit.Assert.assertTrue;
+
 public abstract class CountDistinctBase extends AutomatedTestBase {
 	protected double percentTolerance = 0.0;
 	protected double baseTolerance = 0.0001;
@@ -88,7 +88,7 @@ public abstract class CountDistinctBase extends AutomatedTestBase {
 		}
 	}
 
-	private double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int cols, int rows, long expectedValue) {
+	protected double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int cols, int rows, long expectedValue) {
 		double[][] expectedResult;
 		if(dir.isRow()) {
 			expectedResult = new double[rows][1];
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
index df2ea8a0ce..a880c0d0dd 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -20,6 +20,12 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Test;
@@ -44,24 +50,15 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
 		this.percentTolerance = 0.2;
 	}
 
+	/**
+	 * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch.
+	 */
 	@Test
-	public void testCPSparseLarge() {
+	public void testCPDenseXLarge() {
 		Types.ExecType ex = Types.ExecType.CP;
 
-		int actualDistinctCount = 10;
-		int rows = 10000, cols = 1000;
-		double sparsity = 0.1;
-		double tolerance = actualDistinctCount * this.percentTolerance;
-
-		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
-	}
-
-	@Test
-	public void testCPDenseLarge() {
-		Types.ExecType ex = Types.ExecType.CP;
-
-		int actualDistinctCount = 100;
-		int rows = 10000, cols = 1000;
+		int actualDistinctCount = 10000;
+		int rows = 10000, cols = 10000;
 		double sparsity = 0.9;
 		double tolerance = actualDistinctCount * this.percentTolerance;
 
@@ -139,4 +136,22 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
 
 		countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, execType, tolerance);
 	}
+
+	protected void testCPSparseLarge(SparseBlock.Type sparseBlockType, Types.Direction direction, int rows, int cols,
+									 int actualDistinctCount, double sparsity, double tolerance) {
+		MatrixBlock blkIn = TestUtils.round(TestUtils.generateTestMatrixBlock(rows, cols, 0, actualDistinctCount, sparsity, 7));
+		if (!blkIn.isInSparseFormat()) {
+			blkIn.denseToSparse(false);
+		}
+		blkIn = new MatrixBlock(blkIn, sparseBlockType, true);
+
+		CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX)
+				.setDirection(direction)
+				.setIndexFunction(ReduceCol.getReduceColFnObject());
+
+		MatrixBlock blkOut = LibMatrixCountDistinct.estimateDistinctValues(blkIn, op);
+		double[][] expectedMatrix = getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount);
+
+		TestUtils.compareMatrices(expectedMatrix, blkOut, tolerance, "");
+	}
 }