This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e4491b32c3 [SYSTEMDS-3458] Adding Spark backend support for 
countDistinct() builtin
e4491b32c3 is described below

commit e4491b32c373677857a781141d9d79864a922456
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sat Nov 12 18:00:37 2022 -0800

    [SYSTEMDS-3458] Adding Spark backend support for countDistinct() builtin
    
    This patch adds support for running countDistinct() builtin on the Spark
    backend. The implementation adds a new CountDistinctFunctionSketch to
    the sketches library. The sketch is based on an optimized HashMap lookup:
        Let [sign (1) | exponent (11) | fraction (52)] represent the bits
        of a double value
        - Each double is split into 2 parts:
            - PartA: [sign (1) | exponent (11)] and
            - PartB: [fraction (52)]
        - PartA is used as the key into the HashMap, and PartB is stored
        in the Set<Long> value
        - Many values in blkIn may have the same exponent, so we need to
        store multiple PartB values for any given PartA key
        - The HashMap is then serialized into the sketch like so:
            row_i: [exponent_i, N_i, fraction_i0, fraction_i1, ..,
            fraction_iN, 0, .., 0]
    
            - exponent_i: it is the short int value of the bit representation
            of the original double input
            - N_i: we store the number of fractions per exponent value to
            avoid dealing with jagged matrices
    
    NB: Although countDistinct() does not strictly require us to keep track
    of the individual input elements, we must do so in the sketch so we can
    support the union() op for the MULTI_BLOCK case (see notes below)
    
    In the worst case, blkIn can have 1000x1000 = 10e6 unique values.
    Therefore, total memory is bounded above by:
        Max of
            (10e6 * 2) + (10e6 * 8) bytes   -> 10e6 keys, each with a single 
value
                and 2 + (10e6 * 8) bytes    -> a single key, with unique 10e6 
fraction parts
        = (10e6 * 2) + (10e6 * 8) = 10e6 * 10 bytes < 10 MB
    
    Notes to Reviewers:
        - The current implementation only supports the SINGLE_BLOCK case,
        i.e. the current implementation works for input matrices of max
        size 1000 x 1000. Larger inputs (along either dimension) will throw
        a `NotImplementedException`. MULTI_BLOCK aggregation support will
        be added in a future patch.
        - This patch adds support for only the default RowCol direction;
        we will add support for Row/Col direction in a future patch.
    
    Test currently cover Unit tests not Integration tests
    
    Closes #1727
    Closes #1726
---
 .../runtime/instructions/InstructionUtils.java     |   1 +
 .../runtime/instructions/SPInstructionParser.java  |  12 +-
 .../spark/AggregateUnarySketchSPInstruction.java   |   9 +-
 .../runtime/instructions/spark/SPInstruction.java  |   4 +-
 .../matrix/data/LibMatrixCountDistinct.java        |  14 +-
 ...tApproxSketch.java => CountDistinctSketch.java} |  11 +-
 .../countdistinct/CountDistinctFunctionSketch.java | 162 +++++++++++++++++++++
 .../data/sketch/countdistinctapprox/KMVSketch.java |   3 +-
 .../countDistinct/CountDistinctRowCol.java         |   9 +-
 .../countDistinct/CountDistinctRowColBase.java     |  14 ++
 10 files changed, 219 insertions(+), 20 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index c27d5a70ae..4f4fac1e38 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -896,6 +896,7 @@ public class InstructionUtils
                        case "uak+":
                        case"uark+":
                        case "uack+":    return "ak+";
+                       case "uacd":
                        case "ua+":
                        case "uar+":
                        case "uac+":     return "a+";
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 9496cf465e..0ff745de57 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -126,14 +126,14 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "uac*"    , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uatrace" , 
SPType.AggregateUnary);
                String2SPInstructionType.put( "uaktrace", 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uacd"    , 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uacdr"   , 
SPType.AggregateUnary);
-               String2SPInstructionType.put( "uacdc"   , 
SPType.AggregateUnary);
 
                // Aggregate unary sketch operators
-               String2SPInstructionType.put( "uacdap" , 
SPType.AggregateUnarySketch);
-               String2SPInstructionType.put( "uacdapr", 
SPType.AggregateUnarySketch);
-               String2SPInstructionType.put( "uacdapc", 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacd"    , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdr"   , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdc"   , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdap"  , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdapr" , 
SPType.AggregateUnarySketch);
+               String2SPInstructionType.put( "uacdapc" , 
SPType.AggregateUnarySketch);
 
                //binary aggregate operators (matrix multiplication operators)
                String2SPInstructionType.put( "mapmm"      , SPType.MAPMM);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index 703828e3a1..767e4b0c0b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -66,7 +66,14 @@ public class AggregateUnarySketchSPInstruction extends 
UnarySPInstruction {
         AggBinaryOp.SparkAggType aggtype = 
AggBinaryOp.SparkAggType.valueOf(parts[3]);
 
         CountDistinctOperator cdop = null;
-        if (opcode.equals("uacdap")) {
+        if (opcode.equals("uacd")) {
+            cdop = new CountDistinctOperator(CountDistinctOperatorTypes.COUNT, 
Types.Direction.RowCol,
+                    ReduceAll.getReduceAllFnObject(), 
Hash.HashType.LinearHash);
+        } else if (opcode.equals("uacdr")) {
+            throw new NotImplementedException("uacdr has not been implemented 
yet");
+        } else if (opcode.equals("uacdc")) {
+            throw new NotImplementedException("uacdc has not been implemented 
yet");
+        } else if (opcode.equals("uacdap")) {
             cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, 
Types.Direction.RowCol,
                     ReduceAll.getReduceAllFnObject(), 
Hash.HashType.LinearHash);
         } else if (opcode.equals("uacdapr")) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
index 830ba4dae3..263923de05 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
@@ -32,12 +32,12 @@ public abstract class SPInstruction extends Instruction {
        public enum SPType {
                MAPMM, MAPMMCHAIN, CPMM, RMM, TSMM, TSMM2, PMM, ZIPMM, PMAPMM, 
//matrix multiplication instructions
                MatrixIndexing, Reorg, Binary, Ternary,
-               AggregateUnary, AggregateTernary, Reblock, CSVReblock, 
LIBSVMReblock,
+               AggregateUnary, AggregateUnarySketch, AggregateTernary, 
Reblock, CSVReblock, LIBSVMReblock,
                Builtin, Unary, BuiltinNary, MultiReturnBuiltin, Checkpoint, 
Compression, DeCompression, Cast,
                CentralMoment, Covariance, QSort, QPick,
                ParameterizedBuiltin, MAppend, RAppend, GAppend, 
GAlignedAppend, Rand,
                MatrixReshape, Ctable, Quaternary, CumsumAggregate, 
CumsumOffset, BinUaggChain, UaggOuterChain,
-               Write, SpoofFused, Dnn, AggregateUnarySketch
+               Write, SpoofFused, Dnn
        }
 
        protected final SPType _sptype;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index ee97c8d340..ccddb4db80 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -31,6 +31,8 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.data.*;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
+import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
 import 
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
@@ -355,7 +357,9 @@ public interface LibMatrixCountDistinct {
        }
 
        static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0, 
CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+                       return new 
CountDistinctFunctionSketch(op).getValueFromSketch(arg0);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
                        return new KMVSketch(op).getValueFromSketch(arg0);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
                        throw new NotImplementedException("Not implemented 
yet");
@@ -364,7 +368,9 @@ public interface LibMatrixCountDistinct {
        }
 
        static CorrMatrixBlock createSketch(MatrixBlock blkIn, 
CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+                       return new 
CountDistinctFunctionSketch(op).create(blkIn);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
                        return new KMVSketch(op).create(blkIn);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
                        throw new NotImplementedException("Not implemented 
yet");
@@ -373,7 +379,9 @@ public interface LibMatrixCountDistinct {
        }
 
        static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1, CountDistinctOperator op) {
-               if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+               if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+                       return new CountDistinctFunctionSketch(op).union(arg0, 
arg1);
+               else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
                        return new KMVSketch(op).union(arg0, arg1);
                else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
                        throw new NotImplementedException("Not implemented 
yet");
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
similarity index 85%
rename from 
src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
rename to 
src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
index d5df3b241a..47e7d0a50b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
@@ -17,22 +17,21 @@
  * under the License.
  */
 
-package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+package org.apache.sysds.runtime.matrix.data.sketch;
 
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 // Package private
-abstract class CountDistinctApproxSketch implements MatrixSketch {
-       CountDistinctOperator op;
+public abstract class CountDistinctSketch implements MatrixSketch {
+       public final CountDistinctOperator op;
 
-       CountDistinctApproxSketch(Operator op) {
+       public CountDistinctSketch(Operator op) {
                if(!(op instanceof CountDistinctOperator)) {
                        throw new DMLRuntimeException(
-                               String.format("Cannot create %s with given 
operator", CountDistinctApproxSketch.class.getSimpleName()));
+                               String.format("Cannot create %s with given 
operator", CountDistinctSketch.class.getSimpleName()));
                }
 
                this.op = (CountDistinctOperator) op;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
new file mode 100644
index 0000000000..efdcfa69a6
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinct;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+public class CountDistinctFunctionSketch extends CountDistinctSketch {
+
+       private static final Log LOG = 
LogFactory.getLog(CountDistinctFunctionSketch.class.getName());
+       public CountDistinctFunctionSketch(Operator op) {
+               super(op);
+       }
+
+       @Override
+       public MatrixBlock getValue(MatrixBlock blkIn) {
+               return null;
+       }
+
+       @Override
+       public MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn) {
+               MatrixBlock blkInCorr = blkIn.getCorrection();
+               MatrixBlock blkOut = new MatrixBlock(1, 1, false);
+
+               long res = 0;
+               for (int i=0; i<blkInCorr.getNumRows(); ++i) {
+                       res += blkInCorr.getValue(i, 1);
+               }
+
+               blkOut.setValue(0, 0, res);
+               return blkOut;
+       }
+
+       @Override
+       public CorrMatrixBlock create(MatrixBlock blkIn) {
+               int R = blkIn.getNumRows();
+               int C = blkIn.getNumColumns();
+
+               if (R == 1 && R == C) {
+                       MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
+                       blkOutCorr.setValue(0, 1, 1);
+                       return new CorrMatrixBlock(blkIn, blkOutCorr);
+               }
+
+               if (blkIn.isEmpty()) {
+                       MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
+                       // New matrix block will be initialized to 0, which is 
the correct answer in this case
+                       return new CorrMatrixBlock(blkIn, blkOutCorr);
+               }
+
+               // Double bit repr: [ sign (1) | exponent (11) | fraction (52) ]
+
+               // As we iterate through the input matrix block, we will 
perform a 2-step lookup:
+               // 1. We will first index into the map using the long repr of 
the <sign | exponent> bits
+               // 2. Then, we will perform a lookup on the set containing the 
fraction parts for that given key (exponent)
+
+               // Key      -> [ sign (1) | exponent (11) ] needs 12 bits -> a 
short int is sufficient
+               // Value    -> each [ fraction (52) ] needs a long int
+               Map<Short, Set<Long>> bitMap = new HashMap<>();
+
+               // In the worst case, blkIn can have 1000x1000 = 10e6 unique 
values.
+               // Therefore, total memory is bounded above by:
+               // Max of
+               //  (10e6 * 2) + (10e6 * 8) bytes   -> 10e6 keys, each with a 
single value
+               //  and 2 + (10e6 * 8) bytes        -> a single key, with 
unique 10e6 fraction parts
+               // = (10e6 * 2) + (10e6 * 8) = 10e6 * 10 bytes < 10 MB
+
+               // We have to keep track of all fraction_j in the sketch for 
the union() op to work for large input datasets
+
+               int maxColumns = (int) 
Math.pow(OptimizerUtils.DEFAULT_BLOCKSIZE, 2);
+               for (int i=0; i<R; ++i) {
+                       for (int j=0; j<C; ++j) {
+                               short key = (short) 
extractRightKBitsFromIndex((long) blkIn.getValue(i, j), 52, 12);
+                               long value = extractRightKBitsFromIndex((long) 
blkIn.getValue(i, j), 0, 52);
+
+                               // Update bit map with new (key, value)
+                               Set<Long> fractions = bitMap.getOrDefault(key, 
new HashSet<>());
+                               fractions.add(value);
+                               bitMap.put(key, fractions);
+
+                               maxColumns = Math.max(maxColumns, 
fractions.size());
+                       }
+               }
+
+               MatrixBlock blkOutCorr = serializeInputMatrixBlock(bitMap, 
maxColumns);
+
+               // The sketch contains all relevant info, so the input matrix 
can be discarded at this point
+               return new CorrMatrixBlock(blkIn, blkOutCorr);
+       }
+
+       private long extractRightKBitsFromIndex(long n, int startingIndex, int 
k) {
+               long kMask = (1 << k) - 1;
+               return kMask & (n >> startingIndex);
+       }
+
+       private MatrixBlock serializeInputMatrixBlock(Map<Short, Set<Long>> 
bitMap, int maxWidth) {
+
+               // Each row in output matrix corresponds to a key and each 
column to a fraction value for that key.
+               // The first column will store the exponent value itself:
+               // M x N matrix: row_i: [exponent_i, fraction_i0, fraction_i1, 
.., fraction_iN]
+
+               // Each key has a variable number of fraction values. To avoid 
a jagged matrix,
+               // we will always store the size of the fractions set in the 
second col:
+               // row_i: [exponent_i, N_i, fraction_i0, fraction_i1, .., 
fraction_iN, 0, .., 0]
+               MatrixBlock blkOut = new MatrixBlock(bitMap.size(), maxWidth + 
2, false);
+
+               int i = 0;
+               for (short key : bitMap.keySet()) {
+                       Set<Long> fractions = bitMap.get(key);
+
+                       blkOut.setValue(i, 0, key);
+                       blkOut.setValue(i, 1, fractions.size());
+
+                       int j = 2;
+                       for (long fraction : fractions) {
+                               blkOut.setValue(i, j, fraction);
+                               ++j;
+                       }
+
+                       ++i;
+               }
+               return blkOut;
+       }
+
+       @Override
+       public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock 
arg1) {
+               throw new NotImplementedException("MULTI_BLOCK aggregation is 
not supported yet");
+       }
+
+       @Override
+       public CorrMatrixBlock intersection(CorrMatrixBlock arg0, 
CorrMatrixBlock arg1) {
+               return null;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
index 31e7d15c5d..98bfca0df9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
 import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.utils.Hash;
@@ -42,7 +43,7 @@ import org.apache.sysds.utils.Hash;
  * TODO: Add multi-threaded version
  *
  */
-public class KMVSketch extends CountDistinctApproxSketch {
+public class KMVSketch extends CountDistinctSketch {
 
        private static final Log LOG = 
LogFactory.getLog(KMVSketch.class.getName());
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
index 8f7d9acb8f..f1393ac05e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
@@ -47,9 +47,16 @@ public class CountDistinctRowCol extends 
CountDistinctRowColBase {
        }
 
        @Test
-       public void testSimple1by1() {
+       public void testSimple1by1CP() {
                // test simple 1 by 1.
                ExecType ex = ExecType.CP;
                countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
        }
+
+       @Test
+       public void testSimple1by1Spark() {
+               // test simple 1 by 1.
+               ExecType ex = ExecType.SPARK;
+               countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
index 6b2007548a..5a7a61c6ce 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
@@ -30,6 +30,13 @@ public abstract class CountDistinctRowColBase extends 
CountDistinctBase {
                countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
        }
 
+       @Test
+       public void testSparkDenseSmall() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = baseTolerance + 50 * percentTolerance;
+               countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
+       }
+
        @Test
        public void testCPDenseLarge() {
                ExecType ex = ExecType.CP;
@@ -37,6 +44,13 @@ public abstract class CountDistinctRowColBase extends 
CountDistinctBase {
                countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
        }
 
+       @Test
+       public void testSparkDenseLarge() {
+               ExecType ex = ExecType.SPARK;
+               double tolerance = baseTolerance + 800 * percentTolerance;
+               countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
+       }
+
        @Test
        public void testCPDenseXLarge() {
                ExecType ex = ExecType.CP;

Reply via email to