Repository: spark
Updated Branches:
  refs/heads/master a5f02b002 -> d3c90b74e


[SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation

## What changes were proposed in this pull request?
SPARK-18429 introduced count-min sketch aggregate function for SQL, but the 
implementation and testing is more complicated than needed. This simplifies the 
test cases and removes support for data types that don't have clear equality 
semantics:

1. Removed support for floating point and decimal types.

2. Removed the heavy randomized tests. The underlying CountMinSketch 
implementation already had pretty good test coverage through randomized tests, 
and the SPARK-18429 implementation is just to add an aggregate function wrapper 
around CountMinSketch. There is no need for randomized tests at three different 
levels of the implementations.

## How was this patch tested?
A lot of the change is to simplify test cases.

Author: Reynold Xin <r...@databricks.com>

Closes #16093 from rxin/SPARK-18663.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d3c90b74
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d3c90b74
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d3c90b74

Branch: refs/heads/master
Commit: d3c90b74edecc527ee468bead41d1cca0b667668
Parents: a5f02b0
Author: Reynold Xin <r...@databricks.com>
Authored: Thu Dec 1 21:38:52 2016 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Thu Dec 1 21:38:52 2016 -0800

----------------------------------------------------------------------
 .../spark/util/sketch/CountMinSketch.java       |  22 +-
 .../spark/util/sketch/CountMinSketchImpl.java   |  50 ++-
 .../spark/util/sketch/CountMinSketchSuite.scala |  40 +--
 project/MimaExcludes.scala                      |   8 +-
 .../aggregate/CountMinSketchAgg.scala           |  27 +-
 .../aggregate/ApproximatePercentileSuite.scala  |   2 +-
 .../aggregate/CountMinSketchAggSuite.scala      | 304 ++++++-------------
 .../sql/ApproximatePercentileQuerySuite.scala   |   3 +
 .../spark/sql/CountMinSketchAggQuerySuite.scala | 176 +----------
 9 files changed, 177 insertions(+), 455 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 0011096..f7c22dd 100644
--- 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -17,12 +17,13 @@
 
 package org.apache.spark.util.sketch;
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 
 /**
- * A Count-min sketch is a probabilistic data structure used for summarizing 
streams of data in
+ * A Count-min sketch is a probabilistic data structure used for cardinality 
estimation using
  * sub-linear space.  Currently, supported data types include:
  * <ul>
  *   <li>{@link Byte}</li>
@@ -30,10 +31,6 @@ import java.io.OutputStream;
  *   <li>{@link Integer}</li>
  *   <li>{@link Long}</li>
  *   <li>{@link String}</li>
- *   <li>{@link Float}</li>
- *   <li>{@link Double}</li>
- *   <li>{@link java.math.BigDecimal}</li>
- *   <li>{@link Boolean}</li>
  * </ul>
  * A {@link CountMinSketch} is initialized with a random seed, and a pair of 
parameters:
  * <ol>
@@ -178,6 +175,11 @@ public abstract class CountMinSketch {
   public abstract void writeTo(OutputStream out) throws IOException;
 
   /**
+   * Serializes this {@link CountMinSketch} and returns the serialized form.
+   */
+  public abstract byte[] toByteArray() throws IOException;
+
+  /**
    * Reads in a {@link CountMinSketch} from an input stream. It is the 
caller's responsibility to
    * close the stream.
    */
@@ -186,6 +188,16 @@ public abstract class CountMinSketch {
   }
 
   /**
+   * Reads in a {@link CountMinSketch} from a byte array.
+   */
+  public static CountMinSketch readFrom(byte[] bytes) throws IOException {
+    InputStream in = new ByteArrayInputStream(bytes);
+    CountMinSketch cms = readFrom(in);
+    in.close();
+    return cms;
+  }
+
+  /**
    * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, 
and random
    * {@code seed}.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 94ab3a9..045fec3 100644
--- 
a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ 
b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -17,15 +17,7 @@
 
 package org.apache.spark.util.sketch;
 
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.OutputStream;
-import java.io.Serializable;
-import java.math.BigDecimal;
+import java.io.*;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -153,16 +145,8 @@ class CountMinSketchImpl extends CountMinSketch implements 
Serializable {
   public void add(Object item, long count) {
     if (item instanceof String) {
       addString((String) item, count);
-    } else if (item instanceof BigDecimal) {
-      addString(((BigDecimal) item).toString(), count);
     } else if (item instanceof byte[]) {
       addBinary((byte[]) item, count);
-    } else if (item instanceof Float) {
-      addLong(Float.floatToIntBits((Float) item), count);
-    } else if (item instanceof Double) {
-      addLong(Double.doubleToLongBits((Double) item), count);
-    } else if (item instanceof Boolean) {
-      addLong(((Boolean) item) ? 1L : 0L, count);
     } else {
       addLong(Utils.integralToLong(item), count);
     }
@@ -227,6 +211,10 @@ class CountMinSketchImpl extends CountMinSketch implements 
Serializable {
     return ((int) hash) % width;
   }
 
+  private static int[] getHashBuckets(String key, int hashCount, int max) {
+    return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
+  }
+
   private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
     int[] result = new int[hashCount];
     int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, 
b.length, 0);
@@ -240,18 +228,9 @@ class CountMinSketchImpl extends CountMinSketch implements 
Serializable {
   @Override
   public long estimateCount(Object item) {
     if (item instanceof String) {
-      return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) 
item));
-    } else if (item instanceof BigDecimal) {
-      return estimateCountForBinaryItem(
-        Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
+      return estimateCountForStringItem((String) item);
     } else if (item instanceof byte[]) {
       return estimateCountForBinaryItem((byte[]) item);
-    } else if (item instanceof Float) {
-      return estimateCountForLongItem(Float.floatToIntBits((Float) item));
-    } else if (item instanceof Double) {
-      return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
-    } else if (item instanceof Boolean) {
-      return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
     } else {
       return estimateCountForLongItem(Utils.integralToLong(item));
     }
@@ -265,6 +244,15 @@ class CountMinSketchImpl extends CountMinSketch implements 
Serializable {
     return res;
   }
 
+  private long estimateCountForStringItem(String item) {
+    long res = Long.MAX_VALUE;
+    int[] buckets = getHashBuckets(item, depth, width);
+    for (int i = 0; i < depth; ++i) {
+      res = Math.min(res, table[i][buckets[i]]);
+    }
+    return res;
+  }
+
   private long estimateCountForBinaryItem(byte[] item) {
     long res = Long.MAX_VALUE;
     int[] buckets = getHashBuckets(item, depth, width);
@@ -332,6 +320,14 @@ class CountMinSketchImpl extends CountMinSketch implements 
Serializable {
     }
   }
 
+  @Override
+  public byte[] toByteArray() throws IOException {
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    writeTo(out);
+    out.close();
+    return out.toByteArray();
+  }
+
   public static CountMinSketchImpl readFrom(InputStream in) throws IOException 
{
     CountMinSketchImpl sketch = new CountMinSketchImpl();
     sketch.readFrom0(in);

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
----------------------------------------------------------------------
diff --git 
a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
 
b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index 2c358fc..174eb01 100644
--- 
a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ 
b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.util.sketch
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import java.nio.charset.StandardCharsets
 
 import scala.reflect.ClassTag
 import scala.util.Random
@@ -26,9 +25,9 @@ import scala.util.Random
 import org.scalatest.FunSuite // scalastyle:ignore funsuite
 
 class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
-  private val epsOfTotalCount = 0.0001
+  private val epsOfTotalCount = 0.01
 
-  private val confidence = 0.99
+  private val confidence = 0.9
 
   private val seed = 42
 
@@ -45,12 +44,6 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
   }
 
   def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): 
Unit = {
-    def getProbeItem(item: T): Any = item match {
-      // Use a string to represent the content of an array of bytes
-      case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-      case i => identity(i)
-    }
-
     test(s"accuracy - $typeName") {
       // Uses fixed seed to ensure reproducible test execution
       val r = new Random(31)
@@ -63,7 +56,7 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
       val exactFreq = {
         val sampledItems = sampledItemIndices.map(allItems)
-        sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+        sampledItems.groupBy(identity).mapValues(_.length.toLong)
       }
 
       val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -74,12 +67,12 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
       val probCorrect = {
         val numErrors = allItems.map { item =>
-          val count = exactFreq.getOrElse(getProbeItem(item), 0L)
+          val count = exactFreq.getOrElse(item, 0L)
           val ratio = (sketch.estimateCount(item) - count).toDouble / 
numAllItems
           if (ratio > epsOfTotalCount) 1 else 0
         }.sum
 
-        1D - numErrors.toDouble / numAllItems
+        1.0 - (numErrors.toDouble / numAllItems)
       }
 
       assert(
@@ -96,9 +89,7 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
       val numToMerge = 5
       val numItemsPerSketch = 100000
-      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
-        itemGenerator(r)
-      }
+      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { 
itemGenerator(r) }
 
       val sketches = perSketchItems.map { items =>
         val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -113,11 +104,8 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
       val mergedSketch = sketches.reduce(_ mergeInPlace _)
       checkSerDe(mergedSketch)
 
-      val expectedSketch = {
-        val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-        perSketchItems.foreach(_.foreach(sketch.add))
-        sketch
-      }
+      val expectedSketch = CountMinSketch.create(epsOfTotalCount, confidence, 
seed)
+      perSketchItems.foreach(_.foreach(expectedSketch.add))
 
       perSketchItems.foreach {
         _.foreach { item =>
@@ -142,17 +130,7 @@ class CountMinSketchSuite extends FunSuite { // 
scalastyle:ignore funsuite
 
   testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
 
-  testItemType[Float]("Float") { _.nextFloat() }
-
-  testItemType[Double]("Double") { _.nextDouble() }
-
-  testItemType[java.math.BigDecimal]("Decimal") { r => new 
java.math.BigDecimal(r.nextDouble()) }
-
-  testItemType[Boolean]("Boolean") { _.nextBoolean() }
-
-  testItemType[Array[Byte]]("Binary") { r =>
-    Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
-  }
+  testItemType[Array[Byte]]("Byte array") { r => 
r.nextString(r.nextInt(60)).getBytes }
 
   test("incompatible merge") {
     intercept[IncompatibleMergeException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4995af0..b113bbf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,6 +34,11 @@ import com.typesafe.tools.mima.core.ProblemFilters._
  */
 object MimaExcludes {
 
+  lazy val v22excludes = v21excludes ++ Seq(
+    // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
+    
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray")
+  )
+
   // Exclude rules for 2.1.x
   lazy val v21excludes = v20excludes ++ {
     Seq(
@@ -912,7 +917,8 @@ object MimaExcludes {
   }
 
   def excludes(version: String) = version match {
-    case v if v.startsWith("2.1") => v21excludes
+    case v if v.startsWith("2.2") => v22excludes
+    case v if v.startsWith("2.1") => v22excludes  // TODO: Update this when we 
bump version to 2.2
     case v if v.startsWith("2.0") => v20excludes
     case _ => Seq()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index f5f185f..612c198 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, 
TypeCheckSuccess}
@@ -42,9 +40,9 @@ import org.apache.spark.util.sketch.CountMinSketch
 @ExpressionDescription(
   usage = """
     _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a 
column with the given esp,
-      confidence and seed. The result is an array of bytes, which should be 
deserialized to a
-      `CountMinSketch` before usage. `CountMinSketch` is useful for equality 
predicates and join
-      size estimation.
+      confidence and seed. The result is an array of bytes, which can be 
deserialized to a
+      `CountMinSketch` before usage. Count-min sketch is a probabilistic data 
structure used for
+      cardinality estimation using sub-linear space.
   """)
 case class CountMinSketchAgg(
     child: Expression,
@@ -75,13 +73,13 @@ case class CountMinSketchAgg(
     } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
       !seedExpression.foldable) {
       TypeCheckFailure(
-        "The eps, confidence or seed provided must be a literal or constant 
foldable")
+        "The eps, confidence or seed provided must be a literal or foldable")
     } else if (epsExpression.eval() == null || confidenceExpression.eval() == 
null ||
       seedExpression.eval() == null) {
       TypeCheckFailure("The eps, confidence or seed provided should not be 
null")
-    } else if (eps <= 0D) {
+    } else if (eps <= 0.0) {
       TypeCheckFailure(s"Relative error must be positive (current value = 
$eps)")
-    } else if (confidence <= 0D || confidence >= 1D) {
+    } else if (confidence <= 0.0 || confidence >= 1.0) {
       TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current 
value = $confidence)")
     } else {
       TypeCheckSuccess
@@ -97,9 +95,6 @@ case class CountMinSketchAgg(
     // Ignore empty rows
     if (value != null) {
       child.dataType match {
-        // `Decimal` and `UTF8String` are internal types in spark sql, we need 
to convert them
-        // into acceptable types for `CountMinSketch`.
-        case DecimalType() => 
buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
         // For string type, we can get bytes of our `UTF8String` directly, and 
call the `addBinary`
         // instead of `addString` to avoid unnecessary conversion.
         case StringType => 
buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
@@ -115,14 +110,11 @@ case class CountMinSketchAgg(
   override def eval(buffer: CountMinSketch): Any = serialize(buffer)
 
   override def serialize(buffer: CountMinSketch): Array[Byte] = {
-    val out = new ByteArrayOutputStream()
-    buffer.writeTo(out)
-    out.toByteArray
+    buffer.toByteArray
   }
 
   override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
-    val in = new ByteArrayInputStream(storageFormat)
-    CountMinSketch.readFrom(in)
+    CountMinSketch.readFrom(storageFormat)
   }
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): 
CountMinSketchAgg =
@@ -132,8 +124,7 @@ case class CountMinSketchAgg(
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, 
BooleanType, BinaryType),
-      DoubleType, DoubleType, IntegerType)
+    Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, 
DoubleType, IntegerType)
   }
 
   override def nullable: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
index 8456e24..fcb370a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -86,7 +86,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
       (headBufferSize + bufferSize) * 2
     }
 
-    val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { 
count =>
+    Seq(100, 1000, 10000, 100000, 1000000, 10000000).foreach { count =>
       val buffer = new PercentileDigest(relativeError)
       // Worst case, data is linear sorted
       (0 until count).foreach(buffer.add(_))

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 6e08e29..1047963 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -17,199 +17,114 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
+import java.{lang => jl}
 
-import scala.reflect.ClassTag
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
BoundReference, Cast, GenericInternalRow, Literal}
-import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * Unit test suite for the count-min sketch SQL aggregate funciton 
[[CountMinSketchAgg]].
+ */
 class CountMinSketchAggSuite extends SparkFunSuite {
   private val childExpression = BoundReference(0, IntegerType, nullable = true)
   private val epsOfTotalCount = 0.0001
   private val confidence = 0.99
   private val seed = 42
-
-  test("serialize and de-serialize") {
-    // Check empty serialize and de-serialize
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), 
Literal(confidence),
-      Literal(seed))
-    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
-
-    // Check non-empty serialize and de-serialize
-    val random = new Random(31)
-    (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
-      buffer.add(value)
-    }
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+  private val rand = new Random(seed)
+
+  /** Creates a count-min sketch aggregate expression, using the child 
expression defined above. */
+  private def cms(eps: jl.Double, confidence: jl.Double, seed: jl.Integer): 
CountMinSketchAgg = {
+    new CountMinSketchAgg(
+      child = childExpression,
+      epsExpression = Literal(eps, DoubleType),
+      confidenceExpression = Literal(confidence, DoubleType),
+      seedExpression = Literal(seed, IntegerType))
   }
 
-  def testHighLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"high level interface, update, merge, eval... - $dataType") {
+  /**
+   * Creates a new test case that compares our aggregate function with a 
reference implementation
+   * (using the underlying [[CountMinSketch]]).
+   *
+   * This works by splitting the items into two separate groups, aggregates 
them, and then merges
+   * the two groups back (to emulate partial aggregation), and then compares 
the result with
+   * that generated by [[CountMinSketch]] directly. This assumes insertion 
order does not impact
+   * the result in count-min sketch.
+   */
+  private def testDataType[T](dataType: DataType, items: Seq[T]): Unit = {
+    test("test data type " + dataType) {
       val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = 
true),
         Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
       assert(!agg.nullable)
 
-      val group1 = 0 until sampledItemIndices.length / 2
-      val group1Buffer = agg.createAggregationBuffer()
-      group1.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group1Buffer, input)
+      val (seq1, seq2) = items.splitAt(items.size / 2)
+      val buf1 = addToAggregateBuffer(agg, seq1)
+      val buf2 = addToAggregateBuffer(agg, seq2)
+
+      val sketch = agg.createAggregationBuffer()
+      agg.merge(sketch, buf1)
+      agg.merge(sketch, buf2)
+
+      // Validate cardinality estimation against reference implementation.
+      val referenceSketch = CountMinSketch.create(epsOfTotalCount, confidence, 
seed)
+      items.foreach { item =>
+        referenceSketch.add(item match {
+          case u: UTF8String => u.getBytes
+          case _ => item
+        })
       }
 
-      val group2 = sampledItemIndices.length / 2 until 
sampledItemIndices.length
-      val group2Buffer = agg.createAggregationBuffer()
-      group2.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group2Buffer, input)
+      items.foreach { item =>
+        withClue(s"For item $item") {
+          val itemToTest = item match {
+            case u: UTF8String => u.getBytes
+            case _ => item
+          }
+          assert(referenceSketch.estimateCount(itemToTest) == 
sketch.estimateCount(itemToTest))
+        }
       }
-
-      var mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group1Buffer)
-      agg.merge(mergeBuffer, group2Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge in a different order
-      mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group2Buffer)
-      agg.merge(mergeBuffer, group1Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge with an empty partition
-      val emptyBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, emptyBuffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
     }
-  }
-
-  def testLowLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"low level interface, update, merge, eval... - 
${dataType.typeName}") {
-      val inputAggregationBufferOffset = 1
-      val mutableAggregationBufferOffset = 2
 
-      // Phase one, partial mode aggregation
-      val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = 
true),
-        Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
-        .withNewInputAggBufferOffset(inputAggregationBufferOffset)
-        .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
-
-      val mutableAggBuffer = new GenericInternalRow(
-        new Array[Any](mutableAggregationBufferOffset + 1))
-      agg.initialize(mutableAggBuffer)
-
-      sampledItemIndices.foreach { i =>
-        agg.update(mutableAggBuffer, InternalRow(allItems(i)))
-      }
-      agg.serializeAggregateBufferInPlace(mutableAggBuffer)
-
-      // Serialize the aggregation buffer
-      val serialized = 
mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
-      val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
-
-      // Phase 2: final mode aggregation
-      // Re-initialize the aggregation buffer
-      agg.initialize(mutableAggBuffer)
-      agg.merge(mutableAggBuffer, inputAggBuffer)
-      checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+    def addToAggregateBuffer[T](agg: CountMinSketchAgg, items: Seq[T]): 
CountMinSketch = {
+      val buf = agg.createAggregationBuffer()
+      items.foreach { item => agg.update(buf, InternalRow(item)) }
+      buf
     }
   }
 
-  private def checkResult[T: ClassTag](
-      result: Any,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    result match {
-      case bytesData: Array[Byte] =>
-        val in = new ByteArrayInputStream(bytesData)
-        val cms = CountMinSketch.readFrom(in)
-        val probCorrect = {
-          val numErrors = data.map { i =>
-            val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-            val item = i match {
-              case dec: Decimal => dec.toJavaBigDecimal
-              case str: UTF8String => str.getBytes
-              case _ => i
-            }
-            val ratio = (cms.estimateCount(item) - count).toDouble / 
data.length
-            if (ratio > epsOfTotalCount) 1 else 0
-          }.sum
+  testDataType[Byte](ByteType, Seq.fill(100) { rand.nextInt(10).toByte })
 
-          1D - numErrors.toDouble / data.length
-        }
+  testDataType[Short](ShortType, Seq.fill(100) { rand.nextInt(10).toShort })
 
-        assert(
-          probCorrect > confidence,
-          s"Confidence not reached: required $confidence, reached $probCorrect"
-        )
-      case _ => fail("unexpected return type")
-    }
-  }
+  testDataType[Int](IntegerType, Seq.fill(100) { rand.nextInt(10) })
 
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+  testDataType[Long](LongType, Seq.fill(100) { rand.nextInt(10) })
 
-  def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => 
T): Unit = {
-    // Uses fixed seed to ensure reproducible test execution
-    val r = new Random(31)
+  testDataType[UTF8String](StringType, Seq.fill(100) { 
UTF8String.fromString(rand.nextString(1)) })
 
-    val numAllItems = 1000000
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
+  testDataType[Array[Byte]](BinaryType, Seq.fill(100) { 
rand.nextString(1).getBytes() })
 
-    val numSamples = numAllItems / 10
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+  test("serialize and de-serialize") {
+    // Check empty serialize and de-serialize
+    val agg = cms(epsOfTotalCount, confidence, seed)
+    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
 
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+    // Check non-empty serialize and de-serialize
+    val random = new Random(31)
+    for (i <- 0 until 10) {
+      buffer.add(random.nextInt(100))
     }
-
-    testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
-    testHighLevelInterface[T](dataType, sampledItemIndices, allItems, 
exactFreq)
-  }
-
-  testItemType[Byte](ByteType) { _.nextInt().toByte }
-
-  testItemType[Short](ShortType) { _.nextInt().toShort }
-
-  testItemType[Int](IntegerType) { _.nextInt() }
-
-  testItemType[Long](LongType) { _.nextLong() }
-
-  testItemType[UTF8String](StringType) { r => 
UTF8String.fromString(r.nextString(r.nextInt(20))) }
-
-  testItemType[Float](FloatType) { _.nextFloat() }
-
-  testItemType[Double](DoubleType) { _.nextDouble() }
-
-  testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
-
-  testItemType[Boolean](BooleanType) { _.nextBoolean() }
-
-  testItemType[Array[Byte]](BinaryType) { r =>
-    r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
   }
 
-
-  test("fails analysis if eps, confidence or seed provided is not a literal or 
constant foldable") {
+  test("fails analysis if eps, confidence or seed provided is not foldable") {
     val wrongEps = new CountMinSketchAgg(
       childExpression,
       epsExpression = AttributeReference("a", DoubleType)(),
@@ -227,88 +142,55 @@ class CountMinSketchAggSuite extends SparkFunSuite {
       seedExpression = AttributeReference("c", IntegerType)())
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          "The eps, confidence or seed provided must be a literal or constant 
foldable")
-      )
+      assertResult(
+        TypeCheckFailure("The eps, confidence or seed provided must be a 
literal or foldable")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
   }
 
   test("fails analysis if parameters are invalid") {
     // parameters are null
-    val wrongEps = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Cast(Literal(null), DoubleType),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Literal(seed))
-    val wrongConfidence = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Cast(Literal(null), DoubleType),
-      seedExpression = Literal(seed))
-    val wrongSeed = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Cast(Literal(null), IntegerType))
+    val wrongEps = cms(null, confidence, seed)
+    val wrongConfidence = cms(epsOfTotalCount, null, seed)
+    val wrongSeed = cms(epsOfTotalCount, confidence, null)
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure("The eps, confidence or seed provided should not be 
null")
-      )
+      assertResult(TypeCheckFailure("The eps, confidence or seed provided 
should not be null")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
 
     // parameters are out of the valid range
     Seq(0.0, -1000.0).foreach { invalidEps =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(invalidEps),
-        confidenceExpression = Literal(confidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(s"Relative error must be positive (current value = 
$invalidEps)")
-      )
+      val invalidAgg = cms(invalidEps, confidence, seed)
+      assertResult(
+        TypeCheckFailure(s"Relative error must be positive (current value = 
$invalidEps)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
 
     Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(epsOfTotalCount),
-        confidenceExpression = Literal(invalidConfidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          s"Confidence must be within range (0.0, 1.0) (current value = 
$invalidConfidence)")
-      )
+      val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed)
+      assertResult(TypeCheckFailure(
+        s"Confidence must be within range (0.0, 1.0) (current value = 
$invalidConfidence)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
   }
 
-  private def assertEqual[T](left: T, right: T): Unit = {
-    assert(left == right)
-  }
-
   test("null handling") {
     def isEqual(result: Any, other: CountMinSketch): Boolean = {
-      result match {
-        case bytesData: Array[Byte] =>
-          val in = new ByteArrayInputStream(bytesData)
-          val cms = CountMinSketch.readFrom(in)
-          cms.equals(other)
-        case _ => fail("unexpected return type")
-      }
+      other.equals(CountMinSketch.readFrom(result.asInstanceOf[Array[Byte]]))
     }
 
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), 
Literal(confidence),
-      Literal(seed))
+    val agg = cms(epsOfTotalCount, confidence, seed)
     val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
     val buffer = new GenericInternalRow(new Array[Any](1))
     agg.initialize(buffer)
     // Empty aggregation buffer
     assert(isEqual(agg.eval(buffer), emptyCms))
+
     // Empty input row
     agg.update(buffer, InternalRow(null))
     assert(isEqual(agg.eval(buffer), emptyCms))

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index e98092d..62a7534 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -21,6 +21,9 @@ import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
 import org.apache.spark.sql.test.SharedSQLContext
 
+/**
+ * End-to-end tests for approximate percentile aggregate function.
+ */
 class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
index 3e715a3..dea0d4c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -17,175 +17,29 @@
 
 package org.apache.spark.sql
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
-import java.sql.{Date, Timestamp}
-
-import scala.reflect.ClassTag
-import scala.util.Random
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{Decimal, StringType, _}
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * End-to-end test suite for count_min_sketch.
+ */
 class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
 
-  private val table = "count_min_sketch_table"
-
-  /** Uses fixed seed to ensure reproducible test execution */
-  private val r = new Random(42)
-  private val numAllItems = 1000
-  private val numSamples = numAllItems / 10
-
-  private val eps = 0.1D
-  private val confidence = 0.95D
-  private val seed = 11
-
-  val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
-  val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01"))
-  val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 
00:00:00"))
-  val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 
00:00:00"))
-
-  test(s"compute count-min sketch for multiple columns of different types") {
-    val (allBytes, sampledByteIndices, exactByteFreq) =
-      generateTestData[Byte] { _.nextInt().toByte }
-    val (allShorts, sampledShortIndices, exactShortFreq) =
-      generateTestData[Short] { _.nextInt().toShort }
-    val (allInts, sampledIntIndices, exactIntFreq) =
-      generateTestData[Int] { _.nextInt() }
-    val (allLongs, sampledLongIndices, exactLongFreq) =
-      generateTestData[Long] { _.nextLong() }
-    val (allStrings, sampledStringIndices, exactStringFreq) =
-      generateTestData[String] { r => r.nextString(r.nextInt(20)) }
-    val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] 
{ r =>
-      DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate)
-    }
-    val (allTimestamps, sampledTSIndices, exactTSFreq) = 
generateTestData[Timestamp] { r =>
-      DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
-    }
-    val (allFloats, sampledFloatIndices, exactFloatFreq) =
-      generateTestData[Float] { _.nextFloat() }
-    val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
-      generateTestData[Double] { _.nextDouble() }
-    val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
-      generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
-    val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
-      generateTestData[Boolean] { _.nextBoolean() }
-    val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = 
generateTestData[Array[Byte]] { r =>
-      r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
-    }
-
-    val data = (0 until numSamples).map { i =>
-      Row(allBytes(sampledByteIndices(i)),
-        allShorts(sampledShortIndices(i)),
-        allInts(sampledIntIndices(i)),
-        allLongs(sampledLongIndices(i)),
-        allStrings(sampledStringIndices(i)),
-        allDates(sampledDateIndices(i)),
-        allTimestamps(sampledTSIndices(i)),
-        allFloats(sampledFloatIndices(i)),
-        allDoubles(sampledDoubleIndices(i)),
-        allDeciamls(sampledDecimalIndices(i)),
-        allBooleans(sampledBooleanIndices(i)),
-        allBinaries(sampledBinaryIndices(i)))
-    }
+  test("count-min sketch") {
+    import testImplicits._
 
-    val schema = StructType(Seq(
-      StructField("c1", ByteType),
-      StructField("c2", ShortType),
-      StructField("c3", IntegerType),
-      StructField("c4", LongType),
-      StructField("c5", StringType),
-      StructField("c6", DateType),
-      StructField("c7", TimestampType),
-      StructField("c8", FloatType),
-      StructField("c9", DoubleType),
-      StructField("c10", new DecimalType()),
-      StructField("c11", BooleanType),
-      StructField("c12", BinaryType)))
+    val eps = 0.1
+    val confidence = 0.95
+    val seed = 11
 
-    withTempView(table) {
-      val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
-      spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
+    val items = Seq(1, 1, 2, 2, 2, 2, 3, 4, 5)
+    val sketch = CountMinSketch.readFrom(items.toDF("id")
+      .selectExpr(s"count_min_sketch(id, ${eps}d, ${confidence}d, $seed)")
+      .head().get(0).asInstanceOf[Array[Byte]])
 
-      val cmsSql = schema.fieldNames.map { col =>
-        s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)"
-      }
-      val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head()
-      schema.indices.foreach { i =>
-        val binaryData = result.getAs[Array[Byte]](i)
-        val in = new ByteArrayInputStream(binaryData)
-        val cms = CountMinSketch.readFrom(in)
-        schema.fields(i).dataType match {
-          case ByteType => checkResult(cms, allBytes, exactByteFreq)
-          case ShortType => checkResult(cms, allShorts, exactShortFreq)
-          case IntegerType => checkResult(cms, allInts, exactIntFreq)
-          case LongType => checkResult(cms, allLongs, exactLongFreq)
-          case StringType => checkResult(cms, allStrings, exactStringFreq)
-          case DateType =>
-            checkResult(cms,
-              allDates.map(DateTimeUtils.fromJavaDate),
-              exactDateFreq.map { e =>
-                (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
-              })
-          case TimestampType =>
-            checkResult(cms,
-              allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
-              exactTSFreq.map { e =>
-                
(DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
-              })
-          case FloatType => checkResult(cms, allFloats, exactFloatFreq)
-          case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
-          case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
-          case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
-          case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
-        }
-      }
-    }
-  }
-
-  private def checkResult[T: ClassTag](
-      cms: CountMinSketch,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    val probCorrect = {
-      val numErrors = data.map { i =>
-        val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-        val item = i match {
-          case dec: Decimal => dec.toJavaBigDecimal
-          case str: UTF8String => str.getBytes
-          case _ => i
-        }
-        val ratio = (cms.estimateCount(item) - count).toDouble / data.length
-        if (ratio > eps) 1 else 0
-      }.sum
-
-      1D - numErrors.toDouble / data.length
-    }
-
-    assert(
-      probCorrect > confidence,
-      s"Confidence not reached: required $confidence, reached $probCorrect"
-    )
-  }
-
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+    val reference = CountMinSketch.create(eps, confidence, seed)
+    items.foreach(reference.add)
 
-  private def generateTestData[T: ClassTag](
-      itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
-    }
-    (allItems, sampledItemIndices, exactFreq)
+    assert(sketch == reference)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to