This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 75385a9493 [SYSTEMDS-3390] Improve performance of countDistinctApprox()
75385a9493 is described below

commit 75385a949312a8ea3591559633f9961e811ac067
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Wed Aug 3 16:03:42 2022 +0200

    [SYSTEMDS-3390] Improve performance of countDistinctApprox()
    
    This patch improves the performance of countDistinctApprox() row/col
    aggregation by replacing matrix slicing with direct ops on the input
    matrix. This has the most impact in local CP execution mode, as
    some simple experiments show:
    
    (numbers represent average over 3 runs)
    1. row aggregation
        (A) dense: 10000x1000 with sparsity=0.9
        1.198s with slicing, 0.874s without slicing - a 27% improvement
    
        (B) sparse: 10000x1000 with sparsity=0.1
        0.528s with slicing, 0.512s without slicing - a 3% improvement
    
    As expected, the larger and the more dense the input matrix,
    the larger the performance improvement.
    
    2. col aggregation
        (A) dense: 1000x10000 with sparsity=0.9
        1.186s with slicing, 1.036s without slicing - a 13% improvement
    
        (B) sparse: 1000x10000 with sparsity=0.1
        1.272s with slicing, 0.647s without slicing - a 49% improvement
    
    In this case, the sparser the input matrix, the larger the performance
    improvement. This phenomenon is a result of employing a hash map M
    in the implementation: as the RxC input matrix becomes denser, M's
    keyset size approaches C, and the performance approaches the baseline,
    which uses slicing.
    
    Closes #1650
---
 .../cp/AggregateUnaryCPInstruction.java            |  41 ++-
 .../matrix/data/LibMatrixCountDistinct.java        | 323 +++++++++++++++++----
 .../runtime/matrix/data/sketch/MatrixSketch.java   |  24 +-
 .../CountDistinctApproxSketch.java                 |   2 +-
 .../data/sketch/countdistinctapprox/KMVSketch.java | 224 +++++++++-----
 .../countdistinctapprox/SmallestPriorityQueue.java |   5 +
 .../test/component/matrix/CountDistinctTest.java   |   2 +-
 .../countDistinct/CountDistinctApproxCol.java      |  48 +++
 .../countDistinct/CountDistinctApproxRow.java      |  48 +++
 .../functions/countDistinct/CountDistinctBase.java |   6 +-
 .../countDistinct/CountDistinctRowOrColBase.java   |  45 ++-
 11 files changed, 576 insertions(+), 192 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index fbcf6ff7f3..ddf00ada2b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -82,8 +82,12 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                in1, out, AUType.valueOf(opcode.toUpperCase()), 
