You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ar...@apache.org on 2022/08/03 14:53:35 UTC
[systemds] branch main updated: [SYSTEMDS-3390] Improve performance of countDistinctApprox()
This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 75385a9493 [SYSTEMDS-3390] Improve performance of countDistinctApprox()
75385a9493 is described below
commit 75385a949312a8ea3591559633f9961e811ac067
Author: Badrul Chowdhury <ba...@gmail.com>
AuthorDate: Wed Aug 3 16:03:42 2022 +0200
[SYSTEMDS-3390] Improve performance of countDistinctApprox()
This patch improves the performance of countDistinctApprox() row/col
aggregation by replacing matrix slicing with direct ops on the input
matrix. This has the most impact in local CP execution mode, as
some simple experiments show:
(numbers represent average over 3 runs)
1. row aggregation
(A) dense: 10000x1000 with sparsity=0.9
1.198s with slicing, 0.874s without slicing - a 27% improvement
(B) sparse: 10000x1000 with sparsity=0.1
0.528s with slicing, 0.512s without slicing - a 3% improvement
As expected, the larger and the more dense the input matrix,
the larger the performance improvement.
2. col aggregation
(A) dense: 1000x10000 with sparsity=0.9
1.186s with slicing, 1.036s without slicing - a 13% improvement
(B) sparse: 1000x10000 with sparsity=0.1
1.272s with slicing, 0.647s without slicing - a 49% improvement
In this case, the sparser the input matrix, the larger the performance
improvement. This phenomenon is a result of employing a hash map M
in the implementation: as the RxC input matrix becomes denser, M's
keyset size approaches C, and the performance approaches the baseline,
which uses slicing.
Closes #1650
---
.../cp/AggregateUnaryCPInstruction.java | 41 ++-
.../matrix/data/LibMatrixCountDistinct.java | 323 +++++++++++++++++----
.../runtime/matrix/data/sketch/MatrixSketch.java | 24 +-
.../CountDistinctApproxSketch.java | 2 +-
.../data/sketch/countdistinctapprox/KMVSketch.java | 224 +++++++++-----
.../countdistinctapprox/SmallestPriorityQueue.java | 5 +
.../test/component/matrix/CountDistinctTest.java | 2 +-
.../countDistinct/CountDistinctApproxCol.java | 48 +++
.../countDistinct/CountDistinctApproxRow.java | 48 +++
.../functions/countDistinct/CountDistinctBase.java | 6 +-
.../countDistinct/CountDistinctRowOrColBase.java | 45 ++-
11 files changed, 576 insertions(+), 192 deletions(-)
diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index fbcf6ff7f3..ddf00ada2b 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -82,8 +82,12 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
in1, out, AUType.valueOf(opcode.toUpperCase()), opcode, str);
}
else if(opcode.equalsIgnoreCase("uacd")){
- return new AggregateUnaryCPInstruction(new SimpleOperator(null),
- in1, out, AUType.COUNT_DISTINCT, opcode, str);
+ CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT)
+ .setDirection(Types.Direction.RowCol)
+ .setIndexFunction(ReduceAll.getReduceAllFnObject());
+
+ return new AggregateUnaryCPInstruction(op, in1, out, AUType.COUNT_DISTINCT,
+ opcode, str);
}
else if(opcode.equalsIgnoreCase("uacdap")){
CountDistinctOperator op = new CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
@@ -199,9 +203,15 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
if( !ec.getVariables().keySet().contains(input1.getName()) )
throw new DMLRuntimeException("Variable '" + input1.getName() + "' does not exist.");
MatrixBlock input = ec.getMatrixInput(input1.getName());
- CountDistinctOperator op = new CountDistinctOperator(_type);
+
+ // Operator type: test and cast
+ if (!(_optr instanceof CountDistinctOperator)) {
+ throw new DMLRuntimeException("Operator should be instance of " + CountDistinctOperator.class.getSimpleName());
+ }
+ CountDistinctOperator op = (CountDistinctOperator) (_optr);
+
//TODO add support for row or col count distinct.
- int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
+ int res = (int) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
ec.releaseMatrixInput(input1.getName());
ec.setScalarOutput(output_name, new IntObject(res));
break;
@@ -219,27 +229,16 @@ public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
CountDistinctOperator op = (CountDistinctOperator) _optr; // It is safe to cast at this point
if (op.getDirection().isRowCol()) {
- int res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
+ long res = (long) LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
ec.releaseMatrixInput(input1.getName());
ec.setScalarOutput(output_name, new IntObject(res));
- } else if (op.getDirection().isRow()) {
- //TODO Do not slice out the matrix but directly process on the input
- MatrixBlock res = input.slice(0, input.getNumRows() - 1, 0, 0);
- for (int i = 0; i < input.getNumRows(); ++i) {
- res.setValue(i, 0, LibMatrixCountDistinct.estimateDistinctValues(input.slice(i, i), op));
- }
- ec.releaseMatrixInput(input1.getName());
- ec.setMatrixOutput(output_name, res);
- } else if (op.getDirection().isCol()) {
- //TODO Do not slice out the matrix but directly process on the input
- MatrixBlock res = input.slice(0, 0, 0, input.getNumColumns() - 1);
- for (int j = 0; j < input.getNumColumns(); ++j) {
- res.setValue(0, j, LibMatrixCountDistinct.estimateDistinctValues(input.slice(0, input.getNumRows() - 1, j, j), op));
- }
+ } else { // Row/Col
+ // Note that for each row, the max number of distinct values < NNZ < max number of columns = 1000:
+ // Since count distinct approximate estimates are unreliable for values < 1024,
+ // we will force a naive count.
+ MatrixBlock res = LibMatrixCountDistinct.estimateDistinctValues(input, op);
ec.releaseMatrixInput(input1.getName());
ec.setMatrixOutput(output_name, res);
- } else {
- throw new DMLRuntimeException("Direction for CountDistinctOperator not recognized");
}
break;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 4b13abc995..1198b18dd5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -19,81 +19,93 @@
package org.apache.sysds.runtime.matrix.data;
-import java.util.HashSet;
-import java.util.Set;
+import java.util.*;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLException;
-import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.data.DenseBlock;
-import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.*;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.utils.Hash.HashType;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
/**
* This class contains various methods for counting the number of distinct values inside a MatrixBlock
*/
public interface LibMatrixCountDistinct {
- static final Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+ Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
/**
* The minimum number NonZero of cells in the input before using approximate techniques for counting number of
* distinct values.
*/
- public static int minimumSize = 1024;
+ int minimumSize = 1024;
/**
* Public method to count the number of distinct values inside a matrix. Depending on which CountDistinctOperator
* selected it either gets the absolute number or a estimated value.
*
* TODO: Support counting num distinct in rows, or columns axis.
- *
- * TODO: Add support for distributed spark operations
- *
* TODO: If the MatrixBlock type is CompressedMatrix, simply read the values from the ColGroups.
*
* @param in the input matrix to count number distinct values in
* @param op the selected operator to use
- * @return the distinct count
+ * @return A matrix block containing the absolute distinct count for the entire input or along given row/col axis
*/
- public static int estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
- int res = 0;
+ static MatrixBlock estimateDistinctValues(MatrixBlock in, CountDistinctOperator op) {
if(op.getOperatorType() == CountDistinctOperatorTypes.KMV &&
(op.getHashType() == HashType.ExpHash || op.getHashType() == HashType.StandardJava)) {
throw new DMLException(
"Invalid hashing configuration using " + op.getHashType() + " and " + op.getOperatorType());
}
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL) {
- throw new NotImplementedException("HyperLogLog not implemented");
+ throw new NotImplementedException("HyperLogLog has not been implemented yet");
+ }
+
+ // shortcut in the simplest case.
+ if(in.getLength() == 1 || in.isEmpty()) {
+ return new MatrixBlock(1);
}
- // shortcut in simplest case.
- if(in.getLength() == 1 || in.isEmpty())
- return 1;
- else if(in.getNonZeros() < minimumSize) {
- // Just use naive implementation if the number of nonZeros values size is small.
- res = countDistinctValuesNaive(in);
+
+ long averageNnzPerRowOrCol;
+ if (op.getDirection().isRowCol()) {
+ averageNnzPerRowOrCol = in.getNonZeros();
+ } else if (op.getDirection().isRow()) {
+ // The average nnz per row is susceptible to skew. However, given that CP instructions is limited to
+ // matrices of size at most 1000 x 1000, the performance impact of using naive counting over sketch per
+ // row/col as determined by the average is negligible. Besides, the average is the simplest measure
+ // available without calculating nnz per row/col.
+ averageNnzPerRowOrCol = (long) Math.floor(in.getNonZeros() / (double) in.getNumRows());
+ } else if (op.getDirection().isCol()) {
+ averageNnzPerRowOrCol = (long) Math.floor(in.getNonZeros() / (double) in.getNumColumns());
+ } else {
+ throw new IllegalArgumentException("Unrecognized direction " + op.getDirection());
}
- else {
+
+ // Result is a dense 1x1 (RowCol), Mx1 (Row), or 1xN (Col) matrix
+ MatrixBlock res;
+ if (averageNnzPerRowOrCol < minimumSize) {
+ // Resort to naive counting for small enough matrices
+ res = countDistinctValuesNaive(in, op);
+ } else {
switch(op.getOperatorType()) {
case COUNT:
- res = countDistinctValuesNaive(in);
+ res = countDistinctValuesNaive(in, op);
break;
case KMV:
- res = new KMVSketch(op).getScalarValue(in);
+ res = new KMVSketch(op).getValue(in);
break;
default:
- throw new DMLException("Invalid or not implemented Estimator Type");
+ throw new DMLException("Invalid estimator type for aggregation: " + LibMatrixCountDistinct.class.getSimpleName());
}
}
- if(res <= 0)
- throw new DMLRuntimeException("Impossible estimate of distinct values");
return res;
}
@@ -102,66 +114,257 @@ public interface LibMatrixCountDistinct {
*
* Benefit: precise, but uses memory, on the scale of inputs number of distinct values.
*
- * @param in The input matrix to count number distinct values in
- * @return The absolute distinct count
+ * @param blkIn The input matrix to count number distinct values in
+ * @return A matrix block containing the absolute distinct count for the entire input or along given row/col axis
*/
- private static int countDistinctValuesNaive(MatrixBlock in) {
+ private static MatrixBlock countDistinctValuesNaive(MatrixBlock blkIn, CountDistinctOperator op) {
+
+ if (blkIn.isEmpty()) {
+ return new MatrixBlock(1);
+ }
+ else if(blkIn instanceof CompressedMatrixBlock) {
+ throw new NotImplementedException("countDistinct() does not support CompressedMatrixBlock");
+ }
+
Set<Double> distinct = new HashSet<>();
+ MatrixBlock blkOut;
double[] data;
- if(in.isEmpty())
- return 1;
- else if(in instanceof CompressedMatrixBlock)
- throw new NotImplementedException();
- long nonZeros = in.getNonZeros();
+ if (op.getDirection().isRowCol()) {
+ blkOut = new MatrixBlock(1, 1, false);
- if(nonZeros != -1 && nonZeros < in.getNumColumns() * in.getNumRows()) {
- distinct.add(0d);
- }
+ long distinctCount = 0;
+ long nonZeros = blkIn.getNonZeros();
- if(in.sparseBlock != null) {
- SparseBlock sb = in.sparseBlock;
+ // Check if input matrix contains any 0 values for RowCol case.
+ // This does not apply to row/col case, where we count nnz per row or col during iteration.
+ if(nonZeros != -1 && nonZeros < (long) blkIn.getNumColumns() * blkIn.getNumRows()) {
+ distinct.add(0d);
+ }
- if(in.sparseBlock.isContiguous()) {
- data = sb.values(0);
- countDistinctValuesNaive(data, distinct);
+ if(blkIn.getSparseBlock() != null) {
+ SparseBlock sb = blkIn.getSparseBlock();
+ if(blkIn.getSparseBlock().isContiguous()) {
+ // COO, CSR
+ data = sb.values(0);
+ distinctCount = countDistinctValuesNaive(data, distinct);
+ } else {
+ // MCSR
+ for(int i = 0; i < blkIn.getNumRows(); i++) {
+ if(!sb.isEmpty(i)) {
+ data = blkIn.getSparseBlock().values(i);
+ distinctCount = countDistinctValuesNaive(data, distinct);
+ }
+ }
+ }
+ } else if(blkIn.getDenseBlock() != null) {
+ DenseBlock db = blkIn.getDenseBlock();
+ for (int i = 0; i <= db.numBlocks(); i++) {
+ data = db.valuesAt(i);
+ distinctCount = countDistinctValuesNaive(data, distinct);
+ }
}
- else {
- for(int i = 0; i < in.getNumRows(); i++) {
- if(!sb.isEmpty(i)) {
- data = in.sparseBlock.values(i);
+
+ blkOut.setValue(0, 0, distinctCount);
+ } else if (op.getDirection().isRow()) {
+ blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+ blkOut.allocateBlock();
+
+ if (blkIn.getDenseBlock() != null) {
+ // The naive approach would be to iterate through every (i, j) in the input. However, can do better
+ // by exploiting the physical layout of dense blocks - contiguous blocks in row-major order - in memory.
+ DenseBlock db = blkIn.getDenseBlock();
+ for (int bix=0; bix<db.numBlocks(); ++bix) {
+ data = db.valuesAt(bix);
+ for (int rix=bix * db.blockSize(); rix<blkIn.getNumRows(); rix++) {
+ distinct.clear();
+ for (int cix=0; cix<blkIn.getNumColumns(); ++cix) {
+ distinct.add(data[db.pos(rix, cix)]);
+ }
+ blkOut.setValue(rix, 0, distinct.size());
+ }
+ }
+ } else if (blkIn.getSparseBlock() != null) {
+ // Each sparse block type - COO, CSR, MCSR - has a different data representation, which we will exploit
+ // separately.
+ SparseBlock sb = blkIn.getSparseBlock();
+ if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.MCSR)) {
+ // Currently, SparseBlockIterator only provides an interface for cell-wise iteration.
+ // TODO Explore row-wise and column-wise methods for SparseBlockIterator
+
+ // MCSR enables O(1) access to column values per row
+ for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+ if (sb.isEmpty(rix)) {
+ continue;
+ }
+ distinct.clear();
+ data = sb.values(rix);
countDistinctValuesNaive(data, distinct);
+ blkOut.setValue(rix, 0, distinct.size());
+ }
+ } else if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+ // Casting is safe given if-condition
+ SparseBlockCSR csrBlock = (SparseBlockCSR) sb;
+
+ // Data lies in one contiguous block in CSR format. We will iterate in row-major using O(1) op
+ // size(row) to determine the number of columns per row.
+ data = csrBlock.values();
+ // We want to iterate through all rows to keep track of the row index for constructing the output
+ for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+ if (csrBlock.isEmpty(rix)) {
+ continue;
+ }
+ distinct.clear();
+ int rpos = csrBlock.pos(rix);
+ int clen = csrBlock.size(rix);
+ for (int colOffset=0; colOffset<clen; ++colOffset) {
+ distinct.add(data[rpos + colOffset]);
+ }
+ blkOut.setValue(rix, 0, distinct.size());
+ }
+ } else { // COO
+ if (!(sb instanceof SparseBlockCOO)) {
+ throw new IllegalArgumentException("Input matrix is of unrecognized type: "
+ + sb.getClass().getSimpleName());
+ }
+ SparseBlockCOO cooBlock = (SparseBlockCOO) sb;
+
+ // For COO, we want to avoid using pos(row) and size(row) as they use binary search, which is a
+ // O(log N) op. Also, isEmpty(row) uses pos(row) internally.
+ int[] rixs = cooBlock.rowIndexes();
+ data = cooBlock.values();
+ int i = 0; // data iterator
+ int rix = 0; // row index
+ while (rix < cooBlock.numRows() && i < rixs.length) {
+ distinct.clear();
+ while (i + 1 < rixs.length && rixs[i] == rixs[i + 1]) {
+ distinct.add(data[i]);
+ i++;
+ }
+ if (i + 1 < rixs.length) { // rixs[i] != rixs[i + 1]
+ distinct.add(data[i]);
+ }
+ blkOut.setValue(rix, 0, distinct.size());
+ rix = (i + 1 < rixs.length)? rixs[i + 1] : rix;
+ i++;
}
}
}
- }
- else if(in.denseBlock != null) {
- DenseBlock db = in.denseBlock;
- for(int i = 0; i <= db.numBlocks(); i++) {
- data = db.valuesAt(i);
- countDistinctValuesNaive(data, distinct);
+ } else { // Col aggregation
+ blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumRows());
+ blkOut.allocateBlock();
+
+ // All dense and sparse formats (COO, CSR, MCSR) are row-major formats, so there is no obvious way to iterate
+ // in column-major order besides iterating through every (i, j) pair. getValue() skips over empty cells in CSR
+ // and MCSR formats, but not so in COO format. This results in O(log2 R * log2 C) time for every lookup,
+ // amounting to O(RC * log2R * log2C) for the whole block (R, C <= 1000 in CP case). We will eschew this
+ // approach in favor of one using a hash map M of (column index, distinct values) to obtain a pseudo column-major
+ // grouping of distinct values instead. Given this setup, we will simply iterate over the input
+ // (according to specific dense/sparse format) in row-major order and populate M. Finally, an O(C) iteration
+ // over M will yield the final result.
+ Map<Integer, Set<Double>> distinctValuesByCol = new HashMap<>();
+ if (blkIn.getDenseBlock() != null) {
+ DenseBlock db = blkIn.getDenseBlock();
+ for (int bix=0; bix<db.numBlocks(); ++bix) {
+ data = db.valuesAt(bix);
+ for (int cix=0; cix<blkIn.getNumColumns(); ++cix) {
+ Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+ for (int rix=bix * db.blockSize(); rix<blkIn.getNumRows(); rix++) {
+ double val = data[db.pos(rix, cix)];
+ distinctValues.add(val);
+ }
+ distinctValuesByCol.put(cix, distinctValues);
+ }
+ }
+ } else if (blkIn.getSparseBlock() != null) {
+ SparseBlock sb = blkIn.getSparseBlock();
+ if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.MCSR)) {
+ for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+ if (sb.isEmpty(rix)) {
+ continue;
+ }
+ int[] cixs = sb.indexes(rix);
+ data = sb.values(rix);
+ for (int j=0; j<sb.size(rix); ++j) {
+ int cix = cixs[j];
+ Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+ distinctValues.add(data[j]);
+ distinctValuesByCol.put(cix, distinctValues);
+ }
+ }
+ } else if (SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+ SparseBlockCSR csrBlock = (SparseBlockCSR) sb;
+ data = csrBlock.values();
+ for (int rix=0; rix<blkIn.getNumRows(); ++rix) {
+ if (csrBlock.isEmpty(rix)) {
+ continue;
+ }
+ distinct.clear();
+ int rpos = csrBlock.pos(rix);
+ int clen = csrBlock.size(rix);
+ int[] cixs = csrBlock.indexes();
+ for (int colOffset=0; colOffset<clen; ++colOffset) {
+ int cix = cixs[rpos + colOffset];
+ Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+ distinctValues.add(data[rpos + colOffset]);
+ distinctValuesByCol.put(cix, distinctValues);
+ }
+ }
+ } else { // COO
+ if (!(sb instanceof SparseBlockCOO)) {
+ throw new IllegalArgumentException("Input matrix is of unrecognized type: "
+ + sb.getClass().getSimpleName());
+ }
+ SparseBlockCOO cooBlock = (SparseBlockCOO) sb;
+
+ int[] rixs = cooBlock.rowIndexes();
+ int[] cixs = cooBlock.indexes();
+ data = cooBlock.values();
+ int i = 0; // data iterator
+ while (i < rixs.length) {
+ while (i + 1 < rixs.length && rixs[i] == rixs[i + 1]) {
+ int cix = cixs[i];
+ Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+ distinctValues.add(data[i]);
+ distinctValuesByCol.put(cix, distinctValues);
+ i++;
+ }
+ if (i + 1 < rixs.length) {
+ int cix = cixs[i];
+ Set<Double> distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+ distinctValues.add(data[i]);
+ distinctValuesByCol.put(cix, distinctValues);
+ }
+ i++;
+ }
+ }
+ }
+ // Fill in output block with column aggregation results
+ for (int cix : distinctValuesByCol.keySet()) {
+ blkOut.setValue(0, cix, distinctValuesByCol.get(cix).size());
}
}
- return distinct.size();
+ return blkOut;
}
- private static Set<Double> countDistinctValuesNaive(double[] valuesPart, Set<Double> distinct) {
- for(double v : valuesPart)
+ private static long countDistinctValuesNaive(double[] valuesPart, Set<Double> distinct) {
+ for(double v : valuesPart)
distinct.add(v);
- return distinct;
+
+ return distinct.size();
}
- public static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, CountDistinctOperator op) {
+ static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, CountDistinctOperator op) {
if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
- return new KMVSketch(op).getMatrixValue(arg0);
+ return new KMVSketch(op).getValueFromSketch(arg0);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
throw new NotImplementedException("Not implemented yet");
else
throw new NotImplementedException("Not implemented yet");
}
- public static CorrMatrixBlock createSketch(MatrixBlock blkIn, CountDistinctOperator op) {
+ static CorrMatrixBlock createSketch(MatrixBlock blkIn, CountDistinctOperator op) {
if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
return new KMVSketch(op).create(blkIn);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
@@ -170,7 +373,7 @@ public interface LibMatrixCountDistinct {
throw new NotImplementedException("Not implemented yet");
}
- public static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, CorrMatrixBlock arg1, CountDistinctOperator op) {
+ static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, CorrMatrixBlock arg1, CountDistinctOperator op) {
if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
return new KMVSketch(op).union(arg0, arg1);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
index f9c5f63a03..6feb52a140 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
@@ -22,15 +22,15 @@ package org.apache.sysds.runtime.matrix.data.sketch;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-public interface MatrixSketch<T> {
+public interface MatrixSketch {
/**
- * Get scalar distinct count from a input matrix block.
+ * Get scalar distinct count from an input matrix block.
*
- * @param blkIn A input block to estimate the number of distinct values in
- * @return The distinct count estimate
+ * @param blkIn An input block to estimate the number of distinct values in
+ * @return The result matrix block containing the distinct count estimate
*/
- T getScalarValue(MatrixBlock blkIn);
+ MatrixBlock getValue(MatrixBlock blkIn);
/**
* Obtain matrix distinct count value from estimation Used for estimating distinct in rows or columns.
@@ -38,31 +38,31 @@ public interface MatrixSketch<T> {
* @param blkIn The sketch block to extract the count from
* @return The result matrix block
*/
- public MatrixBlock getMatrixValue(CorrMatrixBlock blkIn);
+ MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn);
/**
- * Create a initial sketch of a given block.
+ * Create an initial sketch of a given block.
*
* @param blkIn A block to process
* @return A sketch
*/
- public CorrMatrixBlock create(MatrixBlock blkIn);
+ CorrMatrixBlock create(MatrixBlock blkIn);
/**
* Union two sketches together to from a combined sketch.
*
* @param arg0 Sketch one
* @param arg1 Sketch two
- * @return The combined sketch
+ * @return The sketch union is a sketch
*/
- public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
+ CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
/**
* Intersect two sketches
*
* @param arg0 Sketch one
* @param arg1 Sketch two
- * @return The intersected sketch
+ * @return The sketch intersection is a sketch
*/
- public CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
+ CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
index 9893e098c5..d5df3b241a 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
@@ -26,7 +26,7 @@ import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
// Package private
-abstract class CountDistinctApproxSketch implements MatrixSketch<Integer> {
+abstract class CountDistinctApproxSketch implements MatrixSketch {
CountDistinctOperator op;
CountDistinctApproxSketch(Operator op) {
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
index 01cfb289e5..31e7d15c5d 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -22,7 +22,6 @@ package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.DenseBlock;
@@ -52,49 +51,89 @@ public class KMVSketch extends CountDistinctApproxSketch {
}
@Override
- public Integer getScalarValue(MatrixBlock in) {
-
- // D is the number of possible distinct values in the MatrixBlock.
- // plus 1 to take account of 0 input.
- long D = in.getNonZeros() + 1;
-
- /**
- * To ensure that the likelihood to hash to the same value we need O(D^2) positions to hash to assign. If the
- * value is higher than int (which is the area we hash to) then use Integer Max value as largest hashing space.
- */
- long tmp = D * D;
- int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
- /**
- * The estimator is asymptotically unbiased as k becomes large, but memory usage also scales with k. Furthermore k
- * value must be within range: D >> k >> 0
- */
- int k = D > 64 ? 64 : (int) D;
+ public MatrixBlock getValue(MatrixBlock blkIn) {
+
+ if (this.op.getDirection().isRowCol()) {
+ // D is the number of possible distinct values in the MatrixBlock.
+ // plus 1 to take account of 0 input.
+ long D = blkIn.getNonZeros() + 1;
+
+ /**
+ * To ensure that the likelihood to hash to the same value we need O(D^2) positions to hash to assign. If the
+ * value is higher than int (which is the area we hash to) then use Integer Max value as largest hashing space.
+ */
+ long tmp = D * D;
+ int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+ /**
+ * The estimator is asymptotically unbiased as k becomes large, but memory usage also scales with k. Furthermore k
+ * value must be within range: D >> k >> 0
+ */
+ int k = D > 64 ? 64 : (int) D;
+
+ SmallestPriorityQueue spq = getKSmallestHashes(blkIn, k, M);
- SmallestPriorityQueue spq = getKSmallestHashes(in, k, M);
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("M not forced to int size: " + tmp);
+ LOG.debug("M: " + M);
+ LOG.debug("M: " + M);
+ LOG.debug("kth smallest hash:" + spq.peek());
+ LOG.debug("spq: " + spq);
+ }
- if(LOG.isDebugEnabled()) {
- LOG.debug("M not forced to int size: " + tmp);
- LOG.debug("M: " + M);
- LOG.debug("M: " + M);
- LOG.debug("kth smallest hash:" + spq.peek());
- LOG.debug("spq: " + spq.toString());
- }
- if(spq.size() < k) {
- return spq.size();
- }
- else {
- double kthSmallestHash = spq.poll();
- double U_k = kthSmallestHash / (double) M;
- double estimate = (double) (k - 1) / U_k;
- double ceilEstimate = Math.min(estimate, (double) D);
+ long res = countDistinctValuesKMV(spq, k, M, D);
+ if(res <= 0) {
+ throw new DMLRuntimeException("Impossible estimate of distinct values");
+ }
- if(LOG.isDebugEnabled()) {
- LOG.debug("U_k : " + U_k);
- LOG.debug("Estimate: " + estimate);
- LOG.debug("Ceil worst case: " + D);
+ // Result is a 1x1 matrix block
+ return new MatrixBlock(res);
+
+ } else if (this.op.getDirection().isRow()) {
+ long D = (long) Math.floor(blkIn.getNonZeros() / (double) blkIn.getNumRows()) + 1;
+ long tmp = D * D;
+ int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+ int k = D > 64 ? 64 : (int) D;
+
+ MatrixBlock resultMatrix = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+ resultMatrix.allocateBlock();
+
+ SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+ for (int i=0; i<blkIn.getNumRows(); ++i) {
+ for (int j=0; j<blkIn.getNumColumns(); ++j) {
+ spq.add(blkIn.getValue(i, j));
+ }
+
+ long res = countDistinctValuesKMV(spq, k, M, D);
+ resultMatrix.setValue(i, 0, res);
+
+ spq.clear();
+ }
+
+ return resultMatrix;
+
+ } else { // Col
+ long D = (long) Math.floor(blkIn.getNonZeros() / (double) blkIn.getNumColumns()) + 1;
+ long tmp = D * D;
+ int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : (int) tmp;
+ int k = D > 64 ? 64 : (int) D;
+
+ MatrixBlock resultMatrix = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
+ resultMatrix.allocateBlock();
+
+ SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+ for (int j=0; j<blkIn.getNumColumns(); ++j) {
+ for (int i=0; i<blkIn.getNumRows(); ++i) {
+ spq.add(blkIn.getValue(i, j));
+ }
+
+ long res = countDistinctValuesKMV(spq, k, M, D);
+ resultMatrix.setValue(0, j, res);
+
+ spq.clear();
}
- return (int) ceilEstimate;
+
+ return resultMatrix;
}
}
@@ -146,21 +185,45 @@ public class KMVSketch extends CountDistinctApproxSketch {
}
}
+ private long countDistinctValuesKMV(SmallestPriorityQueue spq, int k, int M, long D) {
+ long res;
+ if(spq.size() < k) {
+ res = spq.size();
+ }
+ else {
+ double kthSmallestHash = spq.poll();
+ double U_k = kthSmallestHash / (double) M;
+ double estimate = (double) (k - 1) / U_k;
+ double ceilEstimate = Math.min(estimate, (double) D);
+
+ if(LOG.isDebugEnabled()) {
+ LOG.debug("U_k : " + U_k);
+ LOG.debug("Estimate: " + estimate);
+ LOG.debug("Ceil worst case: " + D);
+ }
+ res = Math.round(ceilEstimate);
+ }
+
+ return res;
+ }
+
@Override
- public MatrixBlock getMatrixValue(CorrMatrixBlock arg0) {
+ public MatrixBlock getValueFromSketch(CorrMatrixBlock arg0) {
MatrixBlock blkIn = arg0.getValue();
- if(op.getDirection() == Types.Direction.Row) {
- // 1000 x 1 blkOut -> slice out the first column of the matrix
- MatrixBlock blkOut = blkIn.slice(0, blkIn.getNumRows() - 1, 0, 0);
+ if(op.getDirection().isRow()) {
+ // 1000 x 1 blkOut
+ MatrixBlock blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+ blkOut.allocateBlock();
for(int i = 0; i < blkIn.getNumRows(); ++i) {
getDistinctCountFromSketchByIndex(arg0, i, blkOut);
}
return blkOut;
}
- else if(op.getDirection() == Types.Direction.Col) {
- // 1 x 1000 blkOut -> slice out the first row of the matrix
- MatrixBlock blkOut = blkIn.slice(0, 0, 0, blkIn.getNumColumns() - 1);
+ else if(op.getDirection().isCol()) {
+ // 1 x 1000 blkOut
+ MatrixBlock blkOut = new MatrixBlock(1, blkIn.getNumColumns(), false, blkIn.getNumColumns());
+ blkOut.allocateBlock();
for(int j = 0; j < blkIn.getNumColumns(); ++j) {
getDistinctCountFromSketchByIndex(arg0, j, blkOut);
}
@@ -169,8 +232,9 @@ public class KMVSketch extends CountDistinctApproxSketch {
}
else { // op.getDirection().isRowCol()
- // 1 x 1 blkOut -> slice out the first row and column of the matrix
- MatrixBlock blkOut = blkIn.slice(0, 0, 0, 0);
+ // 1 x 1 blkOut
+ MatrixBlock blkOut = new MatrixBlock(1, 1, false, 1);
+ blkOut.allocateBlock();
getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
return blkOut;
@@ -181,41 +245,43 @@ public class KMVSketch extends CountDistinctApproxSketch {
MatrixBlock blkIn = arg0.getValue();
MatrixBlock blkInCorr = arg0.getCorrection();
- if(op.getOperatorType() == CountDistinctOperatorTypes.KMV) {
- double kthSmallestHash;
- if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
- kthSmallestHash = blkIn.getValue(idx, 0);
- }
- else { // op.getDirection().isCol()
- kthSmallestHash = blkIn.getValue(0, idx);
- }
+ if(op.getOperatorType() != CountDistinctOperatorTypes.KMV) {
+ throw new IllegalArgumentException(this.getClass().getSimpleName() + " cannot use " + op.getOperatorType());
+ }
- double nHashes = blkInCorr.getValue(idx, 0);
- double k = blkInCorr.getValue(idx, 1);
- double D = blkInCorr.getValue(idx, 2);
+ double kthSmallestHash;
+ if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+ kthSmallestHash = blkIn.getValue(idx, 0);
+ }
+ else { // op.getDirection().isCol()
+ kthSmallestHash = blkIn.getValue(0, idx);
+ }
- double D2 = D * D;
- double M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : D2;
+ double nHashes = blkInCorr.getValue(idx, 0);
+ double k = blkInCorr.getValue(idx, 1);
+ double D = blkInCorr.getValue(idx, 2);
- double ceilEstimate;
- if(nHashes != 0 && nHashes < k) {
- ceilEstimate = nHashes;
- }
- else if(nHashes == 0) {
- ceilEstimate = 1;
- }
- else {
- double U_k = kthSmallestHash / M;
- double estimate = (k - 1) / U_k;
- ceilEstimate = Math.min(estimate, D);
- }
+ double D2 = D * D;
+ double M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : D2;
- if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
- blkOut.setValue(idx, 0, ceilEstimate);
- }
- else { // op.getDirection().isCol()
- blkOut.setValue(0, idx, ceilEstimate);
- }
+ double ceilEstimate;
+ if(nHashes != 0 && nHashes < k) {
+ ceilEstimate = nHashes;
+ }
+ else if(nHashes == 0) {
+ ceilEstimate = 1;
+ }
+ else {
+ double U_k = kthSmallestHash / M;
+ double estimate = (k - 1) / U_k;
+ ceilEstimate = Math.min(estimate, D);
+ }
+
+ if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+ blkOut.setValue(idx, 0, ceilEstimate);
+ }
+ else { // op.getDirection().isCol()
+ blkOut.setValue(0, idx, ceilEstimate);
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
index 0a29028c66..f3f7336181 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
@@ -77,6 +77,11 @@ public class SmallestPriorityQueue {
return this.size() == 0;
}
+ public void clear() {
+ this.containedSet.clear();
+ this.smallestHashes.clear();
+ }
+
@Override
public String toString() {
return smallestHashes.toString();
diff --git a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index cd20b67c35..5de18c4b3e 100644
--- a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++ b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -145,7 +145,7 @@ public class CountDistinctTest {
});
}
else {
- int out = LibMatrixCountDistinct.estimateDistinctValues(in, op);
+ int out = (int) LibMatrixCountDistinct.estimateDistinctValues(in, op).getValue(0, 0);
int count = out;
boolean success = Math.abs(nrUnique - count) <= nrUnique * epsilon;
StringBuilder sb = new StringBuilder();
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index e808cb5a76..5a7eccc447 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -20,6 +20,8 @@
package org.apache.sysds.test.functions.countDistinct;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
@@ -51,4 +53,50 @@ public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
public void setUp() {
super.addTestConfiguration();
}
+
+ @Test
+ public void testCPSparseLargeDefaultMCSR() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 10000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparseLargeCSR() {
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 10000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+ tolerance);
+ }
+
+ @Test
+ public void testCPSparseLargeCOO() {
+ int actualDistinctCount = 10;
+ int rows = 1000, cols = 10000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+ tolerance);
+ }
+
+ @Test
+ public void testCPDenseLarge() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 100;
+ int rows = 1000, cols = 10000;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index 05a125636f..c9aa75e375 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -20,6 +20,8 @@
package org.apache.sysds.test.functions.countDistinct;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
@@ -51,4 +53,50 @@ public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
public void setUp() {
super.addTestConfiguration();
}
+
+ @Test
+ public void testCPSparseLargeDefaultMCSR() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+ }
+
+ @Test
+ public void testCPSparseLargeCSR() {
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ super.testCPSparseLarge(SparseBlock.Type.CSR, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+ tolerance);
+ }
+
+ @Test
+ public void testCPSparseLargeCOO() {
+ int actualDistinctCount = 10;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.1;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ super.testCPSparseLarge(SparseBlock.Type.COO, Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+ tolerance);
+ }
+
+ @Test
+ public void testCPDenseLarge() {
+ Types.ExecType ex = Types.ExecType.CP;
+
+ int actualDistinctCount = 100;
+ int rows = 10000, cols = 1000;
+ double sparsity = 0.9;
+ double tolerance = actualDistinctCount * this.percentTolerance;
+
+ countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
+ }
}
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 041cf51a00..5bf850d49a 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -19,13 +19,13 @@
package org.apache.sysds.test.functions.countDistinct;
-import static org.junit.Assert.assertTrue;
-
import org.apache.sysds.common.Types;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
+import static org.junit.Assert.assertTrue;
+
public abstract class CountDistinctBase extends AutomatedTestBase {
protected double percentTolerance = 0.0;
protected double baseTolerance = 0.0001;
@@ -88,7 +88,7 @@ public abstract class CountDistinctBase extends AutomatedTestBase {
}
}
- private double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int cols, int rows, long expectedValue) {
+ protected double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int cols, int rows, long expectedValue) {
double[][] expectedResult;
if(dir.isRow()) {
expectedResult = new double[rows][1];
diff --git a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
index df2ea8a0ce..a880c0d0dd 100644
--- a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
+++ b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -20,6 +20,12 @@
package org.apache.sysds.test.functions.countDistinct;
import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;
@@ -44,24 +50,15 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
this.percentTolerance = 0.2;
}
+ /**
+ * This is a contrived example where size of row/col > 1024, which forces the calculation of a sketch.
+ */
@Test
- public void testCPSparseLarge() {
+ public void testCPDenseXLarge() {
Types.ExecType ex = Types.ExecType.CP;
- int actualDistinctCount = 10;
- int rows = 10000, cols = 1000;
- double sparsity = 0.1;
- double tolerance = actualDistinctCount * this.percentTolerance;
-
- countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, ex, tolerance);
- }
-
- @Test
- public void testCPDenseLarge() {
- Types.ExecType ex = Types.ExecType.CP;
-
- int actualDistinctCount = 100;
- int rows = 10000, cols = 1000;
+ int actualDistinctCount = 10000;
+ int rows = 10000, cols = 10000;
double sparsity = 0.9;
double tolerance = actualDistinctCount * this.percentTolerance;
@@ -139,4 +136,22 @@ public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
countDistinctMatrixTest(getDirection(), actualDistinctCount, cols, rows, sparsity, execType, tolerance);
}
+
+ protected void testCPSparseLarge(SparseBlock.Type sparseBlockType, Types.Direction direction, int rows, int cols,
+ int actualDistinctCount, double sparsity, double tolerance) {
+ MatrixBlock blkIn = TestUtils.round(TestUtils.generateTestMatrixBlock(rows, cols, 0, actualDistinctCount, sparsity, 7));
+ if (!blkIn.isInSparseFormat()) {
+ blkIn.denseToSparse(false);
+ }
+ blkIn = new MatrixBlock(blkIn, sparseBlockType, true);
+
+ CountDistinctOperator op = new CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX)
+ .setDirection(direction)
+ .setIndexFunction(ReduceCol.getReduceColFnObject());
+
+ MatrixBlock blkOut = LibMatrixCountDistinct.estimateDistinctValues(blkIn, op);
+ double[][] expectedMatrix = getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount);
+
+ TestUtils.compareMatrices(expectedMatrix, blkOut, tolerance, "");
+ }
}