This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5590135  [SYSTEMDS-2996] countDistinctApprox Builtin function
5590135 is described below

commit 5590135bd2d73c50e9528db5a83f15c0f7964d4b
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sun Jan 30 21:27:16 2022 -0800

    [SYSTEMDS-2996] countDistinctApprox Builtin function
    
    This commit adds countDistinctApprox instruction to allow
    for a faster approximate counting of distinct elements in a matrix.
    Also added is support for spark with this new instruction.
    
    Closes #1531
    Closes #1554
    
    (Just to make sure github see that you are the author)
    Co-authored-by: Badrul Chowdhury <[email protected]>
---
 .../java/org/apache/sysds/common/Builtins.java     |   2 +-
 src/main/java/org/apache/sysds/common/Types.java   |   3 +
 src/main/java/org/apache/sysds/conf/DMLConfig.java |   2 +
 .../org/apache/sysds/lops/PartialAggregate.java    |   9 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |  16 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |  60 ++-
 .../ParameterizedBuiltinFunctionExpression.java    | 140 +++++-
 .../estim/CompressedSizeEstimatorUltraSparse.java  |   4 +-
 .../sysds/runtime/functionobjects/Builtin.java     |   2 +-
 .../runtime/instructions/CPInstructionParser.java  |  14 +-
 .../runtime/instructions/SPInstructionParser.java  |  15 +-
 .../cp/AggregateUnaryCPInstruction.java            |  70 ++-
 .../spark/AggregateUnarySPInstruction.java         |   1 -
 .../spark/AggregateUnarySketchSPInstruction.java   | 293 +++++++++++++
 .../runtime/instructions/spark/SPInstruction.java  |   2 +-
 .../matrix/data/LibMatrixCountDistinct.java        | 200 ++-------
 .../runtime/matrix/data/sketch/MatrixSketch.java   |  68 +++
 .../CountDistinctApproxSketch.java                 |  56 +++
 .../data/sketch/countdistinctapprox/KMVSketch.java | 488 +++++++++++++++++++++
 .../countdistinctapprox/SmallestPriorityQueue.java |  84 ++++
 .../matrix/operators/CountDistinctOperator.java    |  60 ++-
 .../operators/CountDistinctOperatorTypes.java}     |  35 +-
 .../test/component/matrix/CountDistinctTest.java   |  62 ++-
 ...inctApprox.java => CountDistinctApproxCol.java} |  37 +-
 ...ntDistinct.java => CountDistinctApproxRow.java} |  29 +-
 .../countDistinct/CountDistinctApproxRowCol.java   | 140 ++++++
 .../functions/countDistinct/CountDistinctBase.java | 107 ++---
 ...CountDistinct.java => CountDistinctRowCol.java} |  14 +-
 .../countDistinct/CountDistinctRowColBase.java     |  81 ++++
 .../countDistinct/CountDistinctRowOrColBase.java   | 142 ++++++
 .../functions/countDistinct/countDistinct.dml      |   1 -
 ...stinctApprox.dml => countDistinctApproxCol.dml} |   4 +-
 ...ountDistinct.dml => countDistinctApproxRow.dml} |   5 +-
 ...tDistinct.dml => countDistinctApproxRowCol.dml} |   5 +-
 34 files changed, 1831 insertions(+), 420 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 2fab87b..8eec1e5 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -93,7 +93,6 @@ public enum Builtins {
        COS("cos", false),
        COSH("cosh", false),
        COUNT_DISTINCT("countDistinct",false),
-       COUNT_DISTINCT_APPROX("countDistinctApprox",false),
        COV("cov", false),
        COX("cox", true),
        CSPLINE("cspline", true),
@@ -306,6 +305,7 @@ public enum Builtins {
        //parameterized builtin functions
        AUTODIFF("autoDiff", false, true),
        CDF("cdf", false, true),
+       COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
        CVLM("cvlm", true, false),
        GROUPEDAGG("aggregate", "groupedAggregate", false, true),
        INVCDF("icdf", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 85a11e4..916935c 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -153,6 +153,9 @@ public class Types
                public boolean isCol() {
                        return this == Col;
                }
+               public boolean isRowCol() {
+                       return this == RowCol;
+               }
                @Override
                public String toString() {
                        switch(this) {
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index 2e27c17..f46be3b 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -25,6 +25,7 @@ import java.io.IOException;
 import java.io.StringWriter;
 import java.util.HashMap;
 
+import javax.xml.XMLConstants;
 import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
 import javax.xml.parsers.ParserConfigurationException;
@@ -245,6 +246,7 @@ public class DMLConfig
        private DocumentBuilder getDocumentBuilder() throws 
ParserConfigurationException {
                if (_documentBuilder == null) {
                        DocumentBuilderFactory factory = 
DocumentBuilderFactory.newInstance();
+                       
factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, true);  // Prevent 
XML Injection
                        factory.setIgnoringComments(true); //ignore XML comments
                        _documentBuilder = factory.newDocumentBuilder();
                }
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 1a3bde7..050d87a 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -217,7 +217,7 @@ public class PartialAggregate extends Lop
        }
 
        /**
-        * Instruction generation for for CP and Spark
+        * Instruction generation for CP and Spark
         */
        @Override
        public String getInstructions(String input1, String output) 
@@ -348,8 +348,11 @@ public class PartialAggregate extends Lop
                        }
                        
                        case COUNT_DISTINCT_APPROX: {
-                               if(dir == Direction.RowCol )
-                                       return "uacdap";
+                               switch (dir) {
+                                       case RowCol: return "uacdap";
+                                       case Row: return "uacdapr";
+                                       case Col: return "uacdapc";
+                               }
                                break;
                        }
                }
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index e3cb0ee..19b7177 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -623,10 +623,10 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case MEAN:
                        //checkNumParameters(2, false); // mean(Y) or mean(Y,W)
                        if (getSecondExpr() != null) {
-                               checkNumParameters (2);
+                               checkNumParameters(2);
                        }
                        else {
-                               checkNumParameters (1);
+                               checkNumParameters(1);
                        }
                        
                        checkMatrixParam(getFirstExpr());
@@ -933,7 +933,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setValueType(ValueType.INT64);
                        break;
                case COUNT_DISTINCT:
-               case COUNT_DISTINCT_APPROX:
                        checkNumParameters(1);
                        checkDataTypeParam(getFirstExpr(), DataType.MATRIX);
                        output.setDataType(DataType.SCALAR);
@@ -941,7 +940,6 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setBlocksize(0);
                        output.setValueType(ValueType.INT64);
                        break;
-               
                case LINEAGE:
                        checkNumParameters(1);
                        checkDataTypeParam(getFirstExpr(),
@@ -951,14 +949,12 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        output.setBlocksize(0);
                        output.setValueType(ValueType.STRING);
                        break;
-                       
                case LIST:
                        output.setDataType(DataType.LIST);
                        output.setValueType(ValueType.UNKNOWN);
                        output.setDimensions(getAllExpr().length, 1);
                        output.setBlocksize(-1);
                        break;
-               
                case EXISTS:
                        checkNumParameters(1);
                        checkStringOrDataIdentifier(getFirstExpr());
@@ -1825,9 +1821,9 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
        protected void checkNumParameters(int count) { //always unconditional
                if (getFirstExpr() == null && _args.length > 0) {
                        raiseValidateError("Missing argument for function " + 
this.getOpCode(), false,
-                               LanguageErrorCodes.INVALID_PARAMETERS);
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
                }
-               
+
                // Not sure the rationale for the first two if loops, but will 
keep them for backward compatibility
                if (((count == 1) && (getSecondExpr() != null || getThirdExpr() 
!= null))
                                || ((count == 2) && (getThirdExpr() != null))) {
@@ -1843,7 +1839,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                } else if (count == 0 && (_args.length > 0
                                || getSecondExpr() != null || getThirdExpr() != 
null)) {
                        raiseValidateError("Missing argument for function " + 
this.getOpCode()
-                               + "(). This function doesn't take any 
arguments.", false);
+                                       + "(). This function doesn't take any 
arguments.", false);
                }
        }
 
@@ -1870,7 +1866,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                if( !ArrayUtils.contains(dt, e.getOutput().getDataType()) )
                        raiseValidateError("Non-matching expected data type for 
function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
        }
-       
+
        protected void checkMatrixFrameParam(Expression e) { //always 
unconditional
                if (e.getOutput().getDataType() != DataType.MATRIX && 
e.getOutput().getDataType() != DataType.FRAME) {
                        raiseValidateError("Expecting matrix or frame parameter 
for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS);
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 84212a7..ef51904 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -30,6 +30,22 @@ import java.util.stream.Collectors;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Builtins;
+import org.apache.sysds.common.Types.AggOp;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.Direction;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.common.Types.OpOp2;
+import org.apache.sysds.common.Types.OpOp3;
+import org.apache.sysds.common.Types.OpOpDG;
+import org.apache.sysds.common.Types.OpOpData;
+import org.apache.sysds.common.Types.OpOpDnn;
+import org.apache.sysds.common.Types.OpOpN;
+import org.apache.sysds.common.Types.ParamBuiltinOp;
+import org.apache.sysds.common.Types.ReOrgOp;
+import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
 import org.apache.sysds.hops.AggBinaryOp;
@@ -62,22 +78,6 @@ import org.apache.sysds.hops.rewrite.ProgramRewriter;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.LopsException;
 import org.apache.sysds.lops.compile.Dag;
-import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.common.Builtins;
-import org.apache.sysds.common.Types.AggOp;
-import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.common.Types.Direction;
-import org.apache.sysds.common.Types.FileFormat;
-import org.apache.sysds.common.Types.OpOp1;
-import org.apache.sysds.common.Types.OpOp2;
-import org.apache.sysds.common.Types.OpOp3;
-import org.apache.sysds.common.Types.OpOpDG;
-import org.apache.sysds.common.Types.OpOpData;
-import org.apache.sysds.common.Types.OpOpDnn;
-import org.apache.sysds.common.Types.OpOpN;
-import org.apache.sysds.common.Types.ParamBuiltinOp;
-import org.apache.sysds.common.Types.ReOrgOp;
-import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.parser.PrintStatement.PRINTTYPE;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
@@ -91,7 +91,6 @@ import 
org.apache.sysds.runtime.controlprogram.WhileProgramBlock;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 
-
 public class DMLTranslator 
 {
        private static final Log LOG = 
LogFactory.getLog(DMLTranslator.class.getName());
@@ -2035,6 +2034,29 @@ public class DMLTranslator
                                        target.getValueType(), 
ParamBuiltinOp.LIST, paramHops);
                                break;
 
+                       case COUNT_DISTINCT_APPROX:
+                               // Default direction and data type
+                               Direction dir = Direction.RowCol;
+                               DataType dataType = DataType.SCALAR;
+
+                               LiteralOp dirOp = (LiteralOp) 
paramHops.get("dir");
+                               if (dirOp != null) {
+                                       String dirString = 
dirOp.getStringValue().toUpperCase();
+                                       if 
(dirString.equals(Direction.RowCol.toString())) {
+                                               dir = Direction.RowCol;
+                                               dataType = DataType.SCALAR;
+                                       } else if 
(dirString.equals(Direction.Row.toString())) {
+                                               dir = Direction.Row;
+                                               dataType = DataType.MATRIX;
+                                       } else if 
(dirString.equals(Direction.Col.toString())) {
+                                               dir = Direction.Col;
+                                               dataType = DataType.MATRIX;
+                                       }
+                               }
+
+                               currBuiltinOp = new 
AggUnaryOp(target.getName(), dataType, target.getValueType(),
+                                               
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
+                               break;
                        default:
                                throw new 
ParseException(source.printErrorLocation() + 
                                        
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + 
source.getOpCode());
@@ -2335,11 +2357,9 @@ public class DMLTranslator
                case PROD:
                case VAR:
                case COUNT_DISTINCT:
-               case COUNT_DISTINCT_APPROX:
                        currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.SCALAR, target.getValueType(),
-                               AggOp.valueOf(source.getOpCode().name()), 
Direction.RowCol, expr);
+                                       
AggOp.valueOf(source.getOpCode().name()), Direction.RowCol, expr);
                        break;
-
                case MEAN:
                        if ( expr2 == null ) {
                                // example: x = mean(Y);
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 442d1e6..6b6ca9b 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -29,13 +29,14 @@ import java.util.Set;
 import java.util.stream.Collectors;
 
 import org.antlr.v4.runtime.ParserRuleContext;
-import org.apache.wink.json4j.JSONObject;
 import org.apache.sysds.common.Builtins;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ParamBuiltinOp;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
 import org.apache.sysds.runtime.util.CollectionUtils;
+import org.apache.wink.json4j.JSONObject;
 
 
 public class ParameterizedBuiltinFunctionExpression extends DataIdentifier 
@@ -245,6 +246,10 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        validateParamserv(output, conditional);
                        break;
 
+               case COUNT_DISTINCT_APPROX:
+                       validateCountDistinctApprox(output, conditional);
+                       break;
+
                default: //always unconditional (because unsupported operation)
                        //handle common issue of transformencode
                        if( getOpCode()==Builtins.TRANSFORMENCODE )
@@ -258,7 +263,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
 
        private void validateAutoDiff(DataIdentifier output, boolean 
conditional) {
                //validate data / metadata (recode maps)
-               checkDataType("lineage", LINEAGE_TRACE, DataType.LIST, 
conditional);
+               checkDataType(false, "lineage", LINEAGE_TRACE, DataType.LIST, 
conditional);
 
                //validate specification
                checkDataValueType(false, "lineage", LINEAGE_TRACE, 
DataType.LIST, ValueType.UNKNOWN, conditional);
@@ -266,7 +271,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                // set output characteristics
                output.setDataType(DataType.LIST);
                output.setValueType(ValueType.UNKNOWN);
-               // TODO dimension should be set to -1 but could not set due to 
lineage parsing error in Spark contetx
+               // TODO dimension should be set to -1 but could not set due to 
lineage parsing error in Spark context
                output.setDimensions(varParams.size(), 1);
                // output.setDimensions(-1, 1);
                output.setBlocksize(-1);
@@ -319,9 +324,9 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                checkInvalidParameters(getOpCode(), getVarParams(), valid);
 
                // check existence and correctness of parameters
-               checkDataType(fname, Statement.PS_MODEL, DataType.LIST, 
conditional); // check the model which is the only non-parameterized argument
-               checkDataType(fname, Statement.PS_FEATURES, DataType.MATRIX, 
conditional);
-               checkDataType(fname, Statement.PS_LABELS, DataType.MATRIX, 
conditional);
+               checkDataType(false, fname, Statement.PS_MODEL, DataType.LIST, 
conditional); // check the model which is the only non-parameterized argument
+               checkDataType(false, fname, Statement.PS_FEATURES, 
DataType.MATRIX, conditional);
+               checkDataType(false, fname, Statement.PS_LABELS, 
DataType.MATRIX, conditional);
                checkDataValueType(true, fname, Statement.PS_VAL_FEATURES, 
DataType.MATRIX, ValueType.FP64, conditional);
                checkDataValueType(true, fname, Statement.PS_VAL_LABELS, 
DataType.MATRIX, ValueType.FP64, conditional);
                checkDataValueType(false, fname, Statement.PS_UPDATE_FUN, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -347,6 +352,99 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                output.setBlocksize(-1);
        }
 
+       private void validateCountDistinctApprox(DataIdentifier output, boolean 
conditional) {
+               Set<String> validTypeNames = CollectionUtils.asSet("KMV");
+               HashMap<String, Expression> varParams = getVarParams();
+
+               // "data" is the only parameter that is allowed to be unnamed
+               if (varParams.containsKey(null)) {
+                       varParams.put("data", varParams.remove(null));
+               }
+
+               // Validate the number of parameters
+               String fname = getOpCode().getName();
+               String usageMessage = "function " + fname + " takes at least 1 
and at most 3 parameters";
+               if (varParams.size() < 1) {
+                       raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
+               }
+
+               if (varParams.size() > 3) {
+                       raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+               }
+
+               // Check parameter names are valid
+               Set<String> validParameterNames = CollectionUtils.asSet("data", 
"type", "dir");
+               checkInvalidParameters(getOpCode(), varParams, 
validParameterNames);
+
+               // Check parameter expression data types match expected
+               checkDataType(false, fname, "data", DataType.MATRIX, 
conditional);
+               checkDataValueType(false, fname, "data", DataType.MATRIX, 
ValueType.FP64, conditional);
+
+               // We need the dimensions of the input matrix to determine the 
output matrix characteristics
+               // Validate data parameter, lookup previously defined var or 
resolve expression
+               Identifier dataId = varParams.get("data").getOutput();
+               if (dataId == null) {
+                       raiseValidateError("Cannot parse input parameter 
\"data\" to function " + fname, conditional);
+               }
+
+               checkStringParam(true, fname, "type", conditional);
+               // Check data value of "type" parameter
+               if (varParams.keySet().contains("type")) {
+                       String typeString = 
varParams.get("type").toString().toUpperCase();
+                       if (!validTypeNames.contains(typeString)) {
+                               raiseValidateError("Unrecognized type for 
optional parameter " + typeString, conditional);
+                       }
+               } else {
+                       // default to KMV
+                       addVarParam("type", new StringIdentifier("KMV", this));
+               }
+
+               checkStringParam(true, fname, "dir", conditional);
+               // Check data value of "dir" parameter
+               if (varParams.keySet().contains("dir")) {
+                       String directionString = 
varParams.get("dir").toString().toUpperCase();
+
+                       // Set output type and dimensions based on direction
+
+                       // "r" -> count across all rows, resulting in a Mx1 
matrix
+                       if 
(directionString.equals(Types.Direction.Row.toString())) {
+                               output.setDataType(DataType.MATRIX);
+                               output.setDimensions(dataId.getDim1(), 1);
+                               output.setBlocksize(dataId.getBlocksize());
+                               output.setValueType(ValueType.INT64);
+                               output.setNnz(dataId.getDim1());
+
+                       // "c" -> count across all cols, resulting in a 1xN 
matrix
+                       } else if 
(directionString.equals(Types.Direction.Col.toString())) {
+                               output.setDataType(DataType.MATRIX);
+                               output.setDimensions(1, dataId.getDim2());
+                               output.setBlocksize(dataId.getBlocksize());
+                               output.setValueType(ValueType.INT64);
+                               output.setNnz(dataId.getDim2());
+
+                       // "rc" -> count across all rows and cols in input 
matrix, resulting in a single value
+                       } else if 
(directionString.equals(Types.Direction.RowCol.toString())) {
+                               output.setDataType(DataType.SCALAR);
+                               output.setDimensions(0, 0);
+                               output.setBlocksize(0);
+                               output.setValueType(ValueType.INT64);
+                               output.setNnz(1);
+
+                       // unrecognized value for "dir" parameter, should "cr" 
be valid?
+                       } else {
+                               raiseValidateError("Invalid argument: " + 
directionString + " is not recognized");
+                       }
+
+               // default to dir="rc"
+               } else {
+                       output.setDataType(DataType.SCALAR);
+                       output.setDimensions(0, 0);
+                       output.setBlocksize(0);
+                       output.setValueType(ValueType.INT64);
+                       output.setNnz(1);
+               }
+       }
+
        private void checkStringParam(boolean optional, String fname, String 
pname, boolean conditional) {
                Expression param = getVarParam(pname);
                if (param == null) {
@@ -365,7 +463,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        private void validateTokenize(DataIdentifier output, boolean 
conditional)
        {
                //validate data / metadata (recode maps)
-               checkDataType("tokenize", TF_FN_PARAM_DATA, DataType.FRAME, 
conditional);
+               checkDataType(false, "tokenize", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
 
                //validate specification
                checkDataValueType(false, "tokenize", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -381,8 +479,8 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        private void validateTransformApply(DataIdentifier output, boolean 
conditional) 
        {
                //validate data / metadata (recode maps)
-               checkDataType("transformapply", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
-               checkDataType("transformapply", TF_FN_PARAM_MTD2, 
DataType.FRAME, conditional);
+               checkDataType(false, "transformapply", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
+               checkDataType(false, "transformapply", TF_FN_PARAM_MTD2, 
DataType.FRAME, conditional);
                
                //validate specification
                checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -397,8 +495,8 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        private void validateTransformDecode(DataIdentifier output, boolean 
conditional) 
        {
                //validate data / metadata (recode maps) 
-               checkDataType("transformdecode", TF_FN_PARAM_DATA, 
DataType.MATRIX, conditional);
-               checkDataType("transformdecode", TF_FN_PARAM_MTD2, 
DataType.FRAME, conditional);
+               checkDataType(false, "transformdecode", TF_FN_PARAM_DATA, 
DataType.MATRIX, conditional);
+               checkDataType(false, "transformdecode", TF_FN_PARAM_MTD2, 
DataType.FRAME, conditional);
                
                //validate specification
                checkDataValueType(false, "transformdecode", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -414,7 +512,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        {
                //validate data / metadata (recode maps) 
                Expression exprTarget = getVarParam(Statement.GAGG_TARGET);
-               checkDataType("transformcolmap", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
+               checkDataType(false, "transformcolmap", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
                
                //validate specification
                checkDataValueType(false,"transformcolmap", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -444,7 +542,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        private void validateTransformEncode(DataIdentifier output1, 
DataIdentifier output2, boolean conditional) 
        {
                //validate data / metadata (recode maps) 
-               checkDataType("transformencode", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
+               checkDataType(false, "transformencode", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
                
                //validate specification
                checkDataValueType(false, "transformencode", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
@@ -871,12 +969,18 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                output.setBlocksize(-1);
        }
 
-       private void checkDataType( String fname, String pname, DataType dt, 
boolean conditional ) {
+       private void checkDataType(boolean optional, String fname, String 
pname, DataType dt, boolean conditional) {
                Expression data = getVarParam(pname);
-               if( data==null )
-                       raiseValidateError("Named parameter '" + pname + "' 
missing. Please specify the input.", conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
-               else if( data.getOutput().getDataType() != dt )
-                       raiseValidateError("Input to "+fname+"::"+pname+" must 
be of type '"+dt.toString()+"'. It should not be of type 
'"+data.getOutput().getDataType()+"'.", conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+               if(data == null) {
+                       if(optional)
+                               return;
+                       raiseValidateError("Named parameter '" + pname + "' 
missing. Please specify the input.", conditional,
+                               LanguageErrorCodes.INVALID_PARAMETERS);
+               }
+               else if(data.getOutput().getDataType() != dt)
+                       raiseValidateError("Input to " + fname + "::" + pname + 
" must be of type '" + dt.toString()
+                               + "'. It should not be of type '" + 
data.getOutput().getDataType() + "'.", conditional,
+                               LanguageErrorCodes.INVALID_PARAMETERS);
        }
 
        private void checkDataValueType(boolean optional, String fname, String 
pname, DataType dt, ValueType vt,
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
index 7a31f13..23ad02c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeEstimatorUltraSparse.java
@@ -26,7 +26,7 @@ import 
org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
 import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import 
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 
 /**
  * UltraSparse compressed size estimator (examines entire dataset).
@@ -39,7 +39,7 @@ public class CompressedSizeEstimatorUltraSparse extends 
CompressedSizeEstimator
 
        private CompressedSizeEstimatorUltraSparse(MatrixBlock data, 
CompressionSettings compSettings) {
                super(data, compSettings);
-               CountDistinctOperator op = new 
CountDistinctOperator(CountDistinctTypes.COUNT);
+               CountDistinctOperator op = new 
CountDistinctOperator(CountDistinctOperatorTypes.COUNT);
                final int _numRows = getNumRows();
 
                if(LOG.isDebugEnabled()) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index 4f423c2..7866f23 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -223,7 +223,7 @@ public class Builtin extends ValueFunction
                                // compared and performs just the value part of 
the comparison. We
                                // return an integer cast down to a double, 
since the aggregation
                                // API doesn't have any way to return anything 
but a double. The
-                               // integer returned takes on three posssible 
values: //
+                               // integer returned takes on three possible 
values: //
                                // .     0 => keep the index associated with 
in1 //
                                // .     1 => use the index associated with in2 
//
                                // .     2 => use whichever index is higher 
(tie in value) //
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index c08985b..d3b8ad6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -96,7 +96,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "uacvar"  , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uamax"   , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uarmax"  , 
CPType.AggregateUnary);
-               String2CPInstructionType.put( "uarimax", CPType.AggregateUnary);
+               String2CPInstructionType.put( "uarimax" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacmax"  , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uamin"   , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uarmin"  , 
CPType.AggregateUnary);
@@ -110,13 +110,15 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "uac*"    , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uatrace" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uaktrace", 
CPType.AggregateUnary);
-               String2CPInstructionType.put( "nrow"    ,CPType.AggregateUnary);
-               String2CPInstructionType.put( "ncol"    ,CPType.AggregateUnary);
-               String2CPInstructionType.put( "length"  ,CPType.AggregateUnary);
-               String2CPInstructionType.put( "exists"  ,CPType.AggregateUnary);
-               String2CPInstructionType.put( "lineage" ,CPType.AggregateUnary);
+               String2CPInstructionType.put( "nrow"    , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "ncol"    , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "length"  , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "exists"  , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "lineage" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacd"    , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdap"  , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uacdapr" , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uacdapc" , 
CPType.AggregateUnary);
 
                String2CPInstructionType.put( "uaggouterchain", 
CPType.UaggOuterChain);
                
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 56cd49a..965617d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -42,6 +42,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import 
org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
+import 
org.apache.sysds.runtime.instructions.spark.AggregateUnarySketchSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.AppendMSPInstruction;
@@ -62,6 +63,7 @@ import 
org.apache.sysds.runtime.instructions.spark.CumulativeOffsetSPInstruction
 import org.apache.sysds.runtime.instructions.spark.DeCompressionSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.DnnSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
+import org.apache.sysds.runtime.instructions.spark.LIBSVMReblockSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmChainSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.MatrixReshapeSPInstruction;
@@ -87,7 +89,7 @@ import 
org.apache.sysds.runtime.instructions.spark.UnaryFrameSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
 import org.apache.sysds.runtime.instructions.spark.ZipmmSPInstruction;
-import org.apache.sysds.runtime.instructions.spark.LIBSVMReblockSPInstruction;
+
 
 public class SPInstructionParser extends InstructionParser
 {
@@ -110,7 +112,7 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uacvar"  , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uamax"   , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uarmax"  , 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uarimax" ,  
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uarimax" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uacmax"  , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uamin"   , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uarmin"  , 
SPType.AggregateUnary);
@@ -124,6 +126,12 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uac*"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uatrace" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uaktrace", 
SPType.AggregateUnary);
+               String2SPInstructionType.put( "uacdap"  , 
SPType.AggregateUnary);
+
+               // Aggregate unary sketch operators
+               String2SPInstructionType.put( "uacdap" , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdapr", 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdapc", 
SPType.AggregateUnarySketch);
 
                //binary aggregate operators (matrix multiplication operators)
                String2SPInstructionType.put( "mapmm"      , SPType.MAPMM);
@@ -388,6 +396,9 @@ public class SPInstructionParser extends InstructionParser
                        case AggregateUnary:
                                return 
AggregateUnarySPInstruction.parseInstruction(str);
 
+                       case AggregateUnarySketch:
+                               return 
AggregateUnarySketchSPInstruction.parseInstruction(str);
+
                        case AggregateTernary:
                                return 
AggregateTernarySPInstruction.parseInstruction(str);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index ef1ff08..fbcf6ff 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -28,6 +29,9 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.data.BasicTensorBlock;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.lineage.LineageDedupUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
@@ -82,8 +86,28 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                        in1, out, AUType.COUNT_DISTINCT, opcode, str);
                }
                else if(opcode.equalsIgnoreCase("uacdap")){
-                       return new AggregateUnaryCPInstruction(new 
SimpleOperator(null),
-                       in1, out, AUType.COUNT_DISTINCT_APPROX, opcode, str);
+                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+                                       .setDirection(Types.Direction.RowCol)
+                                       
.setIndexFunction(ReduceAll.getReduceAllFnObject());
+
+                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
+                                       opcode, str);
+               }
+               else if(opcode.equalsIgnoreCase("uacdapr")){
+                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+                                       .setDirection(Types.Direction.Row)
+                                       
.setIndexFunction(ReduceCol.getReduceColFnObject());
+
+                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
+                                       opcode, str);
+               }
+               else if(opcode.equalsIgnoreCase("uacdapc")){
+                       CountDistinctOperator op = new 
CountDistinctOperator(AUType.COUNT_DISTINCT_APPROX)
+                                       .setDirection(Types.Direction.Col)
+                                       
.setIndexFunction(ReduceRow.getReduceRowFnObject());
+
+                       return new AggregateUnaryCPInstruction(op, in1, out, 
AUType.COUNT_DISTINCT_APPROX,
+                                       opcode, str);
                }
                else if(opcode.equalsIgnoreCase("uarimax") || 
opcode.equalsIgnoreCase("uarimin")){
                        // parse with number of outputs
@@ -171,17 +195,55 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                ec.setScalarOutput(output_name, new 
StringObject(out));
                                break;
                        }
-                       case COUNT_DISTINCT:
-                       case COUNT_DISTINCT_APPROX: {
+                       case COUNT_DISTINCT: {
                                if( 
!ec.getVariables().keySet().contains(input1.getName()) )
                                        throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
                                MatrixBlock input = 
ec.getMatrixInput(input1.getName());
                                CountDistinctOperator op = new 
CountDistinctOperator(_type);
+                               //TODO add support for row or col count 
distinct.
                                int res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
                                ec.releaseMatrixInput(input1.getName());
                                ec.setScalarOutput(output_name, new 
IntObject(res));
                                break;
                        }
+                       case COUNT_DISTINCT_APPROX: {
+                               
if(!ec.getVariables().keySet().contains(input1.getName())) {
+                                       throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
+                               }
+
+                               MatrixBlock input = 
ec.getMatrixInput(input1.getName());
+                               if (!(_optr instanceof CountDistinctOperator)) {
+                                       throw new DMLRuntimeException("Operator 
should be instance of " + CountDistinctOperator.class.getSimpleName());
+                               }
+
+                               CountDistinctOperator op = 
(CountDistinctOperator) _optr;  // It is safe to cast at this point
+
+                               if (op.getDirection().isRowCol()) {
+                                       int res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
+                                       ec.releaseMatrixInput(input1.getName());
+                                       ec.setScalarOutput(output_name, new 
IntObject(res));
+                               } else if (op.getDirection().isRow()) {
+                                       //TODO Do not slice out the matrix but 
directly process on the input
+                                       MatrixBlock res = input.slice(0, 
input.getNumRows() - 1, 0, 0);
+                                       for (int i = 0; i < input.getNumRows(); 
++i) {
+                                               res.setValue(i, 0, 
LibMatrixCountDistinct.estimateDistinctValues(input.slice(i, i), op));
+                                       }
+                                       ec.releaseMatrixInput(input1.getName());
+                                       ec.setMatrixOutput(output_name, res);
+                               } else if (op.getDirection().isCol()) {
+                                       //TODO Do not slice out the matrix but 
directly process on the input
+                                       MatrixBlock res = input.slice(0, 0, 0, 
input.getNumColumns() - 1);
+                                       for (int j = 0; j < 
input.getNumColumns(); ++j) {
+                                               res.setValue(0, j, 
LibMatrixCountDistinct.estimateDistinctValues(input.slice(0, input.getNumRows() 
- 1, j, j), op));
+                                       }
+                                       ec.releaseMatrixInput(input1.getName());
+                                       ec.setMatrixOutput(output_name, res);
+                               } else {
+                                       throw new 
DMLRuntimeException("Direction for CountDistinctOperator not recognized");
+                               }
+
+                               break;
+                       }
                        default: {
                                AggregateUnaryOperator au_op = 
(AggregateUnaryOperator) _optr;
                                if (input1.getDataType() == DataType.MATRIX) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 04b6650..38b032c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -289,7 +289,6 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                public RDDUAggValueFunction( AggregateUnaryOperator op, int 
blen ) {
                        _op = op;
                        _blen = blen;
-                       _blen = blen;
                        
                        _ix = new MatrixIndexes(1,1);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
new file mode 100644
index 0000000..71bc75f
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.spark;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.AggBinaryOp;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Hash;
+import scala.Tuple2;
+
+public class AggregateUnarySketchSPInstruction extends UnarySPInstruction {
+    private AggBinaryOp.SparkAggType aggtype;
+    private CountDistinctOperator op;
+
+    protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in, 
CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) {
+        super(SPType.AggregateUnarySketch, op, in, out, opcode, instr);
+        this.op = (CountDistinctOperator) super.getOperator();
+
+        if (opcode.equals("uacdap")) {
+            this.op.setDirection(Types.Direction.RowCol)
+                    .setIndexFunction(ReduceAll.getReduceAllFnObject());
+        } else if (opcode.equals("uacdapr")) {
+            this.op.setDirection(Types.Direction.Row)
+                    .setIndexFunction(ReduceCol.getReduceColFnObject());
+        } else if (opcode.equals("uacdapc")) {
+            this.op.setDirection(Types.Direction.Col)
+                    .setIndexFunction(ReduceRow.getReduceRowFnObject());
+        } else {
+            throw new DMLException("Unrecognized opcode " + opcode);
+        }
+
+        this.aggtype = aggtype;
+    }
+
+    public static AggregateUnarySketchSPInstruction parseInstruction(String 
str) {
+        String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+        InstructionUtils.checkNumFields(parts, 3);
+        String opcode = parts[0];
+
+        CPOperand in1 = new CPOperand(parts[1]);
+        CPOperand out = new CPOperand(parts[2]);
+        AggBinaryOp.SparkAggType aggtype = 
AggBinaryOp.SparkAggType.valueOf(parts[3]);
+
+        CountDistinctOperator cdop = new 
CountDistinctOperator(CountDistinctOperatorTypes.KMV, Hash.HashType.LinearHash);
+
+        return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype, 
opcode, str);
+    }
+
+    @Override
+    public void processInstruction(ExecutionContext ec) {
+        if (input1.getDataType() == Types.DataType.MATRIX) {
+            processMatrixSketch(ec);
+        } else {
+            processTensorSketch(ec);
+        }
+    }
+
+    private void processMatrixSketch(ExecutionContext ec) {
+        SparkExecutionContext sec = (SparkExecutionContext)ec;
+
+        //get input
+        JavaPairRDD<MatrixIndexes, MatrixBlock> in = 
sec.getBinaryMatrixBlockRDDHandleForVariable(input1.getName());
+        JavaPairRDD<MatrixIndexes, MatrixBlock> out = in;
+
+        // dir = RowCol and (dim1() > 1000 || dim2() > 1000)
+        if (aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
+
+            // Create a single sketch and derive approximate count distinct 
from the sketch
+            JavaRDD<CorrMatrixBlock> out1 = out.map(new 
AggregateUnarySketchCreateFunction(this.op));
+
+            // Using fold() instead of reduce() for stable aggregation
+            // Instantiating CorrMatrixBlock mutable buffer with empty matrix 
block so that it can be serialized properly
+            CorrMatrixBlock out2 =
+                    out1.fold(new CorrMatrixBlock(new MatrixBlock()),
+                              new 
AggregateUnarySketchUnionAllFunction(this.op));
+
+            MatrixBlock out3 = 
LibMatrixCountDistinct.countDistinctValuesFromSketch(out2, this.op);
+
+            // put output block into symbol table (no lineage because single 
block)
+            // this also includes implicit maintenance of matrix 
characteristics
+            sec.setMatrixOutput(output.getName(), out3);
+        } else {
+
+            if (aggtype != AggBinaryOp.SparkAggType.NONE && aggtype != 
AggBinaryOp.SparkAggType.MULTI_BLOCK) {
+                throw new DMLRuntimeException(String.format("Unsupported 
aggregation type: %s", aggtype));
+            }
+
+            JavaPairRDD<MatrixIndexes, MatrixBlock> out1;
+            JavaPairRDD<MatrixIndexes, CorrMatrixBlock> out2;
+            JavaPairRDD<MatrixIndexes, MatrixBlock> out3;
+
+            // dir = Row || Col || RowCol and (dim1() <= 1000 || dim2() <= 
1000)
+            if (aggtype == AggBinaryOp.SparkAggType.NONE) {
+                // Input matrix is small enough for a single index, so there 
is no need to execute index function.
+                // Reuse the CreateCombinerFunction(), although there is no 
need to merge values within the same
+                // partition, or combiners across partitions for that matter.
+                out2 = out.mapValues(new 
AggregateUnarySketchCreateCombinerFunction(this.op));
+
+            // aggType = MULTI_BLOCK: dir = Row || Col and (dim1() > 1000 || 
dim2() > 1000)
+            } else {
+                // Execute index function to group rows/columns together based 
on aggregation direction
+                out1 = out.mapToPair(new RowColGroupingFunction(this.op));
+
+                // Create sketch per index
+                out2 = out1.combineByKey(new 
AggregateUnarySketchCreateCombinerFunction(this.op),
+                        new AggregateUnarySketchMergeValueFunction(this.op),
+                        new 
AggregateUnarySketchMergeCombinerFunction(this.op));
+            }
+
+            out3 = out2.mapValues(new 
CalculateAggregateSketchFunction(this.op));
+
+            updateUnaryAggOutputDataCharacteristics(sec, 
this.op.getIndexFunction());
+
+            // put output RDD handle into symbol table
+            sec.setRDDHandleForVariable(output.getName(), out3);
+            sec.addLineageRDD(output.getName(), input1.getName());
+        }
+    }
+
+    private void processTensorSketch(ExecutionContext ec) {
+        throw new NotImplementedException("Aggregate sketch instruction for 
tensors has not been implemented yet.");
+    }
+
+    private static class AggregateUnarySketchCreateFunction implements 
Function<Tuple2<MatrixIndexes, MatrixBlock>, CorrMatrixBlock> {
+        private static final long serialVersionUID = 7295176181965491548L;
+        private CountDistinctOperator op;
+
+        public AggregateUnarySketchCreateFunction(CountDistinctOperator op) {
+            this.op = op;
+        }
+
+        @Override
+        public CorrMatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) 
throws Exception {
+            MatrixIndexes ixIn = arg0._1();
+            MatrixBlock blkIn = arg0._2();
+
+            MatrixIndexes ixOut = new MatrixIndexes();
+            this.op.getIndexFunction().execute(ixIn, ixOut);
+
+            return LibMatrixCountDistinct.createSketch(blkIn, this.op);
+        }
+    }
+
+    private static class AggregateUnarySketchUnionAllFunction implements 
Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
+        private static final long serialVersionUID = -3799519241499062936L;
+        private CountDistinctOperator op;
+
+        public AggregateUnarySketchUnionAllFunction(CountDistinctOperator op) {
+            this.op = op;
+        }
+
+        @Override
+        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1) throws Exception {
+
+            // Input matrix blocks must have corresponding sketch metadata
+            if (arg0.getCorrection() == null && arg1.getCorrection() == null) {
+                throw new DMLRuntimeException("Corrupt sketch: metadata is 
missing");
+            }
+
+            if ((arg0.getValue().getNumRows() == 0 && 
arg0.getValue().getNumColumns() == 0) || arg0.getCorrection() == null) {
+                arg0.set(arg1.getValue(), arg1.getCorrection());
+                return arg0;
+            } else if ((arg1.getValue().getNumRows() == 0 && 
arg1.getValue().getNumColumns() == 0) || arg1.getCorrection() == null) {
+                return arg0;
+            }
+
+            return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+        }
+    }
+
+    private static class RowColGroupingFunction implements 
PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
+
+        private static final long serialVersionUID = -3456633769452405482L;
+        private CountDistinctOperator _op;
+
+        public RowColGroupingFunction(CountDistinctOperator op) {
+            this._op = op;
+        }
+
+        @Override
+        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, 
MatrixBlock> arg0) throws Exception {
+            MatrixIndexes idxIn = arg0._1();
+            MatrixBlock blkIn = arg0._2();
+
+            MatrixIndexes idxOut = new MatrixIndexes();
+            MatrixBlock blkOut = blkIn;  // Do not create sketch yet
+            this._op.getIndexFunction().execute(idxIn, idxOut);
+
+            return new Tuple2<>(idxOut, blkOut);
+        }
+    }
+
+    private static class AggregateUnarySketchCreateCombinerFunction implements 
Function<MatrixBlock, CorrMatrixBlock>
+    {
+        private static final long serialVersionUID = 8997980606986435297L;
+        private final CountDistinctOperator op;
+
+        private 
AggregateUnarySketchCreateCombinerFunction(CountDistinctOperator op) {
+            this.op = op;
+        }
+
+        @Override
+        public CorrMatrixBlock call(MatrixBlock arg0)
+                throws Exception {
+
+            return LibMatrixCountDistinct.createSketch(arg0, this.op);
+        }
+    }
+
+    private static class AggregateUnarySketchMergeValueFunction implements 
Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock>
+    {
+        private static final long serialVersionUID = -7006864809860460549L;
+        private CountDistinctOperator op;
+
+        public AggregateUnarySketchMergeValueFunction(CountDistinctOperator 
op) {
+            this.op = op;
+        }
+
+        @Override
+        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) 
throws Exception {
+            CorrMatrixBlock arg1WithCorr = 
LibMatrixCountDistinct.createSketch(arg1, this.op);
+            return LibMatrixCountDistinct.unionSketch(arg0, arg1WithCorr, 
this.op);
+        }
+    }
+
+    private static class AggregateUnarySketchMergeCombinerFunction implements 
Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock>
+    {
+        private static final long serialVersionUID = 172215143740379070L;
+        private CountDistinctOperator op;
+
+        public AggregateUnarySketchMergeCombinerFunction(CountDistinctOperator 
op) {
+            this.op = op;
+        }
+
+        @Override
+        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1) throws Exception {
+            return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+        }
+    }
+
+    private static class CalculateAggregateSketchFunction implements 
Function<CorrMatrixBlock, MatrixBlock>
+    {
+        private static final long serialVersionUID = 7504873483231717138L;
+        private CountDistinctOperator op;
+
+        public CalculateAggregateSketchFunction(CountDistinctOperator op) {
+            this.op = op;
+        }
+
+        @Override
+        public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
+            return LibMatrixCountDistinct.countDistinctValuesFromSketch(arg0, 
this.op);
+        }
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
index c9935b1..830ba4d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
@@ -37,7 +37,7 @@ public abstract class SPInstruction extends Instruction {
                CentralMoment, Covariance, QSort, QPick,
                ParameterizedBuiltin, MAppend, RAppend, GAppend, 
GAlignedAppend, Rand,
                MatrixReshape, Ctable, Quaternary, CumsumAggregate, 
CumsumOffset, BinUaggChain, UaggOuterChain,
-               Write, SpoofFused, Dnn
+               Write, SpoofFused, Dnn, AggregateUnarySketch
        }
 
        protected final SPType _sptype;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index fa95aee..4b13abc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -19,9 +19,7 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
-import java.util.Collections;
 import java.util.HashSet;
-import java.util.PriorityQueue;
 import java.util.Set;
 
 import org.apache.commons.lang.NotImplementedException;
@@ -32,16 +30,17 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import 
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
-import org.apache.sysds.utils.Hash;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 import org.apache.sysds.utils.Hash.HashType;
 
 /**
  * This class contains various methods for counting the number of distinct 
values inside a MatrixBlock
  */
-public class LibMatrixCountDistinct {
-       private static final Log LOG = 
LogFactory.getLog(LibMatrixCountDistinct.class.getName());
+public interface LibMatrixCountDistinct {
+       static final Log LOG = 
LogFactory.getLog(LibMatrixCountDistinct.class.getName());
 
        /**
         * The minimum number NonZero of cells in the input before using 
approximate techniques for counting number of
@@ -49,10 +48,6 @@ public class LibMatrixCountDistinct {
         */
        public static int minimumSize = 1024;
 
-       private LibMatrixCountDistinct() {
-               // Prevent instantiation via private constructor.
-       }
-
        /**
         * Public method to count the number of distinct values inside a 
matrix. Depending on which CountDistinctOperator
         * selected it either gets the absolute number or a estimated value.
@@ -61,7 +56,7 @@ public class LibMatrixCountDistinct {
         * 
         * TODO: Add support for distributed spark operations
         * 
-        * TODO: If the MatrixBlock type is CompressedMatrix, simply read the 
vaules from the ColGroups.
+        * TODO: If the MatrixBlock type is CompressedMatrix, simply read the 
values from the ColGroups.
         * 
         * @param in the input matrix to count number distinct values in
         * @param op the selected operator to use
@@ -69,11 +64,12 @@ public class LibMatrixCountDistinct {
         */
        public static int estimateDistinctValues(MatrixBlock in, 
CountDistinctOperator op) {
                int res = 0;
-               if(op.operatorType == CountDistinctTypes.KMV &&
-                       (op.hashType == HashType.ExpHash || op.hashType == 
HashType.StandardJava)) {
-                       throw new DMLException("Invalid hashing configuration 
using " + op.hashType + " and " + op.operatorType);
+               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV &&
+                       (op.getHashType() == HashType.ExpHash || 
op.getHashType() == HashType.StandardJava)) {
+                       throw new DMLException(
+                               "Invalid hashing configuration using " + 
op.getHashType() + " and " + op.getOperatorType());
                }
-               else if(op.operatorType == CountDistinctTypes.HLL) {
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL) 
{
                        throw new NotImplementedException("HyperLogLog not 
implemented");
                }
                // shortcut in simplest case.
@@ -84,27 +80,27 @@ public class LibMatrixCountDistinct {
                        res = countDistinctValuesNaive(in);
                }
                else {
-                       switch(op.operatorType) {
+                       switch(op.getOperatorType()) {
                                case COUNT:
                                        res = countDistinctValuesNaive(in);
                                        break;
                                case KMV:
-                                       res = countDistinctValuesKVM(in, op);
+                                       res = new 
KMVSketch(op).getScalarValue(in);
                                        break;
                                default:
                                        throw new DMLException("Invalid or not 
implemented Estimator Type");
                        }
                }
 
-               if(res == 0)
+               if(res <= 0)
                        throw new DMLRuntimeException("Impossible estimate of 
distinct values");
                return res;
        }
 
        /**
-        * Naive implementation of counting Distinct values.
+        * Naive implementation of counting distinct values.
         * 
-        * Benefit Precise, but uses memory, on the scale of inputs number of 
distinct values.
+        * Benefit: precise, but uses memory, on the scale of inputs number of 
distinct values.
         * 
         * @param in The input matrix to count number distinct values in
         * @return The absolute distinct count
@@ -151,155 +147,35 @@ public class LibMatrixCountDistinct {
        }
 
        private static Set<Double> countDistinctValuesNaive(double[] 
valuesPart, Set<Double> distinct) {
-               for(double v : valuesPart) {
+               for(double v : valuesPart) 
                        distinct.add(v);
-               }
                return distinct;
        }
 
-       /**
-        * KMV synopsis(for k minimum values) Distinct-Value Estimation
-        * 
-        * Kevin S. Beyer, Peter J. Haas, Berthold Reinwald, Yannis Sismanis, 
Rainer Gemulla:
-        * 
-        * On synopses for distinct‐value estimation under multiset operations. 
SIGMOD 2007
-        * 
-        * TODO: Add multi-threaded version
-        * 
-        * @param in The Matrix Block to estimate the number of distinct values 
in
-        * @return The distinct count estimate
-        */
-       private static int countDistinctValuesKVM(MatrixBlock in, 
CountDistinctOperator op) {
-
-               // D is the number of possible distinct values in the 
MatrixBlock.
-               // plus 1 to take account of 0 input.
-               long D = in.getNonZeros() + 1;
-
-               /**
-                * To ensure that the likelihood to hash to the same value we 
need O(D^2) positions to hash to assign. If the
-                * value is higher than int (which is the area we hash to) then 
use Integer Max value as largest hashing space.
-                */
-               long tmp = D * D;
-               int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : 
(int) tmp;
-               LOG.debug("M not forced to int size: " + tmp);
-               LOG.debug("M: " + M);
-               /**
-                * The estimator is asymptotically unbiased as k becomes large, 
but memory usage also scales with k. Furthermore k
-                * value must be within range: D >> k >> 0
-                */
-               int k = D > 64 ? 64 : (int) D;
-               SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
-
-               countDistinctValuesKVM(in, op.hashType, k, spq, M);
-
-               LOG.debug("M: " + M);
-               LOG.debug("smallest hash:" + spq.peek());
-               LOG.debug("spq: " + spq.toString());
-
-               if(spq.size() < k) {
-                       return spq.size();
-               }
-               else {
-                       double U_k = (double) spq.poll() / (double) M;
-                       LOG.debug("U_k : " + U_k);
-                       double estimate = (double) (k - 1) / U_k;
-                       LOG.debug("Estimate: " + estimate);
-                       double ceilEstimate = Math.min(estimate, (double) D);
-                       LOG.debug("Ceil worst case: " + D);
-                       return (int) ceilEstimate;
-               }
-       }
-
-       private static void countDistinctValuesKVM(MatrixBlock in, HashType 
hashType, int k, SmallestPriorityQueue spq,
-               int m) {
-               double[] data;
-               if(in.isEmpty())
-                       spq.add(0);
-               else if(in instanceof CompressedMatrixBlock)
-                       throw new NotImplementedException();
-               else if(in.sparseBlock != null) {
-                       SparseBlock sb = in.sparseBlock;
-                       if(in.sparseBlock.isContiguous()) {
-                               data = sb.values(0);
-                               countDistinctValuesKVM(data, hashType, k, spq, 
m);
-                       }
-                       else {
-                               for(int i = 0; i < in.getNumRows(); i++) {
-                                       if(!sb.isEmpty(i)) {
-                                               data = in.sparseBlock.values(i);
-                                               countDistinctValuesKVM(data, 
hashType, k, spq, m);
-                                       }
-                               }
-                       }
-               }
-               else {
-                       DenseBlock db = in.denseBlock;
-                       final int bil = db.index(0);
-                       final int biu = db.index(in.rlen);
-                       for(int i = bil; i <= biu; i++) {
-                               data = db.valuesAt(i);
-                               countDistinctValuesKVM(data, hashType, k, spq, 
m);
-                       }
-               }
+       public static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock 
arg0, CountDistinctOperator op) {
+               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+                       return new KMVSketch(op).getMatrixValue(arg0);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+                       throw new NotImplementedException("Not implemented 
yet");
+               else
+                       throw new NotImplementedException("Not implemented 
yet");
        }
 
-       private static void countDistinctValuesKVM(double[] data, HashType 
hashType, int k, SmallestPriorityQueue spq,
-               int m) {
-               for(double fullValue : data) {
-                       int hash = Hash.hash(fullValue, hashType);
-                       int v = (Math.abs(hash)) % (m - 1) + 1;
-                       spq.add(v);
-               }
+       public static CorrMatrixBlock createSketch(MatrixBlock blkIn, 
CountDistinctOperator op) {
+               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+                       return new KMVSketch(op).create(blkIn);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+                       throw new NotImplementedException("Not implemented 
yet");
+               else
+                       throw new NotImplementedException("Not implemented 
yet");
        }
 
-       /**
-        * Deceiving name, but is used to contain the k smallest values 
inserted.
-        * 
-        * TODO: add utility method to join two partitions
-        * 
-        * TODO: Replace Standard Java Set and Priority Queue with optimized 
versions.
-        */
-       private static class SmallestPriorityQueue {
-               private Set<Integer> containedSet;
-               private PriorityQueue<Integer> smallestHashes;
-               private int k;
-
-               public SmallestPriorityQueue(int k) {
-                       smallestHashes = new PriorityQueue<>(k, 
Collections.reverseOrder());
-                       containedSet = new HashSet<>(1);
-                       this.k = k;
-               }
-
-               public void add(int v) {
-                       if(!containedSet.contains(v)) {
-                               if(smallestHashes.size() < k) {
-                                       smallestHashes.add(v);
-                                       containedSet.add(v);
-                               }
-                               else if(v < smallestHashes.peek()) {
-                                       LOG.trace(smallestHashes.peek() + " -- 
" + v);
-                                       smallestHashes.add(v);
-                                       containedSet.add(v);
-                                       
containedSet.remove(smallestHashes.poll());
-                               }
-                       }
-               }
-
-               public int size() {
-                       return smallestHashes.size();
-               }
-
-               public int peek() {
-                       return smallestHashes.peek();
-               }
-
-               public int poll() {
-                       return smallestHashes.poll();
-               }
-
-               @Override
-               public String toString() {
-                       return smallestHashes.toString();
-               }
+       public static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1, CountDistinctOperator op) {
+               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+                       return new KMVSketch(op).union(arg0, arg1);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
+                       throw new NotImplementedException("Not implemented 
yet");
+               else
+                       throw new NotImplementedException("Not implemented 
yet");
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
new file mode 100644
index 0000000..f9c5f63
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/MatrixSketch.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch;
+
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public interface MatrixSketch<T> {
+
+       /**
+        * Get scalar distinct count from a input matrix block.
+        * 
+        * @param blkIn A input block to estimate the number of distinct values 
in
+        * @return The distinct count estimate
+        */
+       T getScalarValue(MatrixBlock blkIn);
+
+       /**
+        * Obtain matrix distinct count value from estimation Used for 
estimating distinct in rows or columns.
+        * 
+        * @param blkIn The sketch block to extract the count from
+        * @return The result matrix block
+        */
+       public MatrixBlock getMatrixValue(CorrMatrixBlock blkIn);
+
+       /**
+        * Create a initial sketch of a given block.
+        * 
+        * @param blkIn A block to process
+        * @return A sketch
+        */
+       public CorrMatrixBlock create(MatrixBlock blkIn);
+
+       /**
+        * Union two sketches together to from a combined sketch.
+        * 
+        * @param arg0 Sketch one
+        * @param arg1 Sketch two
+        * @return The combined sketch
+        */
+       public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1);
+
+       /**
+        * Intersect two sketches
+        * 
+        * @param arg0 Sketch one
+        * @param arg1 Sketch two
+        * @return The intersected sketch
+        */
+       public CorrMatrixBlock intersection(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1);
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
new file mode 100644
index 0000000..9893e09
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+// Package private
+abstract class CountDistinctApproxSketch implements MatrixSketch<Integer> {
+       CountDistinctOperator op;
+
+       CountDistinctApproxSketch(Operator op) {
+               if(!(op instanceof CountDistinctOperator)) {
+                       throw new DMLRuntimeException(
+                               String.format("Cannot create %s with given 
operator", CountDistinctApproxSketch.class.getSimpleName()));
+               }
+
+               this.op = (CountDistinctOperator) op;
+
+               if(this.op.getDirection() == null) {
+                       throw new DMLRuntimeException("No direction was set for 
the operator");
+               }
+
+               if(!this.op.getDirection().isRow() && 
!this.op.getDirection().isCol() && !this.op.getDirection().isRowCol()) {
+                       throw new DMLRuntimeException(String.format("Unexpected 
direction: %s", this.op.getDirection()));
+               }
+       }
+
+       protected void validateSketchMetadata(MatrixBlock corrBlock) {
+               // (nHashes, k, D) row vector
+               if(corrBlock.getNumColumns() < 3 || corrBlock.getValue(0, 0) < 
0 || corrBlock.getValue(0, 1) < 0 ||
+                       corrBlock.getValue(0, 2) < 0) {
+                       throw new DMLRuntimeException("Sketch metadata is 
corrupt");
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
new file mode 100644
index 0000000..01cfb28
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -0,0 +1,488 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Hash;
+
+/**
+ * KMV synopsis(for k minimum values) Distinct-Value Estimation
+ *
+ * Kevin S. Beyer, Peter J. Haas, Berthold Reinwald, Yannis Sismanis, Rainer 
Gemulla:
+ *
+ * On synopses for distinct‐value estimation under multiset operations. SIGMOD 
2007
+ *
+ * TODO: Add multi-threaded version
+ *
+ */
+public class KMVSketch extends CountDistinctApproxSketch {
+
+       private static final Log LOG = 
LogFactory.getLog(KMVSketch.class.getName());
+
+       public KMVSketch(Operator op) {
+               super(op);
+       }
+
+       @Override
+       public Integer getScalarValue(MatrixBlock in) {
+
+               // D is the number of possible distinct values in the 
MatrixBlock.
+               // plus 1 to take account of 0 input.
+               long D = in.getNonZeros() + 1;
+
+               /**
+                * To ensure that the likelihood to hash to the same value we 
need O(D^2) positions to hash to assign. If the
+                * value is higher than int (which is the area we hash to) then 
use Integer Max value as largest hashing space.
+                */
+               long tmp = D * D;
+               int M = (tmp > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : 
(int) tmp;
+               /**
+                * The estimator is asymptotically unbiased as k becomes large, 
but memory usage also scales with k. Furthermore k
+                * value must be within range: D >> k >> 0
+                */
+               int k = D > 64 ? 64 : (int) D;
+
+               SmallestPriorityQueue spq = getKSmallestHashes(in, k, M);
+
+               if(LOG.isDebugEnabled()) {
+                       LOG.debug("M not forced to int size: " + tmp);
+                       LOG.debug("M: " + M);
+                       LOG.debug("M: " + M);
+                       LOG.debug("kth smallest hash:" + spq.peek());
+                       LOG.debug("spq: " + spq.toString());
+               }
+
+               if(spq.size() < k) {
+                       return spq.size();
+               }
+               else {
+                       double kthSmallestHash = spq.poll();
+                       double U_k = kthSmallestHash / (double) M;
+                       double estimate = (double) (k - 1) / U_k;
+                       double ceilEstimate = Math.min(estimate, (double) D);
+
+                       if(LOG.isDebugEnabled()) {
+                               LOG.debug("U_k : " + U_k);
+                               LOG.debug("Estimate: " + estimate);
+                               LOG.debug("Ceil worst case: " + D);
+                       }
+                       return (int) ceilEstimate;
+               }
+       }
+
+       private SmallestPriorityQueue getKSmallestHashes(MatrixBlock in, int k, 
int M) {
+               SmallestPriorityQueue spq = new SmallestPriorityQueue(k);
+               countDistinctValuesKMV(in, op.getHashType(), k, spq, M);
+
+               return spq;
+       }
+
+       private void countDistinctValuesKMV(MatrixBlock in, Hash.HashType 
hashType, int k, SmallestPriorityQueue spq,
+               int m) {
+               double[] data;
+               if(in.isEmpty())
+                       spq.add(0);
+               else if(in instanceof CompressedMatrixBlock)
+                       throw new NotImplementedException("Cannot approximate 
distinct count for compressed matrices");
+               else if(in.getSparseBlock() != null) {
+                       SparseBlock sb = in.getSparseBlock();
+                       if(sb.isContiguous()) {
+                               data = sb.values(0);
+                               countDistinctValuesKMV(data, hashType, k, spq, 
m);
+                       }
+                       else {
+                               for(int i = 0; i < in.getNumRows(); i++) {
+                                       if(!sb.isEmpty(i)) {
+                                               data = sb.values(i);
+                                               countDistinctValuesKMV(data, 
hashType, k, spq, m);
+                                       }
+                               }
+                       }
+               }
+               else {
+                       DenseBlock db = in.getDenseBlock();
+                       final int bil = db.index(0);
+                       final int biu = db.index(in.getNumRows());
+                       for(int i = bil; i <= biu; i++) {
+                               data = db.valuesAt(i);
+                               countDistinctValuesKMV(data, hashType, k, spq, 
m);
+                       }
+               }
+       }
+
+       private void countDistinctValuesKMV(double[] data, Hash.HashType 
hashType, int k, SmallestPriorityQueue spq, int m) {
+               for(double fullValue : data) {
+                       int hash = Hash.hash(fullValue, hashType);
+                       int v = (Math.abs(hash)) % (m - 1) + 1;
+                       spq.add(v);
+               }
+       }
+
+       @Override
+       public MatrixBlock getMatrixValue(CorrMatrixBlock arg0) {
+               MatrixBlock blkIn = arg0.getValue();
+               if(op.getDirection() == Types.Direction.Row) {
+                       // 1000 x 1 blkOut -> slice out the first column of the 
matrix
+                       MatrixBlock blkOut = blkIn.slice(0, blkIn.getNumRows() 
- 1, 0, 0);
+                       for(int i = 0; i < blkIn.getNumRows(); ++i) {
+                               getDistinctCountFromSketchByIndex(arg0, i, 
blkOut);
+                       }
+
+                       return blkOut;
+               }
+               else if(op.getDirection() == Types.Direction.Col) {
+                       // 1 x 1000 blkOut -> slice out the first row of the 
matrix
+                       MatrixBlock blkOut = blkIn.slice(0, 0, 0, 
blkIn.getNumColumns() - 1);
+                       for(int j = 0; j < blkIn.getNumColumns(); ++j) {
+                               getDistinctCountFromSketchByIndex(arg0, j, 
blkOut);
+                       }
+
+                       return blkOut;
+               }
+               else { // op.getDirection().isRowCol()
+
+                       // 1 x 1 blkOut -> slice out the first row and column 
of the matrix
+                       MatrixBlock blkOut = blkIn.slice(0, 0, 0, 0);
+                       getDistinctCountFromSketchByIndex(arg0, 0, blkOut);
+
+                       return blkOut;
+               }
+       }
+
+       private void getDistinctCountFromSketchByIndex(CorrMatrixBlock arg0, 
int idx, MatrixBlock blkOut) {
+               MatrixBlock blkIn = arg0.getValue();
+               MatrixBlock blkInCorr = arg0.getCorrection();
+
+               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV) {
+                       double kthSmallestHash;
+                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
+                               kthSmallestHash = blkIn.getValue(idx, 0);
+                       }
+                       else { // op.getDirection().isCol()
+                               kthSmallestHash = blkIn.getValue(0, idx);
+                       }
+
+                       double nHashes = blkInCorr.getValue(idx, 0);
+                       double k = blkInCorr.getValue(idx, 1);
+                       double D = blkInCorr.getValue(idx, 2);
+
+                       double D2 = D * D;
+                       double M = (D2 > (long) Integer.MAX_VALUE) ? 
Integer.MAX_VALUE : D2;
+
+                       double ceilEstimate;
+                       if(nHashes != 0 && nHashes < k) {
+                               ceilEstimate = nHashes;
+                       }
+                       else if(nHashes == 0) {
+                               ceilEstimate = 1;
+                       }
+                       else {
+                               double U_k = kthSmallestHash / M;
+                               double estimate = (k - 1) / U_k;
+                               ceilEstimate = Math.min(estimate, D);
+                       }
+
+                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
+                               blkOut.setValue(idx, 0, ceilEstimate);
+                       }
+                       else { // op.getDirection().isCol()
+                               blkOut.setValue(0, idx, ceilEstimate);
+                       }
+               }
+       }
+
+       // Create sketch
+       @Override
+       public CorrMatrixBlock create(MatrixBlock blkIn) {
+
+               // We need a matrix containing sketch metadata per block
+               // N x 3 row vector: (nHashes, k, D)
+               // O(N) extra space
+
+               if(op.getDirection().isRowCol()) {
+                       // (nHashes, k, D) row matrix
+                       MatrixBlock blkOut = new MatrixBlock(blkIn);
+                       MatrixBlock blkOutCorr = new MatrixBlock(1, 3, false);
+                       createSketchByIndex(blkIn, blkOutCorr, 0, blkOut);
+                       return new CorrMatrixBlock(blkOut, blkOutCorr);
+               }
+               else if(op.getDirection().isRow()) {
+                       MatrixBlock blkOut = blkIn;
+                       MatrixBlock blkOutCorr = new 
MatrixBlock(blkIn.getNumRows(), 3, false);
+                       // (nHashes, k, D) row matrix
+                       for(int i = 0; i < blkIn.getNumRows(); ++i) {
+                               createSketchByIndex(blkOut, blkOutCorr, i);
+                       }
+                       return new CorrMatrixBlock(blkOut, blkOutCorr);
+
+               }
+               else if(op.getDirection().isCol()) {
+                       MatrixBlock blkOut = blkIn;
+                       // (nHashes, k, D) row matrix
+                       MatrixBlock blkOutCorr = new 
MatrixBlock(blkIn.getNumColumns(), 3, false);
+                       for(int j = 0; j < blkIn.getNumColumns(); ++j) {
+                               createSketchByIndex(blkOut, blkOutCorr, j);
+                       }
+                       return new CorrMatrixBlock(blkOut, blkOutCorr);
+               }
+               else {
+                       throw new DMLRuntimeException(String.format("Unexpected 
direction: %s", op.getDirection()));
+               }
+       }
+
+       private MatrixBlock sliceMatrixBlockByIndexDirection(MatrixBlock blkIn, 
int idx) {
+               MatrixBlock blkInSlice;
+               if(op.getDirection().isRow()) {
+                       blkInSlice = blkIn.slice(idx, idx);
+               }
+               else if(op.getDirection().isCol()) {
+                       blkInSlice = blkIn.slice(0, blkIn.getNumRows() - 1, 
idx, idx);
+               }
+               else {
+                       blkInSlice = blkIn;
+               }
+
+               return blkInSlice;
+       }
+
+       private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock 
sketchMetaMB, int idx) {
+               createSketchByIndex(blkIn, sketchMetaMB, idx, null);
+       }
+
+       private void createSketchByIndex(MatrixBlock blkIn, MatrixBlock 
sketchMetaMB, int idx, MatrixBlock blkOut) {
+
+               MatrixBlock sketchMB = (blkOut == null) ? blkIn : blkOut;
+
+               MatrixBlock blkInSlice = 
sliceMatrixBlockByIndexDirection(blkIn, idx);
+               long D = blkInSlice.getNonZeros() + 1;
+
+               long D2 = D * D;
+               int M = (D2 > (long) Integer.MAX_VALUE) ? Integer.MAX_VALUE : 
(int) D2;
+               int k = D > 64 ? 64 : (int) D;
+
+               // blkOut is only passed as parameter in case dir == RowCol
+               // This means that the entire block will produce a single 1xK 
sketch-
+               // The output matrix block must be resized and filled with 0 
accordingly
+               if(blkOut != null) {
+                       sketchMB.reset(1, k);
+               }
+
+               if(blkInSlice.getLength() == 1 || blkInSlice.isEmpty()) {
+
+                       // There can only be 1 distinct value for a 1x1 or 
empty matrix
+                       // getMatrixValue() will short circuit and return 1 if 
nHashes = 0
+
+                       // (nHashes, k, D) row matrix
+                       sketchMetaMB.setValue(idx, 0, 0);
+                       sketchMetaMB.setValue(idx, 1, k);
+                       sketchMetaMB.setValue(idx, 2, D);
+
+                       return;
+               }
+
+               SmallestPriorityQueue spq = getKSmallestHashes(blkInSlice, k, 
M);
+               int nHashes = spq.size();
+               assert (nHashes > 0);
+
+               // nHashes != k always
+
+               int i = 0;
+               while(!spq.isEmpty()) {
+                       double toInsert = spq.poll();
+                       if(op.getDirection().isRow()) {
+                               sketchMB.setValue(idx, i, toInsert);
+                       }
+                       else if(op.getDirection().isCol()) {
+                               sketchMB.setValue(i, idx, toInsert);
+                       }
+                       else {
+                               sketchMB.setValue(idx, i, toInsert);
+                       }
+                       ++i;
+               }
+
+               // Last column contains the correction
+               sketchMetaMB.setValue(idx, 0, nHashes);
+               sketchMetaMB.setValue(idx, 1, k);
+               sketchMetaMB.setValue(idx, 2, D);
+       }
+
+       @Override
+       public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1) {
+
+               // Both matrices are guaranteed to be row-/column-aligned
+               MatrixBlock matrix0 = arg0.getValue();
+               MatrixBlock matrix1 = arg1.getValue();
+
+               if(op.getDirection().isRow()) {
+                       // Use the wider of the 2 inputs for stable aggregation.
+                       // The number of rows is always guaranteed to match due 
to col index function execution.
+                       // Therefore, checking the number of columns is 
sufficient.
+                       MatrixBlock combined;
+                       if(matrix0.getNumColumns() > matrix1.getNumColumns()) {
+                               combined = matrix0;
+                       }
+                       else {
+                               combined = matrix1;
+                       }
+                       // (nHashes, k, D)
+                       MatrixBlock combinedCorr = new 
MatrixBlock(matrix0.getNumRows(), 3, false);
+
+                       CorrMatrixBlock blkout = new CorrMatrixBlock(combined, 
combinedCorr);
+                       for(int i = 0; i < matrix0.getNumRows(); ++i) {
+                               unionSketchByIndex(arg0, arg1, i, blkout);
+                       }
+
+                       return blkout;
+
+               }
+               else if(op.getDirection().isCol()) {
+                       // Use the taller of the 2 inputs for stable 
aggregation.
+                       // The number of columns is always guaranteed to match 
due to col index function execution.
+                       // Therefore, checking the number of rows is sufficient.
+                       MatrixBlock combined;
+                       if(matrix0.getNumRows() > matrix1.getNumRows()) {
+                               combined = matrix0;
+                       }
+                       else {
+                               combined = matrix1;
+                       }
+                       // (nHashes, k, D) row vector
+                       MatrixBlock combinedCorr = new 
MatrixBlock(matrix0.getNumColumns(), 3, false);
+
+                       CorrMatrixBlock blkOut = new CorrMatrixBlock(combined, 
combinedCorr);
+                       for(int j = 0; j < matrix0.getNumColumns(); ++j) {
+                               unionSketchByIndex(arg0, arg1, j, blkOut);
+                       }
+
+                       return blkOut;
+
+               }
+               else { // op.getDirection().isRowCol()
+
+                       // Use the wider of the 2 inputs for stable aggregation.
+                       // The number of rows is always guaranteed to match due 
to col index function execution.
+                       // Therefore, checking the number of columns is 
sufficient.
+                       MatrixBlock combined;
+                       if(matrix0.getNumColumns() > matrix1.getNumColumns()) {
+                               combined = matrix0;
+                       }
+                       else {
+                               combined = matrix1;
+                       }
+                       // (nHashes, k, D)
+                       MatrixBlock combinedCorr = new MatrixBlock(1, 3, false);
+
+                       CorrMatrixBlock blkOut = new CorrMatrixBlock(combined, 
combinedCorr);
+                       unionSketchByIndex(arg0, arg1, 0, blkOut);
+
+                       return blkOut;
+               }
+       }
+
+       public void unionSketchByIndex(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1, int idx, CorrMatrixBlock blkOut) {
+               MatrixBlock corr0 = arg0.getCorrection();
+               MatrixBlock corr1 = arg1.getCorrection();
+
+               validateSketchMetadata(corr0);
+               validateSketchMetadata(corr1);
+
+               // Both matrices are guaranteed to be row-/column-aligned
+               MatrixBlock matrix0 = arg0.getValue();
+               MatrixBlock matrix1 = arg1.getValue();
+
+               if((op.getDirection().isRow() && matrix0.getNumRows() != 
matrix1.getNumRows()) ||
+                       (op.getDirection().isCol() && matrix0.getNumColumns() 
!= matrix1.getNumColumns())) {
+                       throw new DMLRuntimeException("Cannot take the union of 
sketches: rows/columns are not aligned");
+               }
+
+               MatrixBlock combined = blkOut.getValue();
+               MatrixBlock combinedCorr = blkOut.getCorrection();
+
+               double nHashes0 = corr0.getValue(idx, 0);
+               double k0 = corr0.getValue(idx, 1);
+               double D0 = corr0.getValue(idx, 2);
+
+               double nHashes1 = corr1.getValue(idx, 0);
+               double k1 = corr1.getValue(idx, 1);
+               double D1 = corr1.getValue(idx, 2);
+
+               double nHashes = Math.max(nHashes0, nHashes1);
+               double k = Math.max(k0, k1);
+               double D = D0 + D1 - 1;
+
+               SmallestPriorityQueue hashUnion = new 
SmallestPriorityQueue((int) nHashes);
+
+               for(int i = 0; i < nHashes0; ++i) {
+                       double val;
+                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
+                               val = matrix0.getValue(idx, i);
+                       }
+                       else { // op.getDirection().isCol()
+                               val = matrix0.getValue(i, idx);
+                       }
+                       hashUnion.add(val);
+               }
+
+               for(int i = 0; i < nHashes1; ++i) {
+                       double val;
+                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
+                               val = matrix1.getValue(idx, i);
+                       }
+                       else { // op.getDirection().isCol()
+                               val = matrix1.getValue(i, idx);
+                       }
+                       hashUnion.add(val);
+               }
+
+               int i = 0;
+               while(!hashUnion.isEmpty()) {
+                       double val = hashUnion.poll();
+                       if(op.getDirection().isRow() || 
op.getDirection().isRowCol()) {
+                               combined.setValue(idx, i, val);
+                       }
+                       else { // op.getDirection().isCol()
+                               combined.setValue(i, idx, val);
+                       }
+                       i++;
+               }
+
+               combinedCorr.setValue(idx, 0, nHashes);
+               combinedCorr.setValue(idx, 1, k);
+               combinedCorr.setValue(idx, 2, D);
+       }
+
+       @Override
+       public CorrMatrixBlock intersection(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1) {
+               throw new NotImplementedException(
+                       String.format("%s intersection has not been implemented 
yet", KMVSketch.class.getSimpleName()));
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
new file mode 100644
index 0000000..0a29028
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/SmallestPriorityQueue.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.PriorityQueue;
+import java.util.Set;
+
+/**
+ * Deceiving name, but is used to contain the k smallest values inserted.
+ *
+ * TODO: Replace Standard Java Set and Priority Queue with optimized versions.
+ */
+public class SmallestPriorityQueue {
+       private static final Log LOG = 
LogFactory.getLog(SmallestPriorityQueue.class.getName());
+
+       private Set<Double> containedSet;
+       private PriorityQueue<Double> smallestHashes;
+       private int k;
+
+       public SmallestPriorityQueue(int k) {
+               smallestHashes = new PriorityQueue<>(k, 
Collections.reverseOrder());
+               containedSet = new HashSet<>(1);
+               this.k = k;
+       }
+
+       public void add(double v) {
+               if(!containedSet.contains(v)) {
+                       if(smallestHashes.size() < k) {
+                               smallestHashes.add(v);
+                               containedSet.add(v);
+                       }
+                       else if(v < smallestHashes.peek()) {
+                               LOG.trace(smallestHashes.peek() + " -- " + v);
+                               smallestHashes.add(v);
+                               containedSet.add(v);
+                               double largest = smallestHashes.poll();
+                               containedSet.remove(largest);
+                       }
+               }
+       }
+
+       public int size() {
+               return smallestHashes.size();
+       }
+
+       public double peek() {
+               return smallestHashes.peek();
+       }
+
+       public double poll() {
+               return smallestHashes.poll();
+       }
+
+       public boolean isEmpty() {
+               return this.size() == 0;
+       }
+
+       @Override
+       public String toString() {
+               return smallestHashes.toString();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
index 3f63ef9..1c430c9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -19,30 +19,28 @@
 
 package org.apache.sysds.runtime.matrix.operators;
 
+import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
 import 
org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
 import org.apache.sysds.utils.Hash.HashType;
 
 public class CountDistinctOperator extends Operator {
        private static final long serialVersionUID = 7615123453265129670L;
 
-       public final CountDistinctTypes operatorType;
-       public final HashType hashType;
-
-       public enum CountDistinctTypes { // The different supported types of 
counting.
-               COUNT, // Baseline naive implementation, iterate though, add to 
hashMap.
-               KMV, // K-Minimum Values algorithm.
-               HLL // HyperLogLog algorithm.
-       }
+       private final CountDistinctOperatorTypes operatorType;
+       private final HashType hashType;
+       private Types.Direction direction;
+       private IndexFunction indexFunction;
 
        public CountDistinctOperator(AUType opType) {
                super(true);
-               switch (opType) {
+               switch(opType) {
                        case COUNT_DISTINCT:
-                               this.operatorType = CountDistinctTypes.COUNT;
+                               this.operatorType = 
CountDistinctOperatorTypes.COUNT;
                                break;
                        case COUNT_DISTINCT_APPROX:
-                               this.operatorType = CountDistinctTypes.KMV;
+                               this.operatorType = 
CountDistinctOperatorTypes.KMV;
                                break;
                        default:
                                throw new DMLRuntimeException(opType + " not 
supported for CountDistinct Operator");
@@ -50,15 +48,49 @@ public class CountDistinctOperator extends Operator {
                this.hashType = HashType.LinearHash;
        }
 
-       public CountDistinctOperator(CountDistinctTypes operatorType) {
+       public CountDistinctOperator(CountDistinctOperatorTypes operatorType) {
                super(true);
                this.operatorType = operatorType;
                this.hashType = HashType.StandardJava;
        }
 
-       public CountDistinctOperator(CountDistinctTypes operatorType, HashType 
hashType) {
+       public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
HashType hashType) {
+               super(true);
+               this.operatorType = operatorType;
+               this.hashType = hashType;
+       }
+
+       public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
IndexFunction indexFunction,
+               HashType hashType) {
                super(true);
                this.operatorType = operatorType;
+               this.indexFunction = indexFunction;
                this.hashType = hashType;
        }
-}
\ No newline at end of file
+
+       public CountDistinctOperatorTypes getOperatorType() {
+               return operatorType;
+       }
+
+       public HashType getHashType() {
+               return hashType;
+       }
+
+       public IndexFunction getIndexFunction() {
+               return indexFunction;
+       }
+
+       public CountDistinctOperator setIndexFunction(IndexFunction 
indexFunction) {
+               this.indexFunction = indexFunction;
+               return this;
+       }
+
+       public Types.Direction getDirection() {
+               return direction;
+       }
+
+       public CountDistinctOperator setDirection(Types.Direction direction) {
+               this.direction = direction;
+               return this;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
similarity index 53%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
copy to 
src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
index 9581bc8..520b02e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperatorTypes.java
@@ -17,33 +17,10 @@
  * under the License.
  */
 
-package org.apache.sysds.test.functions.countDistinct;
+package org.apache.sysds.runtime.matrix.operators;
 
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
-
-public class CountDistinct extends CountDistinctBase {
-
-       public String TEST_NAME = "countDistinct";
-       public String TEST_DIR = "functions/countDistinct/";
-       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinct.class.getSimpleName() + "/";
-
-       protected String getTestClassDir() {
-               return TEST_CLASS_DIR;
-       }
-
-       protected String getTestName() {
-               return TEST_NAME;
-       }
-
-       protected String getTestDir() {
-               return TEST_DIR;
-       }
-
-       @Test
-       public void testSimple1by1() {
-               // test simple 1 by 1.
-               ExecType ex = ExecType.CP;
-               countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
-       }
-}
\ No newline at end of file
+public enum CountDistinctOperatorTypes { // The different supported types of 
counting.
+       COUNT, // Baseline naive implementation, iterate through, add to 
hashMap.
+       KMV, // K-Minimum Values algorithm.
+       HLL // HyperLogLog algorithm.
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
index 308aaaa..cd20b67 100644
--- 
a/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/CountDistinctTest.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.test.component.matrix;
 
 import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.util.ArrayList;
@@ -27,10 +28,11 @@ import java.util.Collection;
 
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
-import 
org.apache.sysds.runtime.matrix.operators.CountDistinctOperator.CountDistinctTypes;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.test.TestUtils;
 import org.apache.sysds.utils.Hash.HashType;
@@ -42,9 +44,9 @@ import org.junit.runners.Parameterized.Parameters;
 @RunWith(value = Parameterized.class)
 public class CountDistinctTest {
 
-       private static CountDistinctTypes[] esT = new CountDistinctTypes[] {
+       private static CountDistinctOperatorTypes[] esT = new 
CountDistinctOperatorTypes[] {
                // The different types of Estimators
-               CountDistinctTypes.COUNT, CountDistinctTypes.KMV, 
CountDistinctTypes.HLL};
+               CountDistinctOperatorTypes.COUNT, 
CountDistinctOperatorTypes.KMV, CountDistinctOperatorTypes.HLL};
 
        @Parameters
        public static Collection<Object[]> data() {
@@ -86,26 +88,26 @@ public class CountDistinctTest {
                
inputs.add(DataConverter.convertToMatrixBlock(TestUtils.generateTestMatrixIntV(1024,
 10241, 0, 3000, 0.1, 7)));
                actualUnique.add(3000L);
 
-               for(CountDistinctTypes et : esT) {
+               for(CountDistinctOperatorTypes et : esT) {
                        for(HashType ht : HashType.values()) {
-                               if((ht == HashType.ExpHash && et == 
CountDistinctTypes.KMV) ||
-                                       (ht == HashType.StandardJava && et == 
CountDistinctTypes.KMV)) {
+                               if((ht == HashType.ExpHash && et == 
CountDistinctOperatorTypes.KMV) ||
+                                       (ht == HashType.StandardJava && et == 
CountDistinctOperatorTypes.KMV)) {
                                        String errorMessage = "Invalid hashing 
configuration using " + ht + " and " + et;
-                                       tests.add(new Object[] {et, 
inputs.get(0), actualUnique.get(0), ht, new DMLException(),
-                                               errorMessage, 0.0});
+                                       tests.add(
+                                               new Object[] {et, 
inputs.get(0), actualUnique.get(0), ht,  new DMLException(), errorMessage, 
0.0});
                                }
-                               else if(et == CountDistinctTypes.HLL) {
+                               else if(et == CountDistinctOperatorTypes.HLL) {
                                        tests.add(new Object[] {et, 
inputs.get(0), actualUnique.get(0), ht, new NotImplementedException(),
                                                "HyperLogLog not implemented", 
0.0});
                                }
-                               else if(et != CountDistinctTypes.COUNT) {
+                               else if(et != CountDistinctOperatorTypes.COUNT) 
{
                                        for(int i = 0; i < inputs.size(); i++) {
                                                // allowing the estimate to be 
15% off
                                                tests.add(new Object[] {et, 
inputs.get(i), actualUnique.get(i), ht, null, null, 0.15});
                                        }
                                }
                        }
-                       if(et == CountDistinctTypes.COUNT) {
+                       if(et == CountDistinctOperatorTypes.COUNT) {
                                for(int i = 0; i < inputs.size(); i++) {
                                        tests.add(new Object[] {et, 
inputs.get(i), actualUnique.get(i), null, null, null, 0.0001});
                                }
@@ -115,7 +117,7 @@ public class CountDistinctTest {
        }
 
        @Parameterized.Parameter
-       public CountDistinctTypes et;
+       public CountDistinctOperatorTypes et;
        @Parameterized.Parameter(1)
        public MatrixBlock in;
        @Parameterized.Parameter(2)
@@ -135,36 +137,28 @@ public class CountDistinctTest {
 
        @Test
        public void testEstimation() {
-
-               Integer out = 0;
-               CountDistinctOperator op = new CountDistinctOperator(et, ht);
                try {
-                       if(expectedException != null){
-                               assertThrows(expectedException.getClass(),  () 
-> {LibMatrixCountDistinct.estimateDistinctValues(in, op);});
-                               return;
+                       CountDistinctOperator op = new 
CountDistinctOperator(et, ht).setDirection(Types.Direction.RowCol);
+                       if(expectedException != null) {
+                               assertThrows(expectedException.getClass(), () 
-> {
+                                       
LibMatrixCountDistinct.estimateDistinctValues(in, op);
+                               });
+                       }
+                       else {
+                               int out = 
LibMatrixCountDistinct.estimateDistinctValues(in, op);
+                               int count = out;
+                               boolean success = Math.abs(nrUnique - count) <= 
nrUnique * epsilon;
+                               StringBuilder sb = new StringBuilder();
+                               sb.append(this.toString());
+                               sb.append("\n" + count + " unique values, 
actual:" + nrUnique + " with eps of " + epsilon);
+                               assertTrue(sb.toString(), success);
                        }
-                       else
-                               out = 
LibMatrixCountDistinct.estimateDistinctValues(in, op);
-               }
-               catch(DMLException e) {
-                       throw e;
-               }
-               catch(NotImplementedException e) {
-                       throw e;
                }
                catch(Exception e) {
                        e.printStackTrace();
                        fail(this.toString());
                }
 
-               int count = out;
-               boolean success = Math.abs(nrUnique - count) <= nrUnique * 
epsilon;
-               if(!success){
-                       StringBuilder sb = new StringBuilder();
-                       sb.append(this.toString());
-                       sb.append("\n" + count + " unique values, actual:" + 
nrUnique + " with eps of " + epsilon);
-                       fail(sb.toString());
-               }
        }
 
        @Override
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
similarity index 65%
rename from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
rename to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
index a15019d..e808cb5 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApprox.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxCol.java
@@ -19,32 +19,13 @@
 
 package org.apache.sysds.test.functions.countDistinct;
 
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
+import org.apache.sysds.common.Types;
 
-public class CountDistinctApprox extends CountDistinctBase {
+public class CountDistinctApproxCol extends CountDistinctRowOrColBase {
 
-       private final static String TEST_NAME = "countDistinctApprox";
+       private final static String TEST_NAME = "countDistinctApproxCol";
        private final static String TEST_DIR = "functions/countDistinct/";
-       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApprox.class.getSimpleName() + "/";
-
-       public CountDistinctApprox() {
-               percentTolerance = 0.1;
-       }
-
-       @Test
-       public void testXXLarge() {
-               ExecType ex = ExecType.CP;
-               double tolerance = 9000 * percentTolerance;
-               countDistinctTest(9000, 10000, 5000, 0.1, ex, tolerance);
-       }
-
-       @Test
-       public void testSparse500Unique(){
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001 + 120 * percentTolerance;
-               countDistinctTest(500, 100, 100000, 0.1, ex, tolerance);
-       }
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxCol.class.getSimpleName() + "/";
 
        @Override
        protected String getTestClassDir() {
@@ -60,4 +41,14 @@ public class CountDistinctApprox extends CountDistinctBase {
        protected String getTestDir() {
                return TEST_DIR;
        }
+
+       @Override
+       protected Types.Direction getDirection() {
+               return Types.Direction.Col;
+       }
+
+       @Override
+       public void setUp() {
+               super.addTestConfiguration();
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
similarity index 65%
copy from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
copy to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
index 9581bc8..05a1256 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRow.java
@@ -19,31 +19,36 @@
 
 package org.apache.sysds.test.functions.countDistinct;
 
-import org.apache.sysds.common.Types.ExecType;
-import org.junit.Test;
+import org.apache.sysds.common.Types;
 
-public class CountDistinct extends CountDistinctBase {
+public class CountDistinctApproxRow extends CountDistinctRowOrColBase {
 
-       public String TEST_NAME = "countDistinct";
-       public String TEST_DIR = "functions/countDistinct/";
-       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinct.class.getSimpleName() + "/";
+       private final static String TEST_NAME = "countDistinctApproxRow";
+       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRow.class.getSimpleName() + "/";
 
+       @Override
        protected String getTestClassDir() {
                return TEST_CLASS_DIR;
        }
 
+       @Override
        protected String getTestName() {
                return TEST_NAME;
        }
 
+       @Override
        protected String getTestDir() {
                return TEST_DIR;
        }
 
-       @Test
-       public void testSimple1by1() {
-               // test simple 1 by 1.
-               ExecType ex = ExecType.CP;
-               countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
+       @Override
+       protected Types.Direction getDirection() {
+               return Types.Direction.Row;
        }
-}
\ No newline at end of file
+
+       @Override
+       public void setUp() {
+               super.addTestConfiguration();
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
new file mode 100644
index 0000000..e59b002
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctApproxRowCol.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.junit.Test;
+
+public class CountDistinctApproxRowCol extends CountDistinctRowColBase {
+
+       private final static String TEST_NAME = "countDistinctApproxRowCol";
+       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowCol.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               super.addTestConfiguration();
+               super.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseLarge() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 9000 * percentTolerance;
+               countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance);
+       }
+
+       @Test
+       public void testSparkSparseLarge() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 9000 * percentTolerance;
+               countDistinctScalarTest(9000, 10000, 5000, 0.1, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 9000 * percentTolerance;
+               countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance);
+       }
+
+       @Test
+       public void testSparkSparseSmall() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 9000 * percentTolerance;
+               countDistinctScalarTest(9000, 999, 999, 0.1, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDenseXSmall() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 5 * percentTolerance;
+               countDistinctScalarTest(5, 5, 10, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testSparkDenseXSmall() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 5 * percentTolerance;
+               countDistinctScalarTest(5, 10, 5, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPEmpty() {
+               ExecType ex = ExecType.CP;
+               countDistinctScalarTest(1, 0, 0, 0.1, ex, 0);
+       }
+
+       @Test
+       public void testSparkEmpty() {
+               ExecType ex = ExecType.SPARK;
+               countDistinctScalarTest(1, 0, 0, 0.1, ex, 0);
+       }
+
+       @Test
+       public void testCPSingleValue() {
+               ExecType ex = ExecType.CP;
+               countDistinctScalarTest(1, 1, 1, 1.0, ex, 0);
+       }
+
+       @Test
+       public void testSparkSingleValue() {
+               ExecType ex = ExecType.SPARK;
+               countDistinctScalarTest(1, 1, 1, 1.0, ex, 0);
+       }
+
+       // Corresponding execType=SPARK tests for CP tests in base class
+       //
+       @Test
+       public void testSparkDense1Unique() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 0.00001;
+               countDistinctScalarTest(1, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testSparkDense2Unique() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 0.00001;
+               countDistinctScalarTest(2, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testSparkDense120Unique() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = 0.00001 + 120 * percentTolerance;
+               countDistinctScalarTest(120, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 9d7e940..041cf51 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -22,13 +22,13 @@ package org.apache.sysds.test.functions.countDistinct;
 import static org.junit.Assert.assertTrue;
 
 import org.apache.sysds.common.Types;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.junit.Test;
 
 public abstract class CountDistinctBase extends AutomatedTestBase {
+       protected double percentTolerance = 0.0;
+       protected double baseTolerance = 0.0001;
 
        protected abstract String getTestClassDir();
 
@@ -36,86 +36,47 @@ public abstract class CountDistinctBase extends 
AutomatedTestBase {
 
        protected abstract String getTestDir();
 
-       @Override
-       public void setUp() {
+       protected void addTestConfiguration() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(getTestName(),
                        new TestConfiguration(getTestClassDir(), getTestName(), 
new String[] {"A.scalar"}));
        }
 
-       protected double percentTolerance = 0.0;
-       protected double baseTolerance = 0.0001;
-
-       @Test
-       public void testSmall() {
-               ExecType ex = ExecType.CP;
-               double tolerance = baseTolerance + 50 * percentTolerance;
-               countDistinctTest(50, 50, 50, 1.0, ex, tolerance);
-       }
-
-       @Test
-       public void testLarge() {
-               ExecType ex = ExecType.CP;
-               double tolerance = baseTolerance + 800 * percentTolerance;
-               countDistinctTest(800, 1000, 1000, 1.0, ex, tolerance);
-       }
-
-       @Test
-       public void testXLarge() {
-               ExecType ex = ExecType.CP;
-               double tolerance = baseTolerance + 1723 * percentTolerance;
-               countDistinctTest(1723, 5000, 2000, 1.0, ex, tolerance);
-       }
-
-       @Test
-       public void test1Unique() {
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001;
-               countDistinctTest(1, 100, 1000, 1.0, ex, tolerance);
-       }
-
-       @Test
-       public void test2Unique() {
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001;
-               countDistinctTest(2, 100, 1000, 1.0, ex, tolerance);
-       }
+       @Override
+       public abstract void setUp();
 
-       @Test
-       public void test120Unique() {
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001 + 120 * percentTolerance;
-               countDistinctTest(120, 100, 1000, 1.0, ex, tolerance);
+       public void countDistinctScalarTest(long numberDistinct, int cols, int 
rows, double sparsity,
+               Types.ExecType instType, double tolerance) {
+               countDistinctTest(Types.Direction.RowCol, numberDistinct, cols, 
rows, sparsity, instType, tolerance);
        }
 
-       @Test
-       public void testSparse500Unique() {
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001 + 500 * percentTolerance;
-               countDistinctTest(500, 100, 640000, 0.1, ex, tolerance);
+       public void countDistinctMatrixTest(Types.Direction dir, long 
numberDistinct, int cols, int rows, double sparsity,
+               Types.ExecType instType, double tolerance) {
+               countDistinctTest(dir, numberDistinct, cols, rows, sparsity, 
instType, tolerance);
        }
 
-       @Test
-       public void testSparse120Unique(){
-               ExecType ex = ExecType.CP;
-               double tolerance = 0.00001 + 120 * percentTolerance;
-               countDistinctTest(120, 100, 64000, 0.1, ex, tolerance);
-       }
+       public void countDistinctTest(Types.Direction dir, long numberDistinct, 
int cols, int rows, double sparsity,
+               Types.ExecType instType, double tolerance) {
 
-       public void countDistinctTest(int numberDistinct, int cols, int rows, 
double sparsity,
-               ExecType instType, double tolerance) {
                Types.ExecMode platformOld = setExecMode(instType);
                try {
                        
loadTestConfiguration(getTestConfiguration(getTestName()));
                        String HOME = SCRIPT_DIR + getTestDir();
                        fullDMLScriptName = HOME + getTestName() + ".dml";
-                       String out = output("A");
-                       System.out.println(out);
+                       String outputPath = output("A");
+
                        programArgs = new String[] {"-args", 
String.valueOf(numberDistinct), String.valueOf(rows),
-                               String.valueOf(cols), String.valueOf(sparsity), 
out};
+                               String.valueOf(cols), String.valueOf(sparsity), 
outputPath};
 
                        runTest(true, false, null, -1);
-                       writeExpectedScalar("A", numberDistinct);
+
+                       if(dir.isRowCol()) {
+                               writeExpectedScalar("A", numberDistinct);
+                       }
+                       else {
+                               double[][] expectedMatrix = 
getExpectedMatrixRowOrCol(dir, cols, rows, numberDistinct);
+                               writeExpectedMatrix("A", expectedMatrix);
+                       }
                        compareResults(tolerance);
                }
                catch(Exception e) {
@@ -126,4 +87,22 @@ public abstract class CountDistinctBase extends 
AutomatedTestBase {
                        rtplatform = platformOld;
                }
        }
-}
\ No newline at end of file
+
+       private double[][] getExpectedMatrixRowOrCol(Types.Direction dir, int 
cols, int rows, long expectedValue) {
+               double[][] expectedResult;
+               if(dir.isRow()) {
+                       expectedResult = new double[rows][1];
+                       for(int i = 0; i < rows; ++i) {
+                               expectedResult[i][0] = expectedValue;
+                       }
+               }
+               else {
+                       expectedResult = new double[1][cols];
+                       for(int i = 0; i < cols; ++i) {
+                               expectedResult[0][i] = expectedValue;
+                       }
+               }
+
+               return expectedResult;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
similarity index 80%
rename from 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
rename to 
src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
index 9581bc8..3de4a61 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinct.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
@@ -22,11 +22,11 @@ package org.apache.sysds.test.functions.countDistinct;
 import org.apache.sysds.common.Types.ExecType;
 import org.junit.Test;
 
-public class CountDistinct extends CountDistinctBase {
+public class CountDistinctRowCol extends CountDistinctRowColBase {
 
        public String TEST_NAME = "countDistinct";
        public String TEST_DIR = "functions/countDistinct/";
-       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinct.class.getSimpleName() + "/";
+       public String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowCol.class.getSimpleName() + "/";
 
        protected String getTestClassDir() {
                return TEST_CLASS_DIR;
@@ -40,10 +40,16 @@ public class CountDistinct extends CountDistinctBase {
                return TEST_DIR;
        }
 
+       @Override
+       public void setUp() {
+               super.addTestConfiguration();
+               super.percentTolerance = 0.0;
+       }
+
        @Test
        public void testSimple1by1() {
                // test simple 1 by 1.
                ExecType ex = ExecType.CP;
-               countDistinctTest(1, 1, 1, 1.0, ex, 0.00001);
+               countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
new file mode 100644
index 0000000..6b20075
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types.ExecType;
+import org.junit.Test;
+
+public abstract class CountDistinctRowColBase extends CountDistinctBase {
+       @Test
+       public void testCPDenseSmall() {
+               ExecType ex = ExecType.CP;
+               double tolerance = baseTolerance + 50 * percentTolerance;
+               countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDenseLarge() {
+               ExecType ex = ExecType.CP;
+               double tolerance = baseTolerance + 800 * percentTolerance;
+               countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDenseXLarge() {
+               ExecType ex = ExecType.CP;
+               double tolerance = baseTolerance + 1723 * percentTolerance;
+               countDistinctScalarTest(1723, 5000, 2000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDense1Unique() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 0.00001;
+               countDistinctScalarTest(1, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDense2Unique() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 0.00001;
+               countDistinctScalarTest(2, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDense120Unique() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 0.00001 + 120 * percentTolerance;
+               countDistinctScalarTest(120, 100, 1000, 1.0, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparse500Unique() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 0.00001 + 500 * percentTolerance;
+               countDistinctScalarTest(500, 100, 640000, 0.1, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparse120Unique() {
+               ExecType ex = ExecType.CP;
+               double tolerance = 0.00001 + 120 * percentTolerance;
+               countDistinctScalarTest(120, 100, 64000, 0.1, ex, tolerance);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
new file mode 100644
index 0000000..df2ea8a
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowOrColBase.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public abstract class CountDistinctRowOrColBase extends CountDistinctBase {
+
+       @Override
+       protected abstract String getTestClassDir();
+
+       @Override
+       protected abstract String getTestName();
+
+       @Override
+       protected abstract String getTestDir();
+
+       protected abstract Types.Direction getDirection();
+
+       protected void addTestConfiguration() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+               this.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
+
+       @Test
+       public void testCPDenseLarge() {
+               Types.ExecType ex = Types.ExecType.CP;
+
+               int actualDistinctCount = 100;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, ex, tolerance);
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+
+       @Test
+       public void testCPDenseSmall() {
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+
+       @Test
+       public void testSparkSparseLargeMultiBlockAggregation() {
+               Types.ExecType execType = Types.ExecType.SPARK;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1001;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+
+       @Test
+       public void testSparkDenseLargeMultiBlockAggregation() {
+               Types.ExecType execType = Types.ExecType.SPARK;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1001;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+
+       @Test
+       public void testSparkSparseLargeNoneAggregation() {
+               Types.ExecType execType = Types.ExecType.SPARK;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+
+       @Test
+       public void testSparkDenseLargeNoneAggregation() {
+               Types.ExecType execType = Types.ExecType.SPARK;
+
+               int actualDistinctCount = 10;
+               int rows = 10000, cols = 1000;
+               double sparsity = 0.9;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+
+               countDistinctMatrixTest(getDirection(), actualDistinctCount, 
cols, rows, sparsity, execType, tolerance);
+       }
+}
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinct.dml
index a0da780..3b21bc8 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinct.dml
@@ -21,5 +21,4 @@
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
 res = countDistinct(input)
-print(res)
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinctApprox.dml 
b/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
similarity index 92%
rename from src/test/scripts/functions/countDistinct/countDistinctApprox.dml
rename to src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
index eeb5bfc..777a56a 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctApprox.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxCol.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApprox(input)
-write(res, $5, format="text")
\ No newline at end of file
+res = countDistinctApprox(input, dir="c", type="KMV")
+write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
similarity index 92%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
index a0da780..38c8b9c 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxRow.dml
@@ -19,7 +19,6 @@
 #
 #-------------------------------------------------------------
 
-input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
-print(res)
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
+res = countDistinctApprox(input, dir="r", type="KMV")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinct.dml 
b/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
similarity index 92%
copy from src/test/scripts/functions/countDistinct/countDistinct.dml
copy to src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
index a0da780..2c5b6cf 100644
--- a/src/test/scripts/functions/countDistinct/countDistinct.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctApproxRowCol.dml
@@ -19,7 +19,6 @@
 #
 #-------------------------------------------------------------
 
-input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinct(input)
-print(res)
+input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
+res = countDistinctApprox(input, dir="rc", type="KMV")
 write(res, $5, format="text")

Reply via email to