opcode, str);
                } 
                else if(opcode.equalsIgnoreCase("uacd")){
-                       return new AggregateUnaryCPInstruction(new 
SimpleOperator(null),
-                       in1, out, AUType.COUNT_DISTINCT, opcode, str);
+                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT)
+                                       .setDirection(Types.Direction.RowCol)
+                                       
.setIndexFunction(ReduceAll.getReduceAllFnObject());
+
+                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT,
+                                       opcode, str);
                }
                else if(opcode.equalsIgnoreCase("uacdap")){
                        CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
@@ -199,9 +203,15 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                if( 
!ec.getVariables().keySet().contains(input1.getName()) )
                                        throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
                                MatrixBlock input = 
ec.getMatrixInput(input1.getName());
-                               CountDistinctOperator op = new 
CountDistinctOperator(_type);
+
+                               // Operator type: test and cast
+                               if (!(_optr instanceof CountDistinctOperator)) {
+                                       throw new DMLRuntimeException("Operator 
should be instance of " + CountDistinctOperator.class.getSimpleName());
+                               }
+                               CountDistinctOperator op = 
(CountDistinctOperator) (_optr);
+
                                //TODO add support for row or col count 
distinct.
-                               int res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
+                               int res = (int) 
LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
                                ec.releaseMatrixInput(input1.getName());
                                ec.setScalarOutput(output_name, new 
IntObject(res));
                                break;
@@ -219,27 +229,16 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                CountDistinctOperator op = 
(CountDistinctOperator) _optr;  // It is safe to cast at this point
 
                                if (op.getDirection().isRowCol()) {
-                                       int res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
+                                       long res = (long) 
LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
                                        ec.releaseMatrixInput(input1.getName());
                                        ec.setScalarOutput(output_name, new 
IntObject(res));
-                               } else if (op.getDirection().isRow()) {
-                                       //TODO Do not slice out the matrix but 
directly process on the input
-                                       MatrixBlock res = input.slice(0, 
input.getNumRows() - 1, 0, 0);
-                                       for (int i = 0; i < input.getNumRows(); 
++i) {
-                                               res.setValue(i, 0, 
LibMatrixCountDistinct.estimateDistinctValues(input.slice(i, i), op));
-                                       }
-                                       ec.releaseMatrixInput(input1.getName());
-                                       ec.setMatrixOutput(output_name, res);
-                               } else if (op.getDirection().isCol()) {
-                                       //TODO Do not slice out the matrix but 
directly process on the input
-                                       MatrixBlock res = input.slice(0, 0, 0, 
input.getNumColumns() - 1);
-                                       for (int j = 0; j < 
input.getNumColumns(); ++j) {
-                                               res.setValue(0, j, 
LibMatrixCountDistinct.estimateDistinctValues(input.slice(0, input.getNumRows() 
- 1, j, j), op));
-                                       }
+                               } else {  // Row/Col
+                                       // Note that for each row, the max 
number of distinct values < NNZ < max number of columns = 1000:
+                                       // Since count distinct approximate 
estimates are unreliable for values < 1024,
+                                       // we will force a naive count.
+                                       MatrixBlock res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
                                        ec.releaseMatrixInput(input1.getName());
                                        ec.setMatrixOutput(output_name, res);
-                               } else {
-                                       throw new 
DMLRuntimeException("Direction for CountDistinctOperator not recognized");
                                }
 
                                break;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 4b13abc995..1198b18dd5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -19,81 +19,93 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
-import java.util.HashSet;
-import java.util.Set;
+import java.util.*;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLException;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.data.DenseBlock;
-import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.*;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 import org.apache.sysds.utils.Hash.HashType;
 
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
 /**
  * This class contains various methods for counting the number of distinct 
values inside a MatrixBlock
  */
 public interface LibMatrixCountDistinct {
-       static final Log LOG = 
LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+       Log LOG = LogFactory.getLog(LibMatrixCountDistinct.class.getName());
 
        /**
         * The minimum number NonZero of cells in the input before using 
approximate techniques for counting number of
         * distinct values.
         */
-       public static int minimumSize = 1024;
+       int minimumSize = 1024;
 
        /**
         * Public method to count the number of distinct values inside a 
matrix. Depending on which CountDistinctOperator
         * selected it either gets the absolute number or a estimated value.
         * 
         * TODO: Support counting num distinct in rows, or columns axis.
-        * 
-        * TODO: Add support for distributed spark operations
-        * 
         * TODO: If the MatrixBlock type is CompressedMatrix, simply read the 
values from the ColGroups.
         * 
         * @param in the input matrix to count number distinct values in
         * @param op the selected operator to use
-        * @return the distinct count
+        * @return A matrix block containing the absolute distinct count for 
the entire input or along given row/col axis
         */
-       public static int estimateDistinctValues(MatrixBlock in, 
CountDistinctOperator op) {
-               int res = 0;
+       static MatrixBlock estimateDistinctValues(MatrixBlock in, 
CountDistinctOperator op) {
                if(op.getOperatorType() == CountDistinctOperatorTypes.KMV &&
                        (op.getHashType() == HashType.ExpHash || 
op.getHashType() == HashType.StandardJava)) {
                        throw new DMLException(
                                "Invalid hashing configuration using " + 
op.getHashType() + " and " + op.getOperatorType());
                }
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL) 
{
-                       throw new NotImplementedException("HyperLogLog not 
implemented");
+                       throw new NotImplementedException("HyperLogLog has not 
been implemented yet");
+               }
+
+               // shortcut in the simplest case.
+               if(in.getLength() == 1 || in.isEmpty()) {
+                       return new MatrixBlock(1);
                }
-               // shortcut in simplest case.
-               if(in.getLength() == 1 || in.isEmpty())
-                       return 1;
-               else if(in.getNonZeros() < minimumSize) {
-                       // Just use naive implementation if the number of 
nonZeros values size is small.
-                       res = countDistinctValuesNaive(in);
+
+               long averageNnzPerRowOrCol;
+               if (op.getDirection().isRowCol()) {
+                       averageNnzPerRowOrCol = in.getNonZeros();
+               } else if (op.getDirection().isRow()) {
+                       // The average nnz per row is susceptible to skew. 
However, given that CP instructions is limited to
+                       // matrices of size at most 1000 x 1000, the 
performance impact of using naive counting over sketch per
+                       // row/col as determined by the average is negligible. 
Besides, the average is the simplest measure
+                       // available without calculating nnz per row/col.
+                       averageNnzPerRowOrCol = (long) 
Math.floor(in.getNonZeros() / (double) in.getNumRows());
+               } else if (op.getDirection().isCol()) {
+                       averageNnzPerRowOrCol = (long) 
Math.floor(in.getNonZeros() / (double) in.getNumColumns());
+               } else {
+                       throw new IllegalArgumentException("Unrecognized 
direction " + op.getDirection());
                }
-               else {
+
+               // Result is a dense 1x1 (RowCol), Mx1 (Row), or 1xN (Col) 
matrix
+               MatrixBlock res;
+               if (averageNnzPerRowOrCol < minimumSize) {
+                       // Resort to naive counting for small enough matrices
+                       res = countDistinctValuesNaive(in, op);
+               } else {
                        switch(op.getOperatorType()) {
                                case COUNT:
-                                       res = countDistinctValuesNaive(in);
+                                       res = countDistinctValuesNaive(in, op);
                                        break;
                                case KMV:
-                                       res = new 
KMVSketch(op).getScalarValue(in);
+                                       res = new KMVSketch(op).getValue(in);
                                        break;
                                default:
-                                       throw new DMLException("Invalid or not 
implemented Estimator Type");
+                                       throw new DMLException("Invalid 
estimator type for aggregation: " + 
LibMatrixCountDistinct.class.getSimpleName());
                        }
                }
 
-               if(res <= 0)
-                       throw new DMLRuntimeException("Impossible estimate of 
distinct values");
                return res;
        }
 
@@ -102,66 +114,257 @@ public interface LibMatrixCountDistinct {
         * 
         * Benefit: precise, but uses memory, on the scale of inputs number of 
distinct values.
         * 
-        * @param in The input matrix to count number distinct values in
-        * @return The absolute distinct count
+        * @param blkIn The input matrix to count number distinct values in
+        * @return A matrix block containing the absolute distinct count for 
the entire input or along given row/col axis
         */
-       private static int countDistinctValuesNaive(MatrixBlock in) {
+       private static MatrixBlock countDistinctValuesNaive(MatrixBlock blkIn, 
CountDistinctOperator op) {
+
+               if (blkIn.isEmpty()) {
+                       return new MatrixBlock(1);
+               }
+               else if(blkIn instanceof CompressedMatrixBlock) {
+                       throw new NotImplementedException("countDistinct() does 
not support CompressedMatrixBlock");
+               }
+
                Set<Double> distinct = new HashSet<>();
+               MatrixBlock blkOut;
                double[] data;
-               if(in.isEmpty())
-                       return 1;
-               else if(in instanceof CompressedMatrixBlock)
-                       throw new NotImplementedException();
 
-               long nonZeros = in.getNonZeros();
+               if (op.getDirection().isRowCol()) {
+                       blkOut = new MatrixBlock(1, 1, false);
 
-               if(nonZeros != -1 && nonZeros < in.getNumColumns() * 
in.getNumRows()) {
-                       distinct.add(0d);
-               }
+                       long distinctCount = 0;
+                       long nonZeros = blkIn.getNonZeros();
 
-               if(in.sparseBlock != null) {
-                       SparseBlock sb = in.sparseBlock;
+                       // Check if input matrix contains any 0 values for 
RowCol case.
+                       // This does not apply to row/col case, where we count 
nnz per row or col during iteration.
+                       if(nonZeros != -1 && nonZeros < (long) 
blkIn.getNumColumns() * blkIn.getNumRows()) {
+                               distinct.add(0d);
+                       }
 
-                       if(in.sparseBlock.isContiguous()) {
-                               data = sb.values(0);
-                               countDistinctValuesNaive(data, distinct);
+                       if(blkIn.getSparseBlock() != null) {
+                               SparseBlock sb = blkIn.getSparseBlock();
+                               if(blkIn.getSparseBlock().isContiguous()) {
+                                       // COO, CSR
+                                       data = sb.values(0);
+                                       distinctCount = 
countDistinctValuesNaive(data, distinct);
+                               } else {
+                                       // MCSR
+                                       for(int i = 0; i < blkIn.getNumRows(); 
i++) {
+                                               if(!sb.isEmpty(i)) {
+                                                       data = 
blkIn.getSparseBlock().values(i);
+                                                       distinctCount = 
countDistinctValuesNaive(data, distinct);
+                                               }
+                                       }
+                               }
+                       } else if(blkIn.getDenseBlock() != null) {
+                               DenseBlock db = blkIn.getDenseBlock();
+                               for (int i = 0; i <= db.numBlocks(); i++) {
+                                       data = db.valuesAt(i);
+                                       distinctCount = 
countDistinctValuesNaive(data, distinct);
+                               }
                        }
-                       else {
-                               for(int i = 0; i < in.getNumRows(); i++) {
-                                       if(!sb.isEmpty(i)) {
-                                               data = in.sparseBlock.values(i);
+
+                       blkOut.setValue(0, 0, distinctCount);
+               } else if (op.getDirection().isRow()) {
+                       blkOut = new MatrixBlock(blkIn.getNumRows(), 1, false, 
blkIn.getNumRows());
+                       blkOut.allocateBlock();
+
+                       if (blkIn.getDenseBlock() != null) {
+                               // The naive approach would be to iterate 
through every (i, j) in the input. However, can do better
+                               // by exploiting the physical layout of dense 
blocks - contiguous blocks in row-major order - in memory.
+                               DenseBlock db = blkIn.getDenseBlock();
+                               for (int bix=0; bix<db.numBlocks(); ++bix) {
+                                       data = db.valuesAt(bix);
+                                       for (int rix=bix * db.blockSize(); 
rix<blkIn.getNumRows(); rix++) {
+                                               distinct.clear();
+                                               for (int cix=0; 
cix<blkIn.getNumColumns(); ++cix) {
+                                                       
distinct.add(data[db.pos(rix, cix)]);
+                                               }
+                                               blkOut.setValue(rix, 0, 
distinct.size());
+                                       }
+                               }
+                       } else if (blkIn.getSparseBlock() != null) {
+                               // Each sparse block type - COO, CSR, MCSR - 
has a different data representation, which we will exploit
+                               // separately.
+                               SparseBlock sb = blkIn.getSparseBlock();
+                               if (SparseBlockFactory.isSparseBlockType(sb, 
SparseBlock.Type.MCSR)) {
+                                       // Currently, SparseBlockIterator only 
provides an interface for cell-wise iteration.
+                                       // TODO Explore row-wise and 
column-wise methods for SparseBlockIterator
+
+                                       // MCSR enables O(1) access to column 
values per row
+                                       for (int rix=0; rix<blkIn.getNumRows(); 
++rix) {
+                                               if (sb.isEmpty(rix)) {
+                                                       continue;
+                                               }
+                                               distinct.clear();
+                                               data = sb.values(rix);
                                                countDistinctValuesNaive(data, 
distinct);
+                                               blkOut.setValue(rix, 0, 
distinct.size());
+                                       }
+                               } else if 
(SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+                                       // Casting is safe given if-condition
+                                       SparseBlockCSR csrBlock = 
(SparseBlockCSR) sb;
+
+                                       // Data lies in one contiguous block in 
CSR format. We will iterate in row-major using O(1) op
+                                       // size(row) to determine the number of 
columns per row.
+                                       data = csrBlock.values();
+                                       // We want to iterate through all rows 
to keep track of the row index for constructing the output
+                                       for (int rix=0; rix<blkIn.getNumRows(); 
++rix) {
+                                               if (csrBlock.isEmpty(rix)) {
+                                                       continue;
+                                               }
+                                               distinct.clear();
+                                               int rpos = csrBlock.pos(rix);
+                                               int clen = csrBlock.size(rix);
+                                               for (int colOffset=0; 
colOffset<clen; ++colOffset) {
+                                                       distinct.add(data[rpos 
+ colOffset]);
+                                               }
+                                               blkOut.setValue(rix, 0, 
distinct.size());
+                                       }
+                               } else { // COO
+                                       if (!(sb instanceof SparseBlockCOO)) {
+                                               throw new 
IllegalArgumentException("Input matrix is of unrecognized type: "
+                                                               + 
sb.getClass().getSimpleName());
+                                       }
+                                       SparseBlockCOO cooBlock = 
(SparseBlockCOO) sb;
+
+                                       // For COO, we want to avoid using 
pos(row) and size(row) as they use binary search, which is a
+                                       // O(log N) op. Also, isEmpty(row) uses 
pos(row) internally.
+                                       int[] rixs = cooBlock.rowIndexes();
+                                       data = cooBlock.values();
+                                       int i = 0;  // data iterator
+                                       int rix = 0;  // row index
+                                       while (rix < cooBlock.numRows() && i < 
rixs.length) {
+                                               distinct.clear();
+                                               while (i + 1 < rixs.length && 
rixs[i] == rixs[i + 1]) {
+                                                       distinct.add(data[i]);
+                                                       i++;
+                                               }
+                                               if (i + 1 < rixs.length) {  // 
rixs[i] != rixs[i + 1]
+                                                       distinct.add(data[i]);
+                                               }
+                                               blkOut.setValue(rix, 0, 
distinct.size());
+                                               rix = (i + 1 < rixs.length)? 
rixs[i + 1] : rix;
+                                               i++;
                                        }
                                }
                        }
-               }
-               else if(in.denseBlock != null) {
-                       DenseBlock db = in.denseBlock;
-                       for(int i = 0; i <= db.numBlocks(); i++) {
-                               data = db.valuesAt(i);
-                               countDistinctValuesNaive(data, distinct);
+               } else {  // Col aggregation
+                       blkOut = new MatrixBlock(1, blkIn.getNumColumns(), 
false, blkIn.getNumRows());
+                       blkOut.allocateBlock();
+
+                       // All dense and sparse formats (COO, CSR, MCSR) are 
row-major formats, so there is no obvious way to iterate
+                       // in column-major order besides iterating through 
every (i, j) pair. getValue() skips over empty cells in CSR
+                       // and MCSR formats, but not so in COO format. This 
results in O(log2 R * log2 C) time for every lookup,
+                       // amounting to O(RC * log2R * log2C) for the whole 
block (R, C <= 1000 in CP case). We will eschew this
+                       // approach in favor of one using a hash map M of 
(column index, distinct values) to obtain a pseudo column-major
+                       // grouping of distinct values instead. Given this 
setup, we will simply iterate over the input
+                       // (according to specific dense/sparse format) in 
row-major order and populate M. Finally, an O(C) iteration
+                       // over M will yield the final result.
+                       Map<Integer, Set<Double>> distinctValuesByCol = new 
HashMap<>();
+                       if (blkIn.getDenseBlock() != null) {
+                               DenseBlock db = blkIn.getDenseBlock();
+                               for (int bix=0; bix<db.numBlocks(); ++bix) {
+                                       data = db.valuesAt(bix);
+                                       for (int cix=0; 
cix<blkIn.getNumColumns(); ++cix) {
+                                               Set<Double> distinctValues = 
distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+                                               for (int rix=bix * 
db.blockSize(); rix<blkIn.getNumRows(); rix++) {
+                                                       double val = 
data[db.pos(rix, cix)];
+                                                       distinctValues.add(val);
+                                               }
+                                               distinctValuesByCol.put(cix, 
distinctValues);
+                                       }
+                               }
+                       } else if (blkIn.getSparseBlock() != null) {
+                               SparseBlock sb = blkIn.getSparseBlock();
+                               if (SparseBlockFactory.isSparseBlockType(sb, 
SparseBlock.Type.MCSR)) {
+                                       for (int rix=0; rix<blkIn.getNumRows(); 
++rix) {
+                                               if (sb.isEmpty(rix)) {
+                                                       continue;
+                                               }
+                                               int[] cixs = sb.indexes(rix);
+                                               data = sb.values(rix);
+                                               for (int j=0; j<sb.size(rix); 
++j) {
+                                                       int cix = cixs[j];
+                                                       Set<Double> 
distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+                                                       
distinctValues.add(data[j]);
+                                                       
distinctValuesByCol.put(cix, distinctValues);
+                                               }
+                                       }
+                               } else if 
(SparseBlockFactory.isSparseBlockType(sb, SparseBlock.Type.CSR)) {
+                                       SparseBlockCSR csrBlock = 
(SparseBlockCSR) sb;
+                                       data = csrBlock.values();
+                                       for (int rix=0; rix<blkIn.getNumRows(); 
++rix) {
+                                               if (csrBlock.isEmpty(rix)) {
+                                                       continue;
+                                               }
+                                               distinct.clear();
+                                               int rpos = csrBlock.pos(rix);
+                                               int clen = csrBlock.size(rix);
+                                               int[] cixs = csrBlock.indexes();
+                                               for (int colOffset=0; 
colOffset<clen; ++colOffset) {
+                                                       int cix = cixs[rpos + 
colOffset];
+                                                       Set<Double> 
distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+                                                       
distinctValues.add(data[rpos + colOffset]);
+                                                       
distinctValuesByCol.put(cix, distinctValues);
+                                               }
+                                       }
+                               } else {  // COO
+                                       if (!(sb instanceof SparseBlockCOO)) {
+                                               throw new 
IllegalArgumentException("Input matrix is of unrecognized type: "
+                                                               + 
sb.getClass().getSimpleName());
+                                       }
+                                       SparseBlockCOO cooBlock = 
(SparseBlockCOO) sb;
+
+                                       int[] rixs = cooBlock.rowIndexes();
+                                       int[] cixs = cooBlock.indexes();
+                                       data = cooBlock.values();
+                                       int i = 0;  // data iterator
+                                       while (i < rixs.length) {
+                                               while (i + 1 < rixs.length && 
rixs[i] == rixs[i + 1]) {
+                                                       int cix = cixs[i];
+                                                       Set<Double> 
distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+                                                       
distinctValues.add(data[i]);
+                                                       
distinctValuesByCol.put(cix, distinctValues);
+                                                       i++;
+                                               }
+                                               if (i + 1 < rixs.length) {
+                                                       int cix = cixs[i];
+                                                       Set<Double> 
distinctValues = distinctValuesByCol.getOrDefault(cix, new HashSet<>());
+                                                       
distinctValues.add(data[i]);
+                                                       
distinctValuesByCol.put(cix, distinctValues);
+                                               }
+                                               i++;
+                                       }
+                               }
+                       }
+                       // Fill in output block with column aggregation results
+                       for (int cix : distinctValuesByCol.keySet()) {
+                               blkOut.setValue(0, cix, 
distinctValuesByCol.get(cix).size());
                        }
                }
 
-               return distinct.size();
+               return blkOut;
        }
 
-       private static Set<Double> countDistinctValuesNaive(double[] 
valuesPart, Set<Double> distinct) {
-               for(double v : valuesPart) 
+       private static long countDistinctValuesNaive(double[] valuesPart, 
Set<Double> distinct) {
+               for(double v : valuesPart)
                        distinct.add(v);
-               return distinct;
+
+               return distinct.size();
        }
 
-       public static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock 
arg0, CountDistinctOperator op) {
+       static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, 
CountDistinctOperator op) {
                if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
-                       return new KMVSketch(op).getMatrixValue(arg0);
+                       return new KMVSketch(op).getValueFromSketch(arg0);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
                        throw new NotImplementedException("Not implemented 
yet");
                else
                        throw new NotImplementedException("Not implemented 
yet");
        }
 
-       public static CorrMatrixBlock createSketch(MatrixBlock blkIn, 
CountDistinctOperator op) {
+       static CorrMatrixBlock createSketch(MatrixBlock blkIn, 
CountDistinctOperator op) {
                if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
                        return new KMVSketch(op).create(blkIn);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
@@ -170,7 +373,7 @@ public interface LibMatrixCountDistinct {
                        throw new NotImplementedException("Not implemented 
yet");
        }
 
-       public static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1, CountDistinctOperator op) {
+       static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1, CountDistinctOperator op) {
                if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
                        return new KMVSketch(op).union(arg0, arg1);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
index f9c5f63a03..6feb52a140 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
@@ -22,15 +22,15 @@ package org.apache.sysds.runtime.matrix.data.sketch;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-public interface MatrixSketch<T> {
+public interface MatrixSketch {
 
        /**
-        * Get scalar distinct count from a input matrix block.
+        * Get scalar distinct count from an input matrix block.
         * 
-        * @param blkIn A input block to estimate the number of distinct values 
in
-        * @return The distinct count estimate
+        * @param blkIn An input block to estimate the number of distinct 
values in
+        * @return The result matrix block containing the distinct count 
estimate
         */
-       T getScalarValue(MatrixBlock blkIn);
+       MatrixBlock getValue(MatrixBlock blkIn);
 
        /**
         * Obtain matrix distinct count value from estimation Used for 
estimating distinct in rows or columns.
@@ -38,31 +38,31 @@ public interface MatrixSketch<T> {
         * @param blkIn The sketch block to extract the count from
         * @return The result matrix block
         */
-       public MatrixBlock getMatrixValue(CorrMatrixBlock blkIn);
+       MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn);
 
        /**
-        * Create a initial sketch of a given block.
+        * Create an initial sketch of a given block.
         * 
         * @param blkIn A block to process
         * @return A sketch
         */
-       public CorrMatrixBlock create(MatrixBlock blkIn);
+       CorrMatrixBlock create(MatrixBlock blkIn);
 
        /**
         * Union two sketches together to from a combined sketch.
         * 
         * @param arg0 Sketch one
         * @param arg1 Sketch two
-        * @return The combined sketch
+        * @return The sketch union is a sketch
         */
-       public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1);
+       CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock arg1);
 
        /**
         * Intersect two sketches
         * 
         * @param arg0 Sketch one
         * @param arg1 Sketch two
-        * @return The intersected sketch
+        * @return The sketch intersection is a sketch
         */
-       public CorrMatrixBlock intersection(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1);
+       CorrMatrixBlock intersection(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1);
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
index 9893e098c5..d5df3b241a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
@@ -26,7 +26,7 @@ import 
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 // Package private
-abstract class CountDistinctApproxSketch implements MatrixSketch<Integer> {
+abstract class CountDistinctApproxSketch implements MatrixSketch {
        CountDistinctOperator op;
 
        CountDistinctApproxSketch(Operator op) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
index 01cfb289e5..31e7d15c5d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -22,7 +22,6 @@ package 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.data.DenseBlock;
@@ -52,49 +51,89 @@ public class KMVSketch extends CountDistinctApproxSketch {
        }
 
        @Override
-       public Integer getScalarValue(MatrixBlock in) {
-
-               // D is the number of possible distinct values in the 
MatrixBlock.
-               // plus 1 to take account of 0 input.
-               long D = in.getNonZeros() + 1;
-
-               /**
-                * To ensure that the likelihood to hash to the same value we 
need O(D^2) positions to hash to assign. If the
-                * value is higher than int (which is the area we hash to) then 
use Integer Max value as largest hashing space.
-                */
-               long tmp = D * D;
-               int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : 
(int) tmp;
-               /**
-                * The estimator is asymptotically unbiased as k becomes large, 
but memory usage also scales with k. Furthermore k
-                * value must be within range: D >> k >> 0
-                */
-               int k = D > 64 ? 64 : (int) D;
+       public MatrixBlock getValue(MatrixBlock blkIn) {
+
+               if (this.op.getDirection().isRowCol()) {
+                       // D is the number of possible distinct values in the 
MatrixBlock.
+                       // plus 1 to take account of 0 input.
+                       long D = blkIn.getNonZeros() + 1;
+
+                       /**
+                        * To ensure that the likelihood to hash to the same 
value we need O(D^2) positions to hash to assign. If the
+                        * value is higher than int (which is the area we hash 
to) then use Integer Max value as largest hashing space.
+                        */
+                       long tmp = D * D;
+                       int M = (tmp > (long) Integer.MAX_VALUE) ? 
Integer.MAX_VALUE : (int) tmp;
+                       /**
+                        * The estimator is asymptotically unbiased as k 
becomes large, but memory usage also scales with k. Furthermore k
+                        * value must be within range: D >> k >> 0
+                        */
+                       int k = D > 64 ? 64 : (int) D;
+
+                       SmallestPriorityQueue spq = getKSmallestHashes(blkIn, 
k, M);
 
-               SmallestPriorityQueue spq = getKSmallestHashes(in, k, M);
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("M not forced to int size: " + tmp);
+                               LOG.debug("M: " + M);
+                               LOG.debug("M: " + M);
+                               LOG.debug("kth smallest hash:" + spq.peek());
+                               LOG.debug("spq: " + spq);
+                       }
 
-               if(LOG.isDebugEnabled()) {
-                       LOG.debug("M not forced to int size: " + tmp);
-                       LOG.debug("M: " + M);
-                       LOG.debug("M: " + M);
-                       LOG.debug("kth smallest hash:" + spq.peek());
-                       LOG.debug("spq: " + spq.toString());
-               }
 
-               if(spq.size() < k) {
-                       return spq.size();
-               }
-               else {
-                       double kthSmallestHash = spq.poll();
-                       double U_k = kthSmallestHash / (double) M;
-                       double estimate = (double) (k - 1) / U_k;
-                       double ceilEstimate = Math.min(estimate, (double) D);
+                       long res = countDistinctValuesKMV(spq, k, M, D);
+                       if(res <= 0) {
+                               throw new DMLRuntimeException("Impossible 
estimate of distinct values");
+                       }
 
-                       if(LOG.isDebugEnabled()) {
-                               LOG.debug("U_k : " + U_k);
-                               LOG.debug("Estimate: " + estimate);
-                               LOG.debug("Ceil worst case: " + D);
+                       // Result is a 1x1 matrix block
+                       return new MatrixBlock(res);
+
+               } else if (this.op.getDirection().isRow()) {
+                       long D = (long) Math.floor(blkIn.getNonZeros() / 
(double) blkIn.getNumRows()) + 1;
+                       long tmp = D * D;
+                       int M = (tmp > (long) Integer.MAX_VALUE) ? 
Integer.MAX_VALUE : (int) tmp;
+                       int k = D > 64 ? 64 : (int) D;
+
+                       MatrixBlock resultMatrix = new 
MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+                       resultMatrix.allocateBlock();
+
+                       SmallestPriorityQueue spq = new 
SmallestPriorityQueue(k);
+                       for (int i=0; i<blkIn.getNumRows(); ++i) {
+                               for (int j=0; j<blkIn.getNumColumns(); ++j) {
+                                       spq.add(blkIn.getValue(i, j));
+                               }
+
+                               long res = countDistinctValuesKMV(spq, k, M, D);
+                               resultMatrix.setValue(i, 0, res);
+
+                               spq.clear();
+                       }
+
+                       return resultMatrix;
+
+               } else {  // Col
+                       long D = (long) Math.floor(blkIn.getNonZeros() / 
(double) blkIn.getNumColumns()) + 1;
+                       long tmp = D * D;
+                       int M = (tmp > (long) Integer.MAX_VALUE) ? 
Integer.MAX_VALUE : (int) tmp;
+                       int k = D > 64 ? 64 : (int) D;
+
+                       MatrixBlock resultMatrix = new MatrixBlock(1, 
blkIn.getNumColumns(), false, blkIn.getNumColumns());
+                       resultMatrix.allocateBlock();
+
+                       SmallestPriorityQueue spq = new 
SmallestPriorityQueue(k);
+                       for (int j=0; j<blkIn.getNumColumns(); ++j) {
+                               for (int i=0; i<blkIn.getNumRows(); ++i) {
+                                       spq.add(blkIn.getValue(i, j));
+                               }
+
+                               long res = countDistinctValuesKMV(spq, k, M, D);
+                               resultMatrix.setValue(0, j, res);
+
+                               spq.clear();
                        }
-                       return (int) ceilEstimate;
+
+                       return resultMatrix;
                }
        }
 
@@ -146,21 +185,45 @@ public class KMVSketch extends CountDistinctApproxSketch {
                }
        }
 
+       private long countDistinctValuesKMV(SmallestPriorityQueue spq, int k, 
int M, long D) {
+               long res;
+               if(spq.size() < k) {
+                       res = spq.size();
+               }
+               else {
+                       double kthSmallestHash = spq.poll();
+                       double U_k = kthSmallestHash / (double) M;
+                       double estimate = (double) (k - 1) / U_k;
+                       double ceilEstimate = Math.min(estimate, (double) D);
+
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("U_k : " + U_k);
+                               LOG.debug("Estimate: " + estimate);
+                               LOG.debug("Ceil worst case: " + D);
+                       }
+                       res = Math.round(ceilEstimate);
+               }
+
+               return res;
+       }
+
        @Override
-       public MatrixBlock getMatrixValue(CorrMatrixBlock arg0) {
+       public MatrixBlock getValueFromSketch(CorrMatrixBlock arg0) {
                MatrixBlock blkIn = arg0.getValue();
-               if(op.getDirection() == Types.Direction.Row) {
-                       // 1000 x 1 blkOut -> slice out the first column of the 
matrix
-                       MatrixBlock blkOut = blkIn.slice(0, blkIn.getNumRows() 
- 1, 0, 0);
+               if(op.getDirection().isRow()) {
+                       // 1000 x 1 blkOut
+                       MatrixBlock blkOut = new 
MatrixBlock(blkIn.getNumRows(), 1, false, blkIn.getNumRows());
+                       blkOut.allocateBlock();
                        for(int i = 0; i < blkIn.getNumRows(); ++i) {
                                getDistinctCountFromSketchByIndex(arg0, i, 
blkOut);
                        }
 
                        return blkOut;
                }
-               else if(op.getDirection() == Types.Direction.Col) {
-                       // 1 x 1000 blkOut -> slice out the first row of the 
matrix
-                       MatrixBlock blkOut = blkIn.slice(0, 0, 0, 
blkIn.getNumColumns() - 1);
+               else if(op.getDirection().isCol()) {
+                       // 1 x 1000 blkOut
+                       MatrixBlock blkOut = new MatrixBlock(1, 
blkIn.getNumColumns(), false, blkIn.getNumColumns());
+                       blkOut.allocateBlock();
                        for(int j = 0; j < blkIn.getNumColumns(); ++j) {
                                getDistinctCountFromSketchByIndex(arg0, j, 
blkOut);
                        }
@@ -169,8 +232,9 @@ public class KMVSketch extends CountDistinctApproxSketch {
                }
                else { // op.getDirection().isRowCol()
 
-                       // 1 x 1 blkOut -> slice out the first row and column 
of the matrix
-                       MatrixBlock blkOut = blkIn.slice(0, 0, 0, 0);
+                       // 1 x 1 blkOut
+                       MatrixBlock blkOut = new MatrixBlock(1, 1, false, 1);
+                       blkOut.allocateBlock();
                        getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
 
                        return blkOut;
@@ -181,41 +245,43 @@ public class KMVSketch extends CountDistinctApproxSketch {
                MatrixBlock blkIn = arg0.getValue();
                MatrixBlock blkInCorr = arg0.getCorrection();
 
-               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV) {
-                       double kthSmallestHash;
-                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
-                               kthSmallestHash = blkIn.getValue(idx, 0);
-                       }
-                       else { // op.getDirection().isCol()
-                               kthSmallestHash = blkIn.getValue(0, idx);
-                       }
+               if(op.getOperatorType() != CountDistinctOperatorTypes.KMV) {
+                       throw new 
IllegalArgumentException(this.getClass().getSimpleName() + " cannot use " + 
op.getOperatorType());
+               }
 
-                       double nHashes = blkInCorr.getValue(idx, 0);
-                       double k = blkInCorr.getValue(idx, 1);
-                       double D = blkInCorr.getValue(idx, 2);
+               double kthSmallestHash;
+               if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+                       kthSmallestHash = blkIn.getValue(idx, 0);
+               }
+               else { // op.getDirection().isCol()
+                       kthSmallestHash = blkIn.getValue(0, idx);
+               }
 
-                       double D2 = D * D;
-                       double M = (D2 > (long) Integer.MAX_VALUE) ? 
Integer.MAX_VALUE : D2;
+               double nHashes = blkInCorr.getValue(idx, 0);
+               double k = blkInCorr.getValue(idx, 1);
+               double D = blkInCorr.getValue(idx, 2);
 
-                       double ceilEstimate;
-                       if(nHashes != 0 && nHashes < k) {
-                               ceilEstimate = nHashes;
-                       }
-                       else if(nHashes == 0) {
-                               ceilEstimate = 1;
-                       }
-                       else {
-                               double U_k = kthSmallestHash / M;
-                               double estimate = (k - 1) / U_k;
-                               ceilEstimate = Math.min(estimate, D);
-                       }
+               double D2 = D * D;
+               double M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE 
: D2;
 
-                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
-                               blkOut.setValue(idx, 0, ceilEstimate);
-                       }
-                       else { // op.getDirection().isCol()
-                               blkOut.setValue(0, idx, ceilEstimate);
-                       }
+               double ceilEstimate;
+               if(nHashes != 0 && nHashes < k) {
+                       ceilEstimate = nHashes;
+               }
+               else if(nHashes == 0) {
+                       ceilEstimate = 1;
+               }
+               else {
+                       double U_k = kthSmallestHash / M;
+                       double estimate = (k - 1) / U_k;
+                       ceilEstimate = Math.min(estimate, D);
+               }
+
+               if(op.getDirection().isRow() || op.getDirection().isRowCol()) {
+                       blkOut.setValue(idx, 0, ceilEstimate);
+               }
+               else { // op.getDirection().isCol()
+                       blkOut.setValue(0, idx, ceilEstimate);
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
index 0a29028c66..f3f7336181 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
@@ -77,6 +77,11 @@ public class SmallestPriorityQueue {
                return this.size() == 0;
        }
 
+       public void clear() {
+               this.containedSet.clear();
+               this.smallestHashes.clear();
+       }
+
        @Override
        public String toString() {
                return smallestHashes.toString();
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index cd20b67c35..5de18c4b3e 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -145,7 +145,7 @@ public class CountDistinctTest {
                                });
                        }
                        else {
-                               int out = 
LibMatrixCountDistinct.estimateDistinctValues(in, op);
+                               int out = (int) 
LibMatrixCountDistinct.estimateDistinctValues(in, op).getValue(0, 0);
                                int count = out;
                                boolean success = Math.abs(nrUnique - count) <= 
nrUnique * epsilon;
                                StringBuilder sb = new StringBuilder();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index e808cb5a76..5a7eccc447 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -20,6 +20,8 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
 
 public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 
@@ -51,4 +53,50 @@ public class CountDistinctApproxCol extends 
CountDistinctRowOrColBase {
        public void setUp() {
                super.addTestConfiguration();
        }
+
+       @Test
+       public void testCPSparseLargeDefaultMCSR() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 10000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparseLargeCSR() {
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 10000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+                               tolerance);
+       }
+
+       @Test
+       public void testCPSparseLargeCOO() {
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 10000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Col, rows, cols, actualDistinctCount, sparsity,
+                               tolerance);
+       }
+
+       @Test
+       public void testCPDenseLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 100;
+               int rows = 1000, cols = 10000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index 05a125636f..c9aa75e375 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -20,6 +20,8 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.junit.Test;
 
 public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 
@@ -51,4 +53,50 @@ public class CountDistinctApproxRow extends 
CountDistinctRowOrColBase {
        public void setUp() {
                super.addTestConfiguration();
        }
+
+       @Test
+       public void testCPSparseLargeDefaultMCSR() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparseLargeCSR() {
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               super.testCPSparseLarge(SparseBlock.Type.CSR, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+                               tolerance);
+       }
+
+       @Test
+       public void testCPSparseLargeCOO() {
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               super.testCPSparseLarge(SparseBlock.Type.COO, 
Types.Direction.Row, rows, cols, actualDistinctCount, sparsity,
+                               tolerance);
+       }
+
+       @Test
+       public void testCPDenseLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 100;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 041cf51a00..5bf850d49a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysds.test.functions.countDistinct;
 
-import static org.junit.Assert.assertTrue;
-
 import org.apache.sysds.common.Types;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 
+import static org.junit.Assert.assertTrue;
+
 public abstract class CountDistinctBase extends AutomatedTestBase {
        protected double percentTolerance = 0.0;
        protected double baseTolerance = 0.0001;
@@ -88,7 +88,7 @@ public abstract class CountDistinctBase extends 
AutomatedTestBase {
                }
        }
 
-       private double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int 
cols, int rows, long expectedValue) {
+       protected double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int 
cols, int rows, long expectedValue) {
                double[][] expectedResult;
                if(dir.isRow()) {
                        expectedResult = new double[rows][1];
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
index df2ea8a0ce..a880c0d0dd 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -20,6 +20,12 @@
 package org.apache.sysds.test.functions.countDistinct;
 
 import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Test;
@@ -44,24 +50,15 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
                this.percentTolerance = 0.2;
        }
 
+       /**
+        * This is a contrived example where size of row/col > 1024, which 
forces the calculation of a sketch.
+        */
        @Test
-       public void testCPSparseLarge() {
+       public void testCPDenseXLarge() {
                Types.ExecType ex = Types.ExecType.CP;
 
-               int actualDistinctCount = 10;
-               int rows = 10000, cols = 1000;
-               double sparsity = 0.1;
-               double tolerance = actualDistinctCount * this.percentTolerance;
-
-               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
-       }
-
-       @Test
-       public void testCPDenseLarge() {
-               Types.ExecType ex = Types.ExecType.CP;
-
-               int actualDistinctCount = 100;
-               int rows = 10000, cols = 1000;
+               int actualDistinctCount = 10000;
+               int rows = 10000, cols = 10000;
                double sparsity = 0.9;
                double tolerance = actualDistinctCount * this.percentTolerance;
 
@@ -139,4 +136,22 @@ public abstract class CountDistinctRowOrColBase extends 
CountDistinctBase {
 
                countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
        }
+
+       protected void testCPSparseLarge(SparseBlock.Type sparseBlockType, 
Types.Direction direction, int rows, int cols,
+                                                                        int 
actualDistinctCount, double sparsity, double tolerance) {
+               MatrixBlock blkIn = 
TestUtils.round(TestUtils.generateTestMatrixBlock(rows, cols, 0, 
actualDistinctCount, sparsity, 7));
+               if (!blkIn.isInSparseFormat()) {
+                       blkIn.denseToSparse(false);
+               }
+               blkIn = new MatrixBlock(blkIn, sparseBlockType, true);
+
+               CountDistinctOperator op = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX)
+                               .setDirection(direction)
+                               
.setIndexFunction(ReduceCol.getReduceColFnObject());
+
+               MatrixBlock blkOut = 
LibMatrixCountDistinct.estimateDistinctValues(blkIn, op);
+               double[][] expectedMatrix = 
getExpectedMatrixRowOrCol(direction, cols, rows, actualDistinctCount);
+
+               TestUtils.compareMatrices(expectedMatrix, blkOut, tolerance, 
"");
+       }
 }

Reply via email to