This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e4491b32c3 [SYSTEMDS-3458] Adding Spark backend support for
countDistinct() builtin
e4491b32c3 is described below
commit e4491b32c373677857a781141d9d79864a922456
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sat Nov 12 18:00:37 2022 -0800
[SYSTEMDS-3458] Adding Spark backend support for countDistinct() builtin
This patch adds support for running countDistinct() builtin on the Spark
backend. The implementation adds a new CountDistinctFunctionSketch to
the sketches library. The sketch is based on an optimized HashMap lookup:
Let [sign (1) | exponent (11) | fraction (52)] represent the bits
of a double value
- Each double is split into 2 parts:
- PartA: [sign (1) | exponent (11)] and
- PartB: [fraction (52)]
- PartA is used as the key into the HashMap, and PartB is stored
in the Set<Long> value
- Many values in blkIn may have the same exponent, so we need to
store multiple PartB values for any given PartA key
- The HashMap is then serialized into the sketch like so:
row_i: [exponent_i, N_i, fraction_i0, fraction_i1, ..,
fraction_iN, 0, .., 0]
- exponent_i: it is the short int value of the bit representation
of the original double input
- N_i: we store the number of fractions per exponent value to
avoid dealing with jagged matrices
NB: Although countDistinct() does not strictly require us to keep track
of the individual input elements, we must do so in the sketch so we can
support the union() op for the MULTI_BLOCK case (see notes below)
In the worst case, blkIn can have 1000x1000 = 10e6 unique values.
Therefore, total memory is bounded above by:
Max of
(10e6 * 2) + (10e6 * 8) bytes -> 10e6 keys, each with a single
value
and 2 + (10e6 * 8) bytes -> a single key, with unique 10e6
fraction parts
= (10e6 * 2) + (10e6 * 8) = 10e6 * 10 bytes < 10 MB
Notes to Reviewers:
- The current implementation only supports the SINGLE_BLOCK case,
i.e. the current implementation works for input matrices of max
size 1000 x 1000. Larger inputs (along either dimension) will throw
a `NotImplementedException`. MULTI_BLOCK aggregation support will
be added in a future patch.
- This patch adds support for only the default RowCol direction;
we will add support for Row/Col direction in a future patch.
Test currently cover Unit tests not Integration tests
Closes #1727
Closes #1726
---
.../runtime/instructions/InstructionUtils.java | 1 +
.../runtime/instructions/SPInstructionParser.java | 12 +-
.../spark/AggregateUnarySketchSPInstruction.java | 9 +-
.../runtime/instructions/spark/SPInstruction.java | 4 +-
.../matrix/data/LibMatrixCountDistinct.java | 14 +-
...tApproxSketch.java => CountDistinctSketch.java} | 11 +-
.../countdistinct/CountDistinctFunctionSketch.java | 162 +++++++++++++++++++++
.../data/sketch/countdistinctapprox/KMVSketch.java | 3 +-
.../countDistinct/CountDistinctRowCol.java | 9 +-
.../countDistinct/CountDistinctRowColBase.java | 14 ++
10 files changed, 219 insertions(+), 20 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index c27d5a70ae..4f4fac1e38 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -896,6 +896,7 @@ public class InstructionUtils
case "uak+":
case"uark+":
case "uack+": return "ak+";
+ case "uacd":
case "ua+":
case "uar+":
case "uac+": return "a+";
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 9496cf465e..0ff745de57 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -126,14 +126,14 @@ public class SPInstructionParser extends InstructionParser
String2SPInstructionType.put( "uac*" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uatrace" ,
SPType.AggregateUnary);
String2SPInstructionType.put( "uaktrace",
SPType.AggregateUnary);
- String2SPInstructionType.put( "uacd" ,
SPType.AggregateUnary);
- String2SPInstructionType.put( "uacdr" ,
SPType.AggregateUnary);
- String2SPInstructionType.put( "uacdc" ,
SPType.AggregateUnary);
// Aggregate unary sketch operators
- String2SPInstructionType.put( "uacdap" ,
SPType.AggregateUnarySketch);
- String2SPInstructionType.put( "uacdapr",
SPType.AggregateUnarySketch);
- String2SPInstructionType.put( "uacdapc",
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacd" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdr" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdc" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdap" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdapr" ,
SPType.AggregateUnarySketch);
+ String2SPInstructionType.put( "uacdapc" ,
SPType.AggregateUnarySketch);
//binary aggregate operators (matrix multiplication operators)
String2SPInstructionType.put( "mapmm" , SPType.MAPMM);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index 703828e3a1..767e4b0c0b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -66,7 +66,14 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
AggBinaryOp.SparkAggType aggtype =
AggBinaryOp.SparkAggType.valueOf(parts[3]);
CountDistinctOperator cdop = null;
- if (opcode.equals("uacdap")) {
+ if (opcode.equals("uacd")) {
+ cdop = new CountDistinctOperator(CountDistinctOperatorTypes.COUNT,
Types.Direction.RowCol,
+ ReduceAll.getReduceAllFnObject(),
Hash.HashType.LinearHash);
+ } else if (opcode.equals("uacdr")) {
+ throw new NotImplementedException("uacdr has not been implemented
yet");
+ } else if (opcode.equals("uacdc")) {
+ throw new NotImplementedException("uacdc has not been implemented
yet");
+ } else if (opcode.equals("uacdap")) {
cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV,
Types.Direction.RowCol,
ReduceAll.getReduceAllFnObject(),
Hash.HashType.LinearHash);
} else if (opcode.equals("uacdapr")) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
index 830ba4dae3..263923de05 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/SPInstruction.java
@@ -32,12 +32,12 @@ public abstract class SPInstruction extends Instruction {
public enum SPType {
MAPMM, MAPMMCHAIN, CPMM, RMM, TSMM, TSMM2, PMM, ZIPMM, PMAPMM,
//matrix multiplication instructions
MatrixIndexing, Reorg, Binary, Ternary,
- AggregateUnary, AggregateTernary, Reblock, CSVReblock,
LIBSVMReblock,
+ AggregateUnary, AggregateUnarySketch, AggregateTernary,
Reblock, CSVReblock, LIBSVMReblock,
Builtin, Unary, BuiltinNary, MultiReturnBuiltin, Checkpoint,
Compression, DeCompression, Cast,
CentralMoment, Covariance, QSort, QPick,
ParameterizedBuiltin, MAppend, RAppend, GAppend,
GAlignedAppend, Rand,
MatrixReshape, Ctable, Quaternary, CumsumAggregate,
CumsumOffset, BinUaggChain, UaggOuterChain,
- Write, SpoofFused, Dnn, AggregateUnarySketch
+ Write, SpoofFused, Dnn
}
protected final SPType _sptype;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index ee97c8d340..ccddb4db80 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -31,6 +31,8 @@ import org.apache.sysds.api.DMLException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.*;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
+import
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
import
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
@@ -355,7 +357,9 @@ public interface LibMatrixCountDistinct {
}
static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0,
CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+ return new
CountDistinctFunctionSketch(op).getValueFromSketch(arg0);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
return new KMVSketch(op).getValueFromSketch(arg0);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
throw new NotImplementedException("Not implemented
yet");
@@ -364,7 +368,9 @@ public interface LibMatrixCountDistinct {
}
static CorrMatrixBlock createSketch(MatrixBlock blkIn,
CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+ return new
CountDistinctFunctionSketch(op).create(blkIn);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
return new KMVSketch(op).create(blkIn);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
throw new NotImplementedException("Not implemented
yet");
@@ -373,7 +379,9 @@ public interface LibMatrixCountDistinct {
}
static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0,
CorrMatrixBlock arg1, CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
+ if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
+ return new CountDistinctFunctionSketch(op).union(arg0,
arg1);
+ else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
return new KMVSketch(op).union(arg0, arg1);
else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
throw new NotImplementedException("Not implemented
yet");
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
similarity index 85%
rename from
src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
rename to
src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
index d5df3b241a..47e7d0a50b 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/CountDistinctApproxSketch.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/CountDistinctSketch.java
@@ -17,22 +17,21 @@
* under the License.
*/
-package org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox;
+package org.apache.sysds.runtime.matrix.data.sketch;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
-import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
// Package private
-abstract class CountDistinctApproxSketch implements MatrixSketch {
- CountDistinctOperator op;
+public abstract class CountDistinctSketch implements MatrixSketch {
+ public final CountDistinctOperator op;
- CountDistinctApproxSketch(Operator op) {
+ public CountDistinctSketch(Operator op) {
if(!(op instanceof CountDistinctOperator)) {
throw new DMLRuntimeException(
- String.format("Cannot create %s with given
operator", CountDistinctApproxSketch.class.getSimpleName()));
+ String.format("Cannot create %s with given
operator", CountDistinctSketch.class.getSimpleName()));
}
this.op = (CountDistinctOperator) op;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
new file mode 100644
index 0000000000..efdcfa69a6
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinct/CountDistinctFunctionSketch.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch.countdistinct;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+
+public class CountDistinctFunctionSketch extends CountDistinctSketch {
+
+ private static final Log LOG =
LogFactory.getLog(CountDistinctFunctionSketch.class.getName());
+ public CountDistinctFunctionSketch(Operator op) {
+ super(op);
+ }
+
+ @Override
+ public MatrixBlock getValue(MatrixBlock blkIn) {
+ return null;
+ }
+
+ @Override
+ public MatrixBlock getValueFromSketch(CorrMatrixBlock blkIn) {
+ MatrixBlock blkInCorr = blkIn.getCorrection();
+ MatrixBlock blkOut = new MatrixBlock(1, 1, false);
+
+ long res = 0;
+ for (int i=0; i<blkInCorr.getNumRows(); ++i) {
+ res += blkInCorr.getValue(i, 1);
+ }
+
+ blkOut.setValue(0, 0, res);
+ return blkOut;
+ }
+
+ @Override
+ public CorrMatrixBlock create(MatrixBlock blkIn) {
+ int R = blkIn.getNumRows();
+ int C = blkIn.getNumColumns();
+
+ if (R == 1 && R == C) {
+ MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
+ blkOutCorr.setValue(0, 1, 1);
+ return new CorrMatrixBlock(blkIn, blkOutCorr);
+ }
+
+ if (blkIn.isEmpty()) {
+ MatrixBlock blkOutCorr = new MatrixBlock(1, 2, false);
+ // New matrix block will be initialized to 0, which is
the correct answer in this case
+ return new CorrMatrixBlock(blkIn, blkOutCorr);
+ }
+
+ // Double bit repr: [ sign (1) | exponent (11) | fraction (52) ]
+
+ // As we iterate through the input matrix block, we will
perform a 2-step lookup:
+ // 1. We will first index into the map using the long repr of
the <sign | exponent> bits
+ // 2. Then, we will perform a lookup on the set containing the
fraction parts for that given key (exponent)
+
+ // Key -> [ sign (1) | exponent (11) ] needs 12 bits -> a
short int is sufficient
+ // Value -> each [ fraction (52) ] needs a long int
+ Map<Short, Set<Long>> bitMap = new HashMap<>();
+
+ // In the worst case, blkIn can have 1000x1000 = 10e6 unique
values.
+ // Therefore, total memory is bounded above by:
+ // Max of
+ // (10e6 * 2) + (10e6 * 8) bytes -> 10e6 keys, each with a
single value
+ // and 2 + (10e6 * 8) bytes -> a single key, with
unique 10e6 fraction parts
+ // = (10e6 * 2) + (10e6 * 8) = 10e6 * 10 bytes < 10 MB
+
+ // We have to keep track of all fraction_j in the sketch for
the union() op to work for large input datasets
+
+ int maxColumns = (int)
Math.pow(OptimizerUtils.DEFAULT_BLOCKSIZE, 2);
+ for (int i=0; i<R; ++i) {
+ for (int j=0; j<C; ++j) {
+ short key = (short)
extractRightKBitsFromIndex((long) blkIn.getValue(i, j), 52, 12);
+ long value = extractRightKBitsFromIndex((long)
blkIn.getValue(i, j), 0, 52);
+
+ // Update bit map with new (key, value)
+ Set<Long> fractions = bitMap.getOrDefault(key,
new HashSet<>());
+ fractions.add(value);
+ bitMap.put(key, fractions);
+
+ maxColumns = Math.max(maxColumns,
fractions.size());
+ }
+ }
+
+ MatrixBlock blkOutCorr = serializeInputMatrixBlock(bitMap,
maxColumns);
+
+ // The sketch contains all relevant info, so the input matrix
can be discarded at this point
+ return new CorrMatrixBlock(blkIn, blkOutCorr);
+ }
+
+ private long extractRightKBitsFromIndex(long n, int startingIndex, int
k) {
+ long kMask = (1 << k) - 1;
+ return kMask & (n >> startingIndex);
+ }
+
+ private MatrixBlock serializeInputMatrixBlock(Map<Short, Set<Long>>
bitMap, int maxWidth) {
+
+ // Each row in output matrix corresponds to a key and each
column to a fraction value for that key.
+ // The first column will store the exponent value itself:
+ // M x N matrix: row_i: [exponent_i, fraction_i0, fraction_i1,
.., fraction_iN]
+
+ // Each key has a variable number of fraction values. To avoid
a jagged matrix,
+ // we will always store the size of the fractions set in the
second col:
+ // row_i: [exponent_i, N_i, fraction_i0, fraction_i1, ..,
fraction_iN, 0, .., 0]
+ MatrixBlock blkOut = new MatrixBlock(bitMap.size(), maxWidth +
2, false);
+
+ int i = 0;
+ for (short key : bitMap.keySet()) {
+ Set<Long> fractions = bitMap.get(key);
+
+ blkOut.setValue(i, 0, key);
+ blkOut.setValue(i, 1, fractions.size());
+
+ int j = 2;
+ for (long fraction : fractions) {
+ blkOut.setValue(i, j, fraction);
+ ++j;
+ }
+
+ ++i;
+ }
+ return blkOut;
+ }
+
+ @Override
+ public CorrMatrixBlock union(CorrMatrixBlock arg0, CorrMatrixBlock
arg1) {
+ throw new NotImplementedException("MULTI_BLOCK aggregation is
not supported yet");
+ }
+
+ @Override
+ public CorrMatrixBlock intersection(CorrMatrixBlock arg0,
CorrMatrixBlock arg1) {
+ return null;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
index 31e7d15c5d..98bfca0df9 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/countdistinctapprox/KMVSketch.java
@@ -28,6 +28,7 @@ import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.Hash;
@@ -42,7 +43,7 @@ import org.apache.sysds.utils.Hash;
* TODO: Add multi-threaded version
*
*/
-public class KMVSketch extends CountDistinctApproxSketch {
+public class KMVSketch extends CountDistinctSketch {
private static final Log LOG =
LogFactory.getLog(KMVSketch.class.getName());
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
index 8f7d9acb8f..f1393ac05e 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowCol.java
@@ -47,9 +47,16 @@ public class CountDistinctRowCol extends
CountDistinctRowColBase {
}
@Test
- public void testSimple1by1() {
+ public void testSimple1by1CP() {
// test simple 1 by 1.
ExecType ex = ExecType.CP;
countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
}
+
+ @Test
+ public void testSimple1by1Spark() {
+ // test simple 1 by 1.
+ ExecType ex = ExecType.SPARK;
+ countDistinctScalarTest(1, 1, 1, 1.0, ex, 0.00001);
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
index 6b2007548a..5a7a61c6ce 100644
---
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
+++
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowColBase.java
@@ -30,6 +30,13 @@ public abstract class CountDistinctRowColBase extends
CountDistinctBase {
countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
}
+ @Test
+ public void testSparkDenseSmall() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = baseTolerance + 50 * percentTolerance;
+ countDistinctScalarTest(50, 50, 50, 1.0, ex, tolerance);
+ }
+
@Test
public void testCPDenseLarge() {
ExecType ex = ExecType.CP;
@@ -37,6 +44,13 @@ public abstract class CountDistinctRowColBase extends
CountDistinctBase {
countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
}
+ @Test
+ public void testSparkDenseLarge() {
+ ExecType ex = ExecType.SPARK;
+ double tolerance = baseTolerance + 800 * percentTolerance;
+ countDistinctScalarTest(800, 1000, 1000, 1.0, ex, tolerance);
+ }
+
@Test
public void testCPDenseXLarge() {
ExecType ex = ExecType.CP;