This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8095f4167f [SYSTEMDS-3463] Add unique() built-in function
8095f4167f is described below

commit 8095f4167f21983bedc024f1aab54bfa837e5992
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sun Nov 27 19:20:52 2022 -0800

    [SYSTEMDS-3463] Add unique() built-in function
    
    This patch converts the existing unique() function from a script to a
    built-in. The script-based approach is based on sorting, which is very
    expensive computationally, especially for large multiblock inputs.
    The new approach, on the other hand, is based on a new data sketch for
    the unique() function. This first patch creates the framework for the
    new unique sketch and implements the CP RowCol case; other cases -
    CP Row/Col and Spark RowCol/Row/Col - will be implemented in subsequent
    patches.
    
    Closes #1740
---
 scripts/builtin/unique.dml                         |  45 -------
 .../java/org/apache/sysds/common/Builtins.java     |   2 +-
 src/main/java/org/apache/sysds/common/Types.java   |   3 +-
 .../org/apache/sysds/lops/PartialAggregate.java    |  11 ++
 .../org/apache/sysds/parser/DMLTranslator.java     |  23 +++-
 .../ParameterizedBuiltinFunctionExpression.java    |  71 +++++++++-
 .../sysds/runtime/functionobjects/Builtin.java     |   2 +-
 .../runtime/instructions/CPInstructionParser.java  |   5 +-
 .../runtime/instructions/InstructionUtils.java     |  13 ++
 .../cp/AggregateUnaryCPInstruction.java            |  60 ++++++---
 .../matrix/data/LibMatrixCountDistinct.java        |   6 +-
 .../sysds/runtime/matrix/data/LibMatrixSketch.java | 117 +++++++++++++++++
 .../matrix/operators/CountDistinctOperator.java    |  13 +-
 .../matrix/operators/UnarySketchOperator.java      |  44 +++++++
 .../systemds/operator/algorithm/builtin/unique.py  |  45 -------
 .../test/functions/builtin/BuiltinUniqueTest.java  | 114 ----------------
 .../sysds/test/functions/unique/UniqueBase.java    |  64 +++++++++
 .../sysds/test/functions/unique/UniqueRowCol.java  | 145 +++++++++++++++++++++
 src/test/scripts/functions/builtin/unique.R        |  27 ----
 .../unique.dml => unique/uniqueRowCol.dml}         |   6 +-
 20 files changed, 543 insertions(+), 273 deletions(-)

diff --git a/scripts/builtin/unique.dml b/scripts/builtin/unique.dml
deleted file mode 100644
index 57e01949b6..0000000000
--- a/scripts/builtin/unique.dml
+++ /dev/null
@@ -1,45 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-# Builtin function that implements unique operation on vectors
-#
-# INPUT:
-# -------------------------------------------------------
-# X     input vector
-# -------------------------------------------------------
-#
-# OUTPUT:
-# -------------------------------------------------------------------
-# R     matrix with only unique rows
-# -------------------------------------------------------------------
-
-m_unique = function(matrix[double] X)
-  return (matrix[double] R)
-{
-  R = X
-  if(nrow(X) > 1) {
-    # sort-based approach (a generic alternative would be transformencode)
-    X_sorted = order(target=X, by=1, decreasing=FALSE, index.return=FALSE);
-    temp = X_sorted[1:nrow(X_sorted)-1,] != X_sorted[2:nrow(X_sorted),];
-    mask = rbind(matrix(1, 1, 1), temp);
-    R = removeEmpty(target = X_sorted, margin = "rows", select = mask);
-  }
-}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 5afef9c308..4a0d045367 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -290,7 +290,6 @@ public enum Builtins {
        TRANS("t", false),
        TSNE("tSNE", true),
        TYPEOF("typeof", false),
-       UNIQUE("unique", true),
        UNIVAR("univar", true),
        UNION("union", true),
        VAR("var", false),
@@ -344,6 +343,7 @@ public enum Builtins {
        TRANSFORMENCODE("transformencode", false, true),
        TRANSFORMMETA("transformmeta", false, true),
        UNDER_SAMPLING("underSampling", true),
+       UNIQUE("unique", false, true),
        UPPER_TRI("upper.tri", false, true),
        XDUMMY1("xdummy1", true), //error handling test
        XDUMMY2("xdummy2", true); //error handling test
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 7c3a3f1e53..b441f7263b 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -199,7 +199,8 @@ public class Types
                TRACE(6), MEAN(7), VAR(8),
                MAXINDEX(9), MININDEX(10),
                COUNT_DISTINCT(11), ROW_COUNT_DISTINCT(12), 
COL_COUNT_DISTINCT(13),
-               COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), 
COUNT_DISTINCT_APPROX_COL(16);
+               COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), 
COUNT_DISTINCT_APPROX_COL(16),
+               UNIQUE(17);
 
                @Override
                public String toString() {
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 1a7d22b989..467c7c69b0 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -374,6 +374,17 @@ public class PartialAggregate extends Lop
 
                        case COUNT_DISTINCT_APPROX_COL:
                                return "uacdapc";
+
+                       case UNIQUE: {
+                               switch (dir) {
+                                       case RowCol: return "unique";
+                                       case Row: return "uniquer";
+                                       case Col: return "uniquec";
+                                       default:
+                                               throw new 
LopsException("PartialAggregate.getOpcode() - "
+                                                               + "Unknown 
aggregate direction: " + dir);
+                               }
+                       }
                }
                
                //should never come here for normal compilation
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 0c3a6dfd8f..06deb8ad7b 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2041,7 +2041,7 @@ public class DMLTranslator
                                break;
 
                        case COUNT_DISTINCT:
-                       case COUNT_DISTINCT_APPROX:
+                       case COUNT_DISTINCT_APPROX: {
                                Direction dir = Direction.RowCol;  // Default 
direction
                                DataType dataType = DataType.SCALAR;  // 
Default output data type
 
@@ -2063,6 +2063,7 @@ public class DMLTranslator
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), dataType, target.getValueType(),
                                                
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
                                break;
+                       }
 
                        case COUNT_DISTINCT_APPROX_ROW:
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
@@ -2074,6 +2075,26 @@ public class DMLTranslator
                                                
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
                                break;
 
+                       case UNIQUE:
+                               Direction dir = Direction.RowCol;
+                               DataType dataType = DataType.MATRIX;
+
+                               LiteralOp dirOp = (LiteralOp) 
paramHops.get("dir");
+                               if (dirOp != null) {
+                                       String dirString = 
dirOp.getStringValue().toUpperCase();
+                                       if 
(dirString.equals(Direction.RowCol.toString())) {
+                                               dir = Direction.RowCol;
+                                       } else if 
(dirString.equals(Direction.Row.toString())) {
+                                               dir = Direction.Row;
+                                       } else if 
(dirString.equals(Direction.Col.toString())) {
+                                               dir = Direction.Col;
+                                       }
+                               }
+
+                               currBuiltinOp = new 
AggUnaryOp(target.getName(), dataType, target.getValueType(),
+                                               
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
+                               break;
+
                        default:
                                throw new 
ParseException(source.printErrorLocation() + 
                                        
"processParameterizedBuiltinFunctionExpression() -- Unknown operation: " + 
source.getOpCode());
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 7ef19badde..293ca7312e 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -259,6 +259,10 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        validateCountDistinctApprox(output, conditional, true);
                        break;
 
+               case UNIQUE:
+                       validateUnique(output, conditional);
+                       break;
+
                default: //always unconditional (because unsupported operation)
                        //handle common issue of transformencode
                        if( getOpCode()==Builtins.TRANSFORMENCODE )
@@ -398,7 +402,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
 
                checkStringParam(true, fname, "dir", conditional);
                // Check data value of "dir" parameter
-               validateAggregationDirection(dataId, output);
+               validateCountDistinctAggregationDirection(dataId, output);
        }
 
        private void validateCountDistinctApprox(DataIdentifier output, boolean 
conditional, boolean isDirectionAlias) {
@@ -464,11 +468,11 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                if (!isDirectionAlias) {
                        checkStringParam(true, fname, "dir", conditional);
                        // Check data value of "dir" parameter
-                       validateAggregationDirection(dataId, output);
+                       validateCountDistinctAggregationDirection(dataId, 
output);
                }
        }
 
-       private void validateAggregationDirection(Identifier dataId, 
DataIdentifier output) {
+       private void validateCountDistinctAggregationDirection(Identifier 
dataId, DataIdentifier output) {
                HashMap<String, Expression> varParams = getVarParams();
                if (varParams.containsKey("dir")) {
                        String inputDirectionString = 
varParams.get("dir").toString().toUpperCase();
@@ -512,6 +516,67 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                }
        }
 
+       private void validateUnique(DataIdentifier output, boolean conditional) 
{
+               HashMap<String, Expression> varParams = getVarParams();
+
+               // "data" is the only parameter that is allowed to be unnamed
+               if (varParams.containsKey(null)) {
+                       varParams.put("data", varParams.remove(null));
+               }
+
+               // Validate the number of parameters
+               String fname = getOpCode().getName();
+               String usageMessage = "function " + fname + " takes at least 1 
and at most 2 parameters";
+               if (varParams.size() < 1) {
+                       raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
+               }
+
+               if (varParams.size() > 2) {
+                       raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+               }
+
+               // Check parameter names are valid
+               Set<String> validParameterNames = CollectionUtils.asSet("data", 
"dir");
+               checkInvalidParameters(getOpCode(), varParams, 
validParameterNames);
+
+               // Check parameter expression data types match expected
+               checkDataType(false, fname, "data", DataType.MATRIX, 
conditional);
+               checkDataValueType(false, fname, "data", DataType.MATRIX, 
ValueType.FP64, conditional);
+
+               // We need the dimensions of the input matrix to determine the 
output matrix characteristics
+               // Validate data parameter, lookup previously defined var or 
resolve expression
+               Identifier dataId = varParams.get("data").getOutput();
+               if (dataId == null) {
+                       raiseValidateError("Cannot parse input parameter 
\"data\" to function " + fname, conditional);
+               }
+
+               checkStringParam(true, fname, "dir", conditional);
+               // Check data value of "dir" parameter
+               validateUniqueAggregationDirection(dataId, output);
+       }
+
+       private void validateUniqueAggregationDirection(Identifier dataId, 
DataIdentifier output) {
+               HashMap<String, Expression> varParams = getVarParams();
+               if (varParams.containsKey("dir")) {
+                       String inputDirectionString = 
varParams.get("dir").toString().toUpperCase();
+
+                       // unrecognized value for "dir" parameter
+                       if 
(!inputDirectionString.equals(Types.Direction.Row.toString())
+                                       && 
!inputDirectionString.equals(Types.Direction.Col.toString())
+                                       && 
!inputDirectionString.equals(Types.Direction.RowCol.toString())) {
+                               raiseValidateError("Invalid argument: " + 
inputDirectionString + " is not recognized");
+                       }
+               }
+
+               // rc/r/c -> unique return value is the same as the input in 
the worst case
+               // default to dir="rc"
+               output.setDataType(DataType.MATRIX);
+               output.setDimensions(dataId.getDim1(), dataId.getDim2());
+               output.setBlocksize(dataId.getBlocksize());
+               output.setValueType(ValueType.FP64);
+               output.setNnz(dataId.getNnz());
+       }
+
        private void checkStringParam(boolean optional, String fname, String 
pname, boolean conditional) {
                Expression param = getVarParam(pname);
                if (param == null) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
index b904e60d9c..30114eff18 100644
--- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java
@@ -51,7 +51,7 @@ public class Builtin extends ValueFunction
                MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, 
LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX,
                STOP, CEIL, FLOOR, CUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, 
INVERSE, SPROP, SIGMOID, EVAL, LIST,
                TYPEOF, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, 
DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE,
-               MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX}
+               MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE}
 
 
        public BuiltinCode bFunc;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 78dec00c24..eeeed7c5a1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -121,6 +121,9 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "uacdap"  , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdapr" , 
CPType.AggregateUnary);
                String2CPInstructionType.put( "uacdapc" , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "unique"  , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uniquer" , 
CPType.AggregateUnary);
+               String2CPInstructionType.put( "uniquec" , 
CPType.AggregateUnary);
 
                String2CPInstructionType.put( "uaggouterchain", 
CPType.UaggOuterChain);
                
@@ -215,7 +218,7 @@ public class CPInstructionParser extends InstructionParser
                String2CPInstructionType.put( "list",   CPType.BuiltinNary);
                
                // Parameterized Builtin Functions
-               String2CPInstructionType.put( "autoDiff" , 
CPType.ParameterizedBuiltin);
+               String2CPInstructionType.put( "autoDiff" ,      
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put("paramserv",       
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "nvlist",         
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "cdf",            
CPType.ParameterizedBuiltin);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 4f4fac1e38..5026175b59 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -104,6 +104,7 @@ import 
org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
 import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
 import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.UnarySketchOperator;
 
 
 public class InstructionUtils 
@@ -453,6 +454,18 @@ public class InstructionUtils
                        aggun = new 
CountDistinctOperator(AggregateUnaryCPInstruction.AUType.COUNT_DISTINCT_APPROX,
                                        Direction.Col, 
ReduceRow.getReduceRowFnObject());
                }
+               else if ( opcode.equalsIgnoreCase("unique") ) {
+                       AggregateOperator agg = new AggregateOperator(0, 
Builtin.getBuiltinFnObject("unique"));
+                       aggun = new UnarySketchOperator(agg, 
ReduceAll.getReduceAllFnObject(), Direction.RowCol, numThreads);
+               }
+               else if ( opcode.equalsIgnoreCase("uniquer") ) {
+                       AggregateOperator agg = new AggregateOperator(0, 
Builtin.getBuiltinFnObject("unique"));
+                       aggun = new UnarySketchOperator(agg, 
ReduceCol.getReduceColFnObject(), Direction.Row, numThreads);
+               }
+               else if ( opcode.equalsIgnoreCase("uniquec") ) {
+                       AggregateOperator agg = new AggregateOperator(0, 
Builtin.getBuiltinFnObject("unique"));
+                       aggun = new UnarySketchOperator(agg, 
ReduceRow.getReduceRowFnObject(), Direction.Col, numThreads);
+               }
 
                return aggun;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
index 6fc0107520..030fe5f5cf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/AggregateUnaryCPInstruction.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -29,33 +28,27 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.data.BasicTensorBlock;
 import org.apache.sysds.runtime.data.TensorBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
-import org.apache.sysds.runtime.functionobjects.ReduceAll;
-import org.apache.sysds.runtime.functionobjects.ReduceCol;
-import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
-import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.lineage.LineageDedupUtils;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
+import org.apache.sysds.runtime.matrix.data.LibMatrixSketch;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
-import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.SmallestPriorityQueue;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+import org.apache.sysds.runtime.matrix.operators.UnarySketchOperator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.utils.Explain;
 
-import java.util.HashSet;
-import java.util.Set;
-
 public class AggregateUnaryCPInstruction extends UnaryCPInstruction {
        // private static final Log LOG = 
LogFactory.getLog(AggregateUnaryCPInstruction.class.getName());
 
        public enum AUType {
                NROW, NCOL, LENGTH, EXISTS, LINEAGE, 
-               COUNT_DISTINCT, COUNT_DISTINCT_APPROX,
+               COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE,
                DEFAULT;
                public boolean isMeta() {
                        return this != DEFAULT;
@@ -107,6 +100,13 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                .parseAggregateUnaryRowIndexOperator(opcode, 
Integer.parseInt(parts[4]), Integer.parseInt(parts[3]));
                        return new AggregateUnaryCPInstruction(aggun, in1, out, 
AUType.DEFAULT, opcode, str);
                }
+               else if(opcode.equalsIgnoreCase("unique")
+                               || opcode.equalsIgnoreCase("uniquer")
+                               || opcode.equalsIgnoreCase("uniquec")){
+                       AggregateUnaryOperator aggun = 
InstructionUtils.parseBasicAggregateUnaryOperator(opcode,
+                                       Integer.parseInt(parts[3]));
+                       return new AggregateUnaryCPInstruction(aggun, in1, out, 
AUType.UNIQUE, opcode, str);
+               }
                else { //DEFAULT BEHAVIOR
                        AggregateUnaryOperator aggun = InstructionUtils
                                .parseBasicAggregateUnaryOperator(opcode, 
Integer.parseInt(parts[3]));
@@ -116,7 +116,7 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
        
        @Override
        public void processInstruction( ExecutionContext ec ) {
-               String output_name = output.getName();
+               String outputName = output.getName();
                String opcode = getOpcode();
                
                switch( _type ) {
@@ -163,7 +163,7 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                }
                                
                                //create and set output scalar
-                               ec.setScalarOutput(output_name, new 
IntObject(rval));
+                               ec.setScalarOutput(outputName, new 
IntObject(rval));
                                break;
                        }
                        case EXISTS: {
@@ -172,7 +172,7 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                        
ec.getScalarInput(input1).getStringValue();
                                boolean rval = 
ec.getVariables().keySet().contains(varName);
                                //create and set output scalar
-                               ec.setScalarOutput(output_name, new 
BooleanObject(rval));
+                               ec.setScalarOutput(outputName, new 
BooleanObject(rval));
                                break;
                        }
                        case LINEAGE: {
@@ -184,7 +184,7 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                LineageItem li = ec.getLineageItem(input1);
                                String out = !DMLScript.LINEAGE_DEDUP ? 
Explain.explain(li) :
                                        Explain.explain(li) + 
LineageDedupUtils.mergeExplainDedupBlocks(ec);
-                               ec.setScalarOutput(output_name, new 
StringObject(out));
+                               ec.setScalarOutput(outputName, new 
StringObject(out));
                                break;
                        }
                        case COUNT_DISTINCT:
@@ -203,18 +203,38 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                if (op.getDirection().isRowCol()) {
                                        long res = (long) 
LibMatrixCountDistinct.estimateDistinctValues(input, op).getValue(0, 0);
                                        ec.releaseMatrixInput(input1.getName());
-                                       ec.setScalarOutput(output_name, new 
IntObject(res));
+                                       ec.setScalarOutput(outputName, new 
IntObject(res));
                                } else {  // Row/Col
                                        // Note that for each row, the max 
number of distinct values < NNZ < max number of columns = 1000:
                                        // Since count distinct approximate 
estimates are unreliable for values < 1024,
                                        // we will force a naive count.
                                        MatrixBlock res = 
LibMatrixCountDistinct.estimateDistinctValues(input, op);
                                        ec.releaseMatrixInput(input1.getName());
-                                       ec.setMatrixOutput(output_name, res);
+                                       ec.setMatrixOutput(outputName, res);
+                               }
+
+                               break;
+                       }
+
+                       case UNIQUE: {
+                               
if(!ec.getVariables().keySet().contains(input1.getName())) {
+                                       throw new DMLRuntimeException("Variable 
'" + input1.getName() + "' does not exist.");
+                               }
+                               MatrixBlock input = 
ec.getMatrixInput(input1.getName());
+
+                               // Operator type: test and cast
+                               if (!(_optr instanceof UnarySketchOperator)) {
+                                       throw new DMLRuntimeException("Operator 
should be instance of "
+                                                       + 
UnarySketchOperator.class.getSimpleName());
                                }
+                               UnarySketchOperator op = (UnarySketchOperator) 
_optr;
 
+                               MatrixBlock res = 
LibMatrixSketch.getUniqueValues(input, op.getDirection());
+                               ec.releaseMatrixInput(input1.getName());
+                               ec.setMatrixOutput(outputName, res);
                                break;
                        }
+
                        default: {
                                AggregateUnaryOperator au_op = 
(AggregateUnaryOperator) _optr;
                                if (input1.getDataType() == DataType.MATRIX) {
@@ -226,10 +246,10 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
                                        ec.releaseMatrixInput(input1.getName());
                                        if (output.getDataType() == 
DataType.SCALAR) {
                                                DoubleObject ret = new 
DoubleObject(resultBlock.getValue(0, 0));
-                                               ec.setScalarOutput(output_name, 
ret);
+                                               ec.setScalarOutput(outputName, 
ret);
                                        } else {
                                                // since the computed value is 
a scalar, allocate a "temp" output matrix
-                                               ec.setMatrixOutput(output_name, 
resultBlock);
+                                               ec.setMatrixOutput(outputName, 
resultBlock);
                                        }
                                } 
                                else if (input1.getDataType() == 
DataType.TENSOR) {
@@ -240,10 +260,10 @@ public class AggregateUnaryCPInstruction extends 
UnaryCPInstruction {
 
                                        ec.releaseTensorInput(input1.getName());
                                        if(output.getDataType() == 
DataType.SCALAR)
-                                               ec.setScalarOutput(output_name, 
ScalarObjectFactory.createScalarObject(
+                                               ec.setScalarOutput(outputName, 
ScalarObjectFactory.createScalarObject(
                                                        input1.getValueType(), 
resultBlock.get(new int[]{0, 0})));
                                        else
-                                               ec.setTensorOutput(output_name, 
new TensorBlock(resultBlock));
+                                               ec.setTensorOutput(outputName, 
new TensorBlock(resultBlock));
                                }
                                else {
                                        throw new DMLRuntimeException(opcode + 
" only supported on matrix or tensor.");
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index 72bcd64b43..c70f72ad3f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -29,7 +29,11 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
-import org.apache.sysds.runtime.data.*;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCOO;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
 import org.apache.sysds.runtime.matrix.data.sketch.SketchFactory;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
new file mode 100644
index 0000000000..3793564dbe
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixSketch.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+
+public class LibMatrixSketch {
+
+       private enum MatrixShape {
+               SKINNY,  // rows > cols
+               WIDE,    // rows < cols
+       }
+
+       public static MatrixBlock getUniqueValues(MatrixBlock blkIn, 
Types.Direction dir) {
+
+               int R = blkIn.getNumRows();
+               int C = blkIn.getNumColumns();
+               List<HashSet<Double>> hashSets = new ArrayList<>();
+
+               MatrixShape matrixShape = (R >= C)? MatrixShape.SKINNY : 
MatrixShape.WIDE;
+               MatrixBlock blkOut;
+               switch (dir)
+               {
+                       case RowCol:
+                               HashSet<Double> hashSet = new HashSet<>();
+                               // TODO optimize for sparse and compressed 
inputs
+                               for (int i=0; i<R; ++i) {
+                                       for (int j=0; j<C; ++j) {
+                                               hashSet.add(blkIn.getValue(i, 
j));
+                                       }
+                               }
+                               hashSets.add(hashSet);
+                               blkOut = serializeRowCol(hashSets, dir, 
matrixShape);
+                               break;
+
+                       case Row:
+                       case Col:
+                               throw new NotImplementedException("Unique 
Row/Col has not been implemented yet");
+
+                       default:
+                               throw new 
IllegalArgumentException("Unrecognized direction: " + dir);
+               }
+
+               return blkOut;
+       }
+
+       private static MatrixBlock serializeRowCol(List<HashSet<Double>> 
hashSets, Types.Direction dir, MatrixShape matrixShape) {
+
+               if (dir != Types.Direction.RowCol) {
+                       throw new IllegalArgumentException("Unrecognized 
direction: " + dir);
+               }
+
+               MatrixBlock blkOut;
+
+               if (hashSets.isEmpty()) {
+                       throw new IllegalArgumentException("Corrupt sketch: 
metadata cannot be empty");
+               }
+
+               int R, C;
+               HashSet<Double> hashSet = hashSets.get(0);
+               Iterator<Double> iter = hashSet.iterator();
+
+               if (hashSet.size() <= OptimizerUtils.DEFAULT_BLOCKSIZE) {
+                       if (matrixShape == MatrixShape.SKINNY) {
+                               // Rx1 column vector
+                               R = hashSet.size();
+                               C = 1;
+                       } else {  // WIDE
+                               // 1xC row vector
+                               R = 1;
+                               C = hashSet.size();
+                       }
+               } else {
+                       if (matrixShape == MatrixShape.SKINNY) {
+                               R = OptimizerUtils.DEFAULT_BLOCKSIZE;
+                               C = (hashSet.size() / 
OptimizerUtils.DEFAULT_BLOCKSIZE) + 1;
+                       } else {  // WIDE
+                               R = (hashSet.size() / 
OptimizerUtils.DEFAULT_BLOCKSIZE) + 1;
+                               C = OptimizerUtils.DEFAULT_BLOCKSIZE;
+                       }
+               }
+
+               blkOut = new MatrixBlock(R, C, false);
+               for (int i=0; i<R; ++i) {
+                       // C is guaranteed to be > 0
+                       for (int j=0; j<C; ++j) {
+                               blkOut.setValue(i, j, iter.next());
+                       }
+               }
+
+               return blkOut;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
index c33accf943..d3deba14de 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/CountDistinctOperator.java
@@ -26,15 +26,14 @@ import org.apache.sysds.runtime.functionobjects.Plus;
 import 
org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction.AUType;
 import org.apache.sysds.utils.Hash.HashType;
 
-public class CountDistinctOperator extends AggregateUnaryOperator {
+public class CountDistinctOperator extends UnarySketchOperator {
        private static final long serialVersionUID = 7615123453265129670L;
 
        private final CountDistinctOperatorTypes operatorType;
-       private final Types.Direction direction;
        private final HashType hashType;
 
        public CountDistinctOperator(AUType opType, Types.Direction direction, 
IndexFunction indexFunction) {
-               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, 1);
+               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, direction, 1);
 
                switch(opType) {
                        case COUNT_DISTINCT:
@@ -47,15 +46,13 @@ public class CountDistinctOperator extends 
AggregateUnaryOperator {
                                throw new DMLRuntimeException(opType + " not 
supported for CountDistinct Operator");
                }
                this.hashType = HashType.LinearHash;
-               this.direction = direction;
        }
 
        public CountDistinctOperator(CountDistinctOperatorTypes operatorType, 
Types.Direction direction,
                                                                 IndexFunction 
indexFunction, HashType hashType) {
-               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, 1);
+               super(new AggregateOperator(0, Plus.getPlusFnObject()), 
indexFunction, direction, 1);
 
                this.operatorType = operatorType;
-               this.direction = direction;
                this.hashType = hashType;
        }
 
@@ -66,8 +63,4 @@ public class CountDistinctOperator extends 
AggregateUnaryOperator {
        public HashType getHashType() {
                return hashType;
        }
-
-       public Types.Direction getDirection() {
-               return direction;
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java
new file mode 100644
index 0000000000..0716c45afd
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnarySketchOperator.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.operators;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
+
+public class UnarySketchOperator extends AggregateUnaryOperator {
+       private static final long serialVersionUID = 7615123453265129671L;
+
+       private final Types.Direction direction;
+
+       public UnarySketchOperator(AggregateOperator aop, IndexFunction 
indexFunction, Types.Direction direction) {
+               super(aop, indexFunction);
+               this.direction = direction;
+       }
+
+       public UnarySketchOperator(AggregateOperator aop, IndexFunction 
indexFunction,
+                                                          Types.Direction 
direction, int numThreads) {
+               super(aop, indexFunction, numThreads);
+               this.direction = direction;
+       }
+
+       public Types.Direction getDirection() {
+               return direction;
+       }
+}
diff --git a/src/main/python/systemds/operator/algorithm/builtin/unique.py 
b/src/main/python/systemds/operator/algorithm/builtin/unique.py
deleted file mode 100644
index fd77b1fd55..0000000000
--- a/src/main/python/systemds/operator/algorithm/builtin/unique.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# -------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-# -------------------------------------------------------------
-
-# Autogenerated By   : src/main/python/generator/generator.py
-# Autogenerated From : scripts/builtin/unique.dml
-
-from typing import Dict, Iterable
-
-from systemds.operator import OperationNode, Matrix, Frame, List, MultiReturn, 
Scalar
-from systemds.script_building.dag import OutputType
-from systemds.utils.consts import VALID_INPUT_TYPES
-
-
-def unique(X: Matrix):
-    """
-     Builtin function that implements unique operation on vectors
-    
-    
-    
-    :param X: input vector
-    :return: matrix with only unique rows
-    """
-
-    params_dict = {'X': X}
-    return Matrix(X.sds_context,
-        'unique',
-        named_input_nodes=params_dict)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java
deleted file mode 100644
index 7d36d79b08..0000000000
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinUniqueTest.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysds.test.functions.builtin;
-
-import org.apache.sysds.common.Types;
-import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.runtime.matrix.data.MatrixValue;
-import org.apache.sysds.test.AutomatedTestBase;
-import org.apache.sysds.test.TestConfiguration;
-import org.apache.sysds.test.TestUtils;
-import org.junit.Test;
-
-import java.util.HashMap;
-
-public class BuiltinUniqueTest extends AutomatedTestBase {
-       private final static String TEST_NAME = "unique";
-       private final static String TEST_DIR = "functions/builtin/";
-       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinUniqueTest.class.getSimpleName() + "/";
-
-       @Override
-       public void setUp() {
-               TestUtils.clearAssertionInformation();
-               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
-       }
-
-       @Test
-       public void testUnique1CP() {
-               double[][] X = 
{{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
-               runUniqueTest(X, ExecType.CP);
-       }
-
-       @Test
-       public void testUnique1SP() {
-               double[][] X = 
{{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
-               runUniqueTest(X,ExecType.SPARK);
-       }
-
-       @Test
-       public void testUnique2CP() {
-               double[][] X = {{0}};
-               runUniqueTest(X, ExecType.CP);
-       }
-
-       @Test
-       public void testUnique2SP() {
-               double[][] X = {{0}};
-               runUniqueTest(X, ExecType.SPARK);
-       }
-
-       @Test
-       public void testUnique3CP() {
-               double[][] X = {{1, 2, 3}, {2, 3, 4}, {1, 2, 3}};
-               runUniqueTest(X, ExecType.CP);
-       }
-
-//     @Test
-//     public void testUnique3SP() { //This fails?
-//             double[][] X = {{1, 2, 3}, {2, 3, 4}, {1, 2, 3}};
-//             runUniqueTest(X, ExecType.SPARK);
-//     }
-
-       @Test
-       public void testUnique4CP() {
-               double[][] X = {{1.5, 2}, {7, 3}, {1, 3}, {1.5, 2}, {-1, 
-2.32}, {-1, 0.1}, {1, 3}, {-1, 0.1}};
-               runUniqueTest(X, ExecType.CP);
-       }
-
-//     @Test
-//     public void testUnique4SP() { //This fails?
-//             double[][] X = {{1.5, 2}, {7, 3}, {1, 3}, {1.5, 2}, {-1, 
-2.32}, {-1, 0.1}, {1, 3}, {-1, 0.1}};
-//             runUniqueTest(X, ExecType.SPARK);
-//     }
-
-       private void runUniqueTest(double[][] X, ExecType instType) {
-               Types.ExecMode platformOld = setExecMode(instType);
-               try {
-                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
-                       String HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{ "-args", input("X"), 
output("R")};
-                       fullRScriptName = HOME + TEST_NAME + ".R";
-                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
-
-                       writeInputMatrixWithMTD("X", X, true);
-
-                       runTest(true, false, null, -1);
-                       runRScript(true);
-
-                       HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
-                       HashMap<MatrixValue.CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir("R");
-                       TestUtils.compareMatrices(dmlfile, rfile, 1e-10, "dml", 
"expected");
-               }
-               finally {
-                       rtplatform = platformOld;
-               }
-       }
-}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java 
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
new file mode 100644
index 0000000000..6b78a60290
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueBase.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unique;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.HashMap;
+
+public abstract class UniqueBase extends AutomatedTestBase {
+
+       protected abstract String getTestName();
+
+       protected abstract String getTestDir();
+
+       protected abstract String getTestClassDir();
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+       }
+
+       protected void uniqueTest(double[][] inputMatrix, double[][] 
expectedMatrix,
+                                                       Types.ExecType 
instType, double epsilon) {
+               Types.ExecMode platformOld = setExecMode(instType);
+               try {
+                       
loadTestConfiguration(getTestConfiguration(getTestName()));
+                       String HOME = SCRIPT_DIR + getTestDir();
+                       fullDMLScriptName = HOME + getTestName() + ".dml";
+                       programArgs = new String[]{ "-args", input("I"), 
output("A")};
+
+                       writeInputMatrixWithMTD("I", inputMatrix, true);
+
+                       runTest(true, false, null, -1);
+                       writeExpectedMatrix("A", expectedMatrix);
+
+                       compareResultsRowsOutOfOrder(epsilon);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java 
b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java
new file mode 100644
index 0000000000..a8b2fc1ba7
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/unique/UniqueRowCol.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unique;
+
+import org.apache.sysds.common.Types;
+import org.junit.Test;
+
+public class UniqueRowCol extends UniqueBase {
+       private final static String TEST_NAME = "uniqueRowCol";
+       private final static String TEST_DIR = "functions/unique/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
UniqueRowCol.class.getSimpleName() + "/";
+
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Test
+       public void testBaseCase1CP() {
+               double[][] inputMatrix = {{0}};
+               double[][] expectedMatrix = {{0}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testBaseCase2CP() {
+               double[][] inputMatrix = {{1}};
+               double[][] expectedMatrix = {{1}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testSkinnySmallCP() {
+               double[][] inputMatrix = 
{{1},{1},{6},{9},{4},{2},{0},{9},{0},{0},{4},{4}};
+               double[][] expectedMatrix = {{1},{6},{9},{4},{2},{0}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testWideSmallCP() {
+               double[][] inputMatrix = {{1,1,6,9,4,2,0,9,0,0,4,4}};
+               double[][] expectedMatrix = {{1,6,9,4,2,0}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testSquareLargeCP() {
+               double[][] inputMatrix = new double[1000][1000];
+               // Input is a 1000 x 1000 matrix:
+               // [1, 1, ..., 1, 2, 2, .., 2]
+               // [1, 1, ..., 1, 2, 2, .., 2]
+               // ..
+               // [1, 1, ..., 1, 2, 2, .., 2]
+               // [2, 2, ..., 2, 1, 1, .., 1]
+               // [2, 2, ..., 2, 1, 1, .., 1]
+               // ..
+               // [2, 2, ..., 2, 1, 1, .., 1]
+               for (int i=0; i<500; ++i) {
+                       for (int j=0; j<500; ++j) {
+                               inputMatrix[i][j] = 1;
+                               inputMatrix[i+500][j+500] = 1;
+                       }
+               }
+               for (int i=500; i<1000; ++i) {
+                       for (int j=0; j<500; ++j) {
+                               inputMatrix[i][j] = 2;
+                               inputMatrix[i-500][j+500] = 2;
+                       }
+               }
+               // Expect the output to be a skinny matrix due to the following 
condition in code:
+               // (R >= C)? LibMatrixSketch.MatrixShape.SKINNY : 
LibMatrixSketch.MatrixShape.WIDE;
+               double[][] expectedMatrix = {{1},{2}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testSkinnyLargeCP() {
+               double[][] inputMatrix = new double[2000][2];
+               // Input is a 2000 x 2 matrix:
+               // [1, 2]
+               // [1, 2]
+               // ..
+               // [1, 2]
+               // [2, 1]
+               // [2, 1]
+               // ..
+               // [2, 1]
+               for (int i=0; i<1000; ++i) {
+                       inputMatrix[i][0] = 1;
+                       inputMatrix[i][1] = 2;
+               }
+               for (int i=1000; i<2000; ++i) {
+                       inputMatrix[i][0] = 2;
+                       inputMatrix[i][1] = 1;
+               }
+               double[][] expectedMatrix = {{1}, {2}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+
+       @Test
+       public void testWideLargeCP() {
+               double[][] inputMatrix = new double[2][2000];
+               // Input is a 2 x 2000 matrix:
+               // [1, 1, ..., 1, 2, 2, .., 2]
+               // [2, 2, ..., 2, 1, 1, .., 1]
+               for (int j=0; j<1000; ++j) {
+                       inputMatrix[0][j] = 1;
+                       inputMatrix[1][j+1000] = 1;
+               }
+               for (int j=1000; j<2000; ++j) {
+                       inputMatrix[0][j] = 2;
+                       inputMatrix[1][j-1000] = 2;
+               }
+               double[][] expectedMatrix = {{1,2}};
+               uniqueTest(inputMatrix, expectedMatrix, Types.ExecType.CP, 0.0);
+       }
+}
diff --git a/src/test/scripts/functions/builtin/unique.R 
b/src/test/scripts/functions/builtin/unique.R
deleted file mode 100644
index 6f4c17895e..0000000000
--- a/src/test/scripts/functions/builtin/unique.R
+++ /dev/null
@@ -1,27 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-args<-commandArgs(TRUE)
-options(digits=22)
-library("Matrix")
-
-X = as.matrix(readMM(paste(args[1], "X.mtx", sep="")));
-R = unique(X[order(X[,1]),]);
-writeMM(as(R, "CsparseMatrix"), paste(args[2], "R", sep=""));
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/unique.dml 
b/src/test/scripts/functions/unique/uniqueRowCol.dml
similarity index 92%
rename from src/test/scripts/functions/builtin/unique.dml
rename to src/test/scripts/functions/unique/uniqueRowCol.dml
index 55b5aab378..2022342418 100644
--- a/src/test/scripts/functions/builtin/unique.dml
+++ b/src/test/scripts/functions/unique/uniqueRowCol.dml
@@ -19,6 +19,6 @@
 #
 #-------------------------------------------------------------
 
-X = read($1);
-R = unique(X = X);
-write(R, $2);
+input = read($1);
+res = unique(input);
+write(res, $2, format="text");

Reply via email to