You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/12/02 05:38:55 UTC

spark git commit: [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation

Repository: spark
Updated Branches:
  refs/heads/master a5f02b002 -> d3c90b74e


[SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation

## What changes were proposed in this pull request?
SPARK-18429 introduced count-min sketch aggregate function for SQL, but the implementation and testing is more complicated than needed. This simplifies the test cases and removes support for data types that don't have clear equality semantics:

1. Removed support for floating point and decimal types.

2. Removed the heavy randomized tests. The underlying CountMinSketch implementation already had pretty good test coverage through randomized tests, and the SPARK-18429 implementation is just to add an aggregate function wrapper around CountMinSketch. There is no need for randomized tests at three different levels of the implementations.

## How was this patch tested?
A lot of the change is to simplify test cases.

Author: Reynold Xin <rx...@databricks.com>

Closes #16093 from rxin/SPARK-18663.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d3c90b74
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d3c90b74
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d3c90b74

Branch: refs/heads/master
Commit: d3c90b74edecc527ee468bead41d1cca0b667668
Parents: a5f02b0
Author: Reynold Xin <rx...@databricks.com>
Authored: Thu Dec 1 21:38:52 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Dec 1 21:38:52 2016 -0800

----------------------------------------------------------------------
 .../spark/util/sketch/CountMinSketch.java       |  22 +-
 .../spark/util/sketch/CountMinSketchImpl.java   |  50 ++-
 .../spark/util/sketch/CountMinSketchSuite.scala |  40 +--
 project/MimaExcludes.scala                      |   8 +-
 .../aggregate/CountMinSketchAgg.scala           |  27 +-
 .../aggregate/ApproximatePercentileSuite.scala  |   2 +-
 .../aggregate/CountMinSketchAggSuite.scala      | 304 ++++++-------------
 .../sql/ApproximatePercentileQuerySuite.scala   |   3 +
 .../spark/sql/CountMinSketchAggQuerySuite.scala | 176 +----------
 9 files changed, 177 insertions(+), 455 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 0011096..f7c22dd 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -17,12 +17,13 @@
 
 package org.apache.spark.util.sketch;
 
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 
 /**
- * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in
+ * A Count-min sketch is a probabilistic data structure used for cardinality estimation using
  * sub-linear space.  Currently, supported data types include:
  * <ul>
  *   <li>{@link Byte}</li>
@@ -30,10 +31,6 @@ import java.io.OutputStream;
  *   <li>{@link Integer}</li>
  *   <li>{@link Long}</li>
  *   <li>{@link String}</li>
- *   <li>{@link Float}</li>
- *   <li>{@link Double}</li>
- *   <li>{@link java.math.BigDecimal}</li>
- *   <li>{@link Boolean}</li>
  * </ul>
  * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
  * <ol>
@@ -178,6 +175,11 @@ public abstract class CountMinSketch {
   public abstract void writeTo(OutputStream out) throws IOException;
 
   /**
+   * Serializes this {@link CountMinSketch} and returns the serialized form.
+   */
+  public abstract byte[] toByteArray() throws IOException;
+
+  /**
    * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to
    * close the stream.
    */
@@ -186,6 +188,16 @@ public abstract class CountMinSketch {
   }
 
   /**
+   * Reads in a {@link CountMinSketch} from a byte array.
+   */
+  public static CountMinSketch readFrom(byte[] bytes) throws IOException {
+    InputStream in = new ByteArrayInputStream(bytes);
+    CountMinSketch cms = readFrom(in);
+    in.close();
+    return cms;
+  }
+
+  /**
    * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
    * {@code seed}.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 94ab3a9..045fec3 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -17,15 +17,7 @@
 
 package org.apache.spark.util.sketch;
 
-import java.io.DataInputStream;
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
-import java.io.OutputStream;
-import java.io.Serializable;
-import java.math.BigDecimal;
+import java.io.*;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -153,16 +145,8 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   public void add(Object item, long count) {
     if (item instanceof String) {
       addString((String) item, count);
-    } else if (item instanceof BigDecimal) {
-      addString(((BigDecimal) item).toString(), count);
     } else if (item instanceof byte[]) {
       addBinary((byte[]) item, count);
-    } else if (item instanceof Float) {
-      addLong(Float.floatToIntBits((Float) item), count);
-    } else if (item instanceof Double) {
-      addLong(Double.doubleToLongBits((Double) item), count);
-    } else if (item instanceof Boolean) {
-      addLong(((Boolean) item) ? 1L : 0L, count);
     } else {
       addLong(Utils.integralToLong(item), count);
     }
@@ -227,6 +211,10 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return ((int) hash) % width;
   }
 
+  private static int[] getHashBuckets(String key, int hashCount, int max) {
+    return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
+  }
+
   private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
     int[] result = new int[hashCount];
     int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
@@ -240,18 +228,9 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   @Override
   public long estimateCount(Object item) {
     if (item instanceof String) {
-      return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
-    } else if (item instanceof BigDecimal) {
-      return estimateCountForBinaryItem(
-        Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
+      return estimateCountForStringItem((String) item);
     } else if (item instanceof byte[]) {
       return estimateCountForBinaryItem((byte[]) item);
-    } else if (item instanceof Float) {
-      return estimateCountForLongItem(Float.floatToIntBits((Float) item));
-    } else if (item instanceof Double) {
-      return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
-    } else if (item instanceof Boolean) {
-      return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
     } else {
       return estimateCountForLongItem(Utils.integralToLong(item));
     }
@@ -265,6 +244,15 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return res;
   }
 
+  private long estimateCountForStringItem(String item) {
+    long res = Long.MAX_VALUE;
+    int[] buckets = getHashBuckets(item, depth, width);
+    for (int i = 0; i < depth; ++i) {
+      res = Math.min(res, table[i][buckets[i]]);
+    }
+    return res;
+  }
+
   private long estimateCountForBinaryItem(byte[] item) {
     long res = Long.MAX_VALUE;
     int[] buckets = getHashBuckets(item, depth, width);
@@ -332,6 +320,14 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     }
   }
 
+  @Override
+  public byte[] toByteArray() throws IOException {
+    ByteArrayOutputStream out = new ByteArrayOutputStream();
+    writeTo(out);
+    out.close();
+    return out.toByteArray();
+  }
+
   public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
     CountMinSketchImpl sketch = new CountMinSketchImpl();
     sketch.readFrom0(in);

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
----------------------------------------------------------------------
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index 2c358fc..174eb01 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.util.sketch
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-import java.nio.charset.StandardCharsets
 
 import scala.reflect.ClassTag
 import scala.util.Random
@@ -26,9 +25,9 @@ import scala.util.Random
 import org.scalatest.FunSuite // scalastyle:ignore funsuite
 
 class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
-  private val epsOfTotalCount = 0.0001
+  private val epsOfTotalCount = 0.01
 
-  private val confidence = 0.99
+  private val confidence = 0.9
 
   private val seed = 42
 
@@ -45,12 +44,6 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
   }
 
   def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
-    def getProbeItem(item: T): Any = item match {
-      // Use a string to represent the content of an array of bytes
-      case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-      case i => identity(i)
-    }
-
     test(s"accuracy - $typeName") {
       // Uses fixed seed to ensure reproducible test execution
       val r = new Random(31)
@@ -63,7 +56,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val exactFreq = {
         val sampledItems = sampledItemIndices.map(allItems)
-        sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+        sampledItems.groupBy(identity).mapValues(_.length.toLong)
       }
 
       val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -74,12 +67,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val probCorrect = {
         val numErrors = allItems.map { item =>
-          val count = exactFreq.getOrElse(getProbeItem(item), 0L)
+          val count = exactFreq.getOrElse(item, 0L)
           val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
           if (ratio > epsOfTotalCount) 1 else 0
         }.sum
 
-        1D - numErrors.toDouble / numAllItems
+        1.0 - (numErrors.toDouble / numAllItems)
       }
 
       assert(
@@ -96,9 +89,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val numToMerge = 5
       val numItemsPerSketch = 100000
-      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
-        itemGenerator(r)
-      }
+      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { itemGenerator(r) }
 
       val sketches = perSketchItems.map { items =>
         val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -113,11 +104,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
       val mergedSketch = sketches.reduce(_ mergeInPlace _)
       checkSerDe(mergedSketch)
 
-      val expectedSketch = {
-        val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-        perSketchItems.foreach(_.foreach(sketch.add))
-        sketch
-      }
+      val expectedSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      perSketchItems.foreach(_.foreach(expectedSketch.add))
 
       perSketchItems.foreach {
         _.foreach { item =>
@@ -142,17 +130,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
   testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
 
-  testItemType[Float]("Float") { _.nextFloat() }
-
-  testItemType[Double]("Double") { _.nextDouble() }
-
-  testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }
-
-  testItemType[Boolean]("Boolean") { _.nextBoolean() }
-
-  testItemType[Array[Byte]]("Binary") { r =>
-    Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
-  }
+  testItemType[Array[Byte]]("Byte array") { r => r.nextString(r.nextInt(60)).getBytes }
 
   test("incompatible merge") {
     intercept[IncompatibleMergeException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4995af0..b113bbf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -34,6 +34,11 @@ import com.typesafe.tools.mima.core.ProblemFilters._
  */
 object MimaExcludes {
 
+  lazy val v22excludes = v21excludes ++ Seq(
+    // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray")
+  )
+
   // Exclude rules for 2.1.x
   lazy val v21excludes = v20excludes ++ {
     Seq(
@@ -912,7 +917,8 @@ object MimaExcludes {
   }
 
   def excludes(version: String) = version match {
-    case v if v.startsWith("2.1") => v21excludes
+    case v if v.startsWith("2.2") => v22excludes
+    case v if v.startsWith("2.1") => v22excludes  // TODO: Update this when we bump version to 2.2
     case v if v.startsWith("2.0") => v20excludes
     case _ => Seq()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index f5f185f..612c198 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
@@ -42,9 +40,9 @@ import org.apache.spark.util.sketch.CountMinSketch
 @ExpressionDescription(
   usage = """
     _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
-      confidence and seed. The result is an array of bytes, which should be deserialized to a
-      `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
-      size estimation.
+      confidence and seed. The result is an array of bytes, which can be deserialized to a
+      `CountMinSketch` before usage. Count-min sketch is a probabilistic data structure used for
+      cardinality estimation using sub-linear space.
   """)
 case class CountMinSketchAgg(
     child: Expression,
@@ -75,13 +73,13 @@ case class CountMinSketchAgg(
     } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
       !seedExpression.foldable) {
       TypeCheckFailure(
-        "The eps, confidence or seed provided must be a literal or constant foldable")
+        "The eps, confidence or seed provided must be a literal or foldable")
     } else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
       seedExpression.eval() == null) {
       TypeCheckFailure("The eps, confidence or seed provided should not be null")
-    } else if (eps <= 0D) {
+    } else if (eps <= 0.0) {
       TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
-    } else if (confidence <= 0D || confidence >= 1D) {
+    } else if (confidence <= 0.0 || confidence >= 1.0) {
       TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
     } else {
       TypeCheckSuccess
@@ -97,9 +95,6 @@ case class CountMinSketchAgg(
     // Ignore empty rows
     if (value != null) {
       child.dataType match {
-        // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
-        // into acceptable types for `CountMinSketch`.
-        case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
         // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
         // instead of `addString` to avoid unnecessary conversion.
         case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
@@ -115,14 +110,11 @@ case class CountMinSketchAgg(
   override def eval(buffer: CountMinSketch): Any = serialize(buffer)
 
   override def serialize(buffer: CountMinSketch): Array[Byte] = {
-    val out = new ByteArrayOutputStream()
-    buffer.writeTo(out)
-    out.toByteArray
+    buffer.toByteArray
   }
 
   override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
-    val in = new ByteArrayInputStream(storageFormat)
-    CountMinSketch.readFrom(in)
+    CountMinSketch.readFrom(storageFormat)
   }
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
@@ -132,8 +124,7 @@ case class CountMinSketchAgg(
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
   override def inputTypes: Seq[AbstractDataType] = {
-    Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
-      DoubleType, DoubleType, IntegerType)
+    Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType)
   }
 
   override def nullable: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
index 8456e24..fcb370a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -86,7 +86,7 @@ class ApproximatePercentileSuite extends SparkFunSuite {
       (headBufferSize + bufferSize) * 2
     }
 
-    val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count =>
+    Seq(100, 1000, 10000, 100000, 1000000, 10000000).foreach { count =>
       val buffer = new PercentileDigest(relativeError)
       // Worst case, data is linear sorted
       (0 until count).foreach(buffer.add(_))

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
index 6e08e29..1047963 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -17,199 +17,114 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
+import java.{lang => jl}
 
-import scala.reflect.ClassTag
 import scala.util.Random
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
-import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * Unit test suite for the count-min sketch SQL aggregate funciton [[CountMinSketchAgg]].
+ */
 class CountMinSketchAggSuite extends SparkFunSuite {
   private val childExpression = BoundReference(0, IntegerType, nullable = true)
   private val epsOfTotalCount = 0.0001
   private val confidence = 0.99
   private val seed = 42
-
-  test("serialize and de-serialize") {
-    // Check empty serialize and de-serialize
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
-      Literal(seed))
-    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
-
-    // Check non-empty serialize and de-serialize
-    val random = new Random(31)
-    (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
-      buffer.add(value)
-    }
-    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+  private val rand = new Random(seed)
+
+  /** Creates a count-min sketch aggregate expression, using the child expression defined above. */
+  private def cms(eps: jl.Double, confidence: jl.Double, seed: jl.Integer): CountMinSketchAgg = {
+    new CountMinSketchAgg(
+      child = childExpression,
+      epsExpression = Literal(eps, DoubleType),
+      confidenceExpression = Literal(confidence, DoubleType),
+      seedExpression = Literal(seed, IntegerType))
   }
 
-  def testHighLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"high level interface, update, merge, eval... - $dataType") {
+  /**
+   * Creates a new test case that compares our aggregate function with a reference implementation
+   * (using the underlying [[CountMinSketch]]).
+   *
+   * This works by splitting the items into two separate groups, aggregates them, and then merges
+   * the two groups back (to emulate partial aggregation), and then compares the result with
+   * that generated by [[CountMinSketch]] directly. This assumes insertion order does not impact
+   * the result in count-min sketch.
+   */
+  private def testDataType[T](dataType: DataType, items: Seq[T]): Unit = {
+    test("test data type " + dataType) {
       val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
         Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
       assert(!agg.nullable)
 
-      val group1 = 0 until sampledItemIndices.length / 2
-      val group1Buffer = agg.createAggregationBuffer()
-      group1.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group1Buffer, input)
+      val (seq1, seq2) = items.splitAt(items.size / 2)
+      val buf1 = addToAggregateBuffer(agg, seq1)
+      val buf2 = addToAggregateBuffer(agg, seq2)
+
+      val sketch = agg.createAggregationBuffer()
+      agg.merge(sketch, buf1)
+      agg.merge(sketch, buf2)
+
+      // Validate cardinality estimation against reference implementation.
+      val referenceSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      items.foreach { item =>
+        referenceSketch.add(item match {
+          case u: UTF8String => u.getBytes
+          case _ => item
+        })
       }
 
-      val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length
-      val group2Buffer = agg.createAggregationBuffer()
-      group2.foreach { index =>
-        val input = InternalRow(allItems(sampledItemIndices(index)))
-        agg.update(group2Buffer, input)
+      items.foreach { item =>
+        withClue(s"For item $item") {
+          val itemToTest = item match {
+            case u: UTF8String => u.getBytes
+            case _ => item
+          }
+          assert(referenceSketch.estimateCount(itemToTest) == sketch.estimateCount(itemToTest))
+        }
       }
-
-      var mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group1Buffer)
-      agg.merge(mergeBuffer, group2Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge in a different order
-      mergeBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, group2Buffer)
-      agg.merge(mergeBuffer, group1Buffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
-
-      // Merge with an empty partition
-      val emptyBuffer = agg.createAggregationBuffer()
-      agg.merge(mergeBuffer, emptyBuffer)
-      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
     }
-  }
-
-  def testLowLevelInterface[T: ClassTag](
-      dataType: DataType,
-      sampledItemIndices: Array[Int],
-      allItems: Array[T],
-      exactFreq: Map[Any, Long]): Any = {
-    test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
-      val inputAggregationBufferOffset = 1
-      val mutableAggregationBufferOffset = 2
 
-      // Phase one, partial mode aggregation
-      val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
-        Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
-        .withNewInputAggBufferOffset(inputAggregationBufferOffset)
-        .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
-
-      val mutableAggBuffer = new GenericInternalRow(
-        new Array[Any](mutableAggregationBufferOffset + 1))
-      agg.initialize(mutableAggBuffer)
-
-      sampledItemIndices.foreach { i =>
-        agg.update(mutableAggBuffer, InternalRow(allItems(i)))
-      }
-      agg.serializeAggregateBufferInPlace(mutableAggBuffer)
-
-      // Serialize the aggregation buffer
-      val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
-      val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
-
-      // Phase 2: final mode aggregation
-      // Re-initialize the aggregation buffer
-      agg.initialize(mutableAggBuffer)
-      agg.merge(mutableAggBuffer, inputAggBuffer)
-      checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+    def addToAggregateBuffer[T](agg: CountMinSketchAgg, items: Seq[T]): CountMinSketch = {
+      val buf = agg.createAggregationBuffer()
+      items.foreach { item => agg.update(buf, InternalRow(item)) }
+      buf
     }
   }
 
-  private def checkResult[T: ClassTag](
-      result: Any,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    result match {
-      case bytesData: Array[Byte] =>
-        val in = new ByteArrayInputStream(bytesData)
-        val cms = CountMinSketch.readFrom(in)
-        val probCorrect = {
-          val numErrors = data.map { i =>
-            val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-            val item = i match {
-              case dec: Decimal => dec.toJavaBigDecimal
-              case str: UTF8String => str.getBytes
-              case _ => i
-            }
-            val ratio = (cms.estimateCount(item) - count).toDouble / data.length
-            if (ratio > epsOfTotalCount) 1 else 0
-          }.sum
+  testDataType[Byte](ByteType, Seq.fill(100) { rand.nextInt(10).toByte })
 
-          1D - numErrors.toDouble / data.length
-        }
+  testDataType[Short](ShortType, Seq.fill(100) { rand.nextInt(10).toShort })
 
-        assert(
-          probCorrect > confidence,
-          s"Confidence not reached: required $confidence, reached $probCorrect"
-        )
-      case _ => fail("unexpected return type")
-    }
-  }
+  testDataType[Int](IntegerType, Seq.fill(100) { rand.nextInt(10) })
 
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+  testDataType[Long](LongType, Seq.fill(100) { rand.nextInt(10) })
 
-  def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
-    // Uses fixed seed to ensure reproducible test execution
-    val r = new Random(31)
+  testDataType[UTF8String](StringType, Seq.fill(100) { UTF8String.fromString(rand.nextString(1)) })
 
-    val numAllItems = 1000000
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
+  testDataType[Array[Byte]](BinaryType, Seq.fill(100) { rand.nextString(1).getBytes() })
 
-    val numSamples = numAllItems / 10
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+  test("serialize and de-serialize") {
+    // Check empty serialize and de-serialize
+    val agg = cms(epsOfTotalCount, confidence, seed)
+    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
 
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+    // Check non-empty serialize and de-serialize
+    val random = new Random(31)
+    for (i <- 0 until 10) {
+      buffer.add(random.nextInt(100))
     }
-
-    testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
-    testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
-  }
-
-  testItemType[Byte](ByteType) { _.nextInt().toByte }
-
-  testItemType[Short](ShortType) { _.nextInt().toShort }
-
-  testItemType[Int](IntegerType) { _.nextInt() }
-
-  testItemType[Long](LongType) { _.nextLong() }
-
-  testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }
-
-  testItemType[Float](FloatType) { _.nextFloat() }
-
-  testItemType[Double](DoubleType) { _.nextDouble() }
-
-  testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
-
-  testItemType[Boolean](BooleanType) { _.nextBoolean() }
-
-  testItemType[Array[Byte]](BinaryType) { r =>
-    r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
   }
 
-
-  test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
+  test("fails analysis if eps, confidence or seed provided is not foldable") {
     val wrongEps = new CountMinSketchAgg(
       childExpression,
       epsExpression = AttributeReference("a", DoubleType)(),
@@ -227,88 +142,55 @@ class CountMinSketchAggSuite extends SparkFunSuite {
       seedExpression = AttributeReference("c", IntegerType)())
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          "The eps, confidence or seed provided must be a literal or constant foldable")
-      )
+      assertResult(
+        TypeCheckFailure("The eps, confidence or seed provided must be a literal or foldable")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
   }
 
   test("fails analysis if parameters are invalid") {
     // parameters are null
-    val wrongEps = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Cast(Literal(null), DoubleType),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Literal(seed))
-    val wrongConfidence = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Cast(Literal(null), DoubleType),
-      seedExpression = Literal(seed))
-    val wrongSeed = new CountMinSketchAgg(
-      childExpression,
-      epsExpression = Literal(epsOfTotalCount),
-      confidenceExpression = Literal(confidence),
-      seedExpression = Cast(Literal(null), IntegerType))
+    val wrongEps = cms(null, confidence, seed)
+    val wrongConfidence = cms(epsOfTotalCount, null, seed)
+    val wrongSeed = cms(epsOfTotalCount, confidence, null)
 
     Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
-      assertEqual(
-        wrongAgg.checkInputDataTypes(),
-        TypeCheckFailure("The eps, confidence or seed provided should not be null")
-      )
+      assertResult(TypeCheckFailure("The eps, confidence or seed provided should not be null")) {
+        wrongAgg.checkInputDataTypes()
+      }
     }
 
     // parameters are out of the valid range
     Seq(0.0, -1000.0).foreach { invalidEps =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(invalidEps),
-        confidenceExpression = Literal(confidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")
-      )
+      val invalidAgg = cms(invalidEps, confidence, seed)
+      assertResult(
+        TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
 
     Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
-      val invalidAgg = new CountMinSketchAgg(
-        childExpression,
-        epsExpression = Literal(epsOfTotalCount),
-        confidenceExpression = Literal(invalidConfidence),
-        seedExpression = Literal(seed))
-      assertEqual(
-        invalidAgg.checkInputDataTypes(),
-        TypeCheckFailure(
-          s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")
-      )
+      val invalidAgg = cms(epsOfTotalCount, invalidConfidence, seed)
+      assertResult(TypeCheckFailure(
+        s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")) {
+        invalidAgg.checkInputDataTypes()
+      }
     }
   }
 
-  private def assertEqual[T](left: T, right: T): Unit = {
-    assert(left == right)
-  }
-
   test("null handling") {
     def isEqual(result: Any, other: CountMinSketch): Boolean = {
-      result match {
-        case bytesData: Array[Byte] =>
-          val in = new ByteArrayInputStream(bytesData)
-          val cms = CountMinSketch.readFrom(in)
-          cms.equals(other)
-        case _ => fail("unexpected return type")
-      }
+      other.equals(CountMinSketch.readFrom(result.asInstanceOf[Array[Byte]]))
     }
 
-    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
-      Literal(seed))
+    val agg = cms(epsOfTotalCount, confidence, seed)
     val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
     val buffer = new GenericInternalRow(new Array[Any](1))
     agg.initialize(buffer)
     // Empty aggregation buffer
     assert(isEqual(agg.eval(buffer), emptyCms))
+
     // Empty input row
     agg.update(buffer, InternalRow(null))
     assert(isEqual(agg.eval(buffer), emptyCms))

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
index e98092d..62a7534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala
@@ -21,6 +21,9 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
 import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
 import org.apache.spark.sql.test.SharedSQLContext
 
+/**
+ * End-to-end tests for approximate percentile aggregate function.
+ */
 class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d3c90b74/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
index 3e715a3..dea0d4c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -17,175 +17,29 @@
 
 package org.apache.spark.sql
 
-import java.io.ByteArrayInputStream
-import java.nio.charset.StandardCharsets
-import java.sql.{Date, Timestamp}
-
-import scala.reflect.ClassTag
-import scala.util.Random
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{Decimal, StringType, _}
-import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
 
+/**
+ * End-to-end test suite for count_min_sketch.
+ */
 class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
 
-  private val table = "count_min_sketch_table"
-
-  /** Uses fixed seed to ensure reproducible test execution */
-  private val r = new Random(42)
-  private val numAllItems = 1000
-  private val numSamples = numAllItems / 10
-
-  private val eps = 0.1D
-  private val confidence = 0.95D
-  private val seed = 11
-
-  val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
-  val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01"))
-  val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00"))
-  val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00"))
-
-  test(s"compute count-min sketch for multiple columns of different types") {
-    val (allBytes, sampledByteIndices, exactByteFreq) =
-      generateTestData[Byte] { _.nextInt().toByte }
-    val (allShorts, sampledShortIndices, exactShortFreq) =
-      generateTestData[Short] { _.nextInt().toShort }
-    val (allInts, sampledIntIndices, exactIntFreq) =
-      generateTestData[Int] { _.nextInt() }
-    val (allLongs, sampledLongIndices, exactLongFreq) =
-      generateTestData[Long] { _.nextLong() }
-    val (allStrings, sampledStringIndices, exactStringFreq) =
-      generateTestData[String] { r => r.nextString(r.nextInt(20)) }
-    val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r =>
-      DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate)
-    }
-    val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r =>
-      DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
-    }
-    val (allFloats, sampledFloatIndices, exactFloatFreq) =
-      generateTestData[Float] { _.nextFloat() }
-    val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
-      generateTestData[Double] { _.nextDouble() }
-    val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
-      generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
-    val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
-      generateTestData[Boolean] { _.nextBoolean() }
-    val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r =>
-      r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
-    }
-
-    val data = (0 until numSamples).map { i =>
-      Row(allBytes(sampledByteIndices(i)),
-        allShorts(sampledShortIndices(i)),
-        allInts(sampledIntIndices(i)),
-        allLongs(sampledLongIndices(i)),
-        allStrings(sampledStringIndices(i)),
-        allDates(sampledDateIndices(i)),
-        allTimestamps(sampledTSIndices(i)),
-        allFloats(sampledFloatIndices(i)),
-        allDoubles(sampledDoubleIndices(i)),
-        allDeciamls(sampledDecimalIndices(i)),
-        allBooleans(sampledBooleanIndices(i)),
-        allBinaries(sampledBinaryIndices(i)))
-    }
+  test("count-min sketch") {
+    import testImplicits._
 
-    val schema = StructType(Seq(
-      StructField("c1", ByteType),
-      StructField("c2", ShortType),
-      StructField("c3", IntegerType),
-      StructField("c4", LongType),
-      StructField("c5", StringType),
-      StructField("c6", DateType),
-      StructField("c7", TimestampType),
-      StructField("c8", FloatType),
-      StructField("c9", DoubleType),
-      StructField("c10", new DecimalType()),
-      StructField("c11", BooleanType),
-      StructField("c12", BinaryType)))
+    val eps = 0.1
+    val confidence = 0.95
+    val seed = 11
 
-    withTempView(table) {
-      val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
-      spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
+    val items = Seq(1, 1, 2, 2, 2, 2, 3, 4, 5)
+    val sketch = CountMinSketch.readFrom(items.toDF("id")
+      .selectExpr(s"count_min_sketch(id, ${eps}d, ${confidence}d, $seed)")
+      .head().get(0).asInstanceOf[Array[Byte]])
 
-      val cmsSql = schema.fieldNames.map { col =>
-        s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)"
-      }
-      val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head()
-      schema.indices.foreach { i =>
-        val binaryData = result.getAs[Array[Byte]](i)
-        val in = new ByteArrayInputStream(binaryData)
-        val cms = CountMinSketch.readFrom(in)
-        schema.fields(i).dataType match {
-          case ByteType => checkResult(cms, allBytes, exactByteFreq)
-          case ShortType => checkResult(cms, allShorts, exactShortFreq)
-          case IntegerType => checkResult(cms, allInts, exactIntFreq)
-          case LongType => checkResult(cms, allLongs, exactLongFreq)
-          case StringType => checkResult(cms, allStrings, exactStringFreq)
-          case DateType =>
-            checkResult(cms,
-              allDates.map(DateTimeUtils.fromJavaDate),
-              exactDateFreq.map { e =>
-                (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
-              })
-          case TimestampType =>
-            checkResult(cms,
-              allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
-              exactTSFreq.map { e =>
-                (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
-              })
-          case FloatType => checkResult(cms, allFloats, exactFloatFreq)
-          case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
-          case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
-          case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
-          case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
-        }
-      }
-    }
-  }
-
-  private def checkResult[T: ClassTag](
-      cms: CountMinSketch,
-      data: Array[T],
-      exactFreq: Map[Any, Long]): Unit = {
-    val probCorrect = {
-      val numErrors = data.map { i =>
-        val count = exactFreq.getOrElse(getProbeItem(i), 0L)
-        val item = i match {
-          case dec: Decimal => dec.toJavaBigDecimal
-          case str: UTF8String => str.getBytes
-          case _ => i
-        }
-        val ratio = (cms.estimateCount(item) - count).toDouble / data.length
-        if (ratio > eps) 1 else 0
-      }.sum
-
-      1D - numErrors.toDouble / data.length
-    }
-
-    assert(
-      probCorrect > confidence,
-      s"Confidence not reached: required $confidence, reached $probCorrect"
-    )
-  }
-
-  private def getProbeItem[T: ClassTag](item: T): Any = item match {
-    // Use a string to represent the content of an array of bytes
-    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
-    case i => identity(i)
-  }
+    val reference = CountMinSketch.create(eps, confidence, seed)
+    items.foreach(reference.add)
 
-  private def generateTestData[T: ClassTag](
-      itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
-    val allItems = Array.fill(numAllItems)(itemGenerator(r))
-    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
-    val exactFreq = {
-      val sampledItems = sampledItemIndices.map(allItems)
-      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
-    }
-    (allItems, sampledItemIndices, exactFreq)
+    assert(sketch == reference)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org