You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/29 21:16:50 UTC

spark git commit: [SPARK-18429][SQL] implement a new Aggregate for CountMinSketch

Repository: spark
Updated Branches:
  refs/heads/master f643fe47f -> d57a594b8


[SPARK-18429][SQL] implement a new Aggregate for CountMinSketch

## What changes were proposed in this pull request?

This PR implements a new Aggregate to generate count min sketch, which is a wrapper of CountMinSketch.

## How was this patch tested?

add test cases

Author: wangzhenhua <wa...@huawei.com>

Closes #15877 from wzhfy/cms.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d57a594b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d57a594b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d57a594b

Branch: refs/heads/master
Commit: d57a594b8b4cb1a6942dbd2fd30c3a97e0dd031e
Parents: f643fe4
Author: wangzhenhua <wa...@huawei.com>
Authored: Tue Nov 29 13:16:46 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Nov 29 13:16:46 2016 -0800

----------------------------------------------------------------------
 .../spark/util/sketch/CountMinSketch.java       |   4 +
 .../spark/util/sketch/CountMinSketchImpl.java   |  30 +-
 .../spark/util/sketch/CountMinSketchSuite.scala |  23 +-
 sql/catalyst/pom.xml                            |   5 +
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../aggregate/CountMinSketchAgg.scala           | 146 +++++++++
 .../aggregate/CountMinSketchAggSuite.scala      | 320 +++++++++++++++++++
 .../spark/sql/CountMinSketchAggQuerySuite.scala | 189 +++++++++++
 8 files changed, 710 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 40fa20c..0011096 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -30,6 +30,10 @@ import java.io.OutputStream;
  *   <li>{@link Integer}</li>
  *   <li>{@link Long}</li>
  *   <li>{@link String}</li>
+ *   <li>{@link Float}</li>
+ *   <li>{@link Double}</li>
+ *   <li>{@link java.math.BigDecimal}</li>
+ *   <li>{@link Boolean}</li>
  * </ul>
  * A {@link CountMinSketch} is initialized with a random seed, and a pair of parameters:
  * <ol>

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 2acbb24..94ab3a9 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -25,6 +25,7 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.io.OutputStream;
 import java.io.Serializable;
+import java.math.BigDecimal;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -152,6 +153,16 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   public void add(Object item, long count) {
     if (item instanceof String) {
       addString((String) item, count);
+    } else if (item instanceof BigDecimal) {
+      addString(((BigDecimal) item).toString(), count);
+    } else if (item instanceof byte[]) {
+      addBinary((byte[]) item, count);
+    } else if (item instanceof Float) {
+      addLong(Float.floatToIntBits((Float) item), count);
+    } else if (item instanceof Double) {
+      addLong(Double.doubleToLongBits((Double) item), count);
+    } else if (item instanceof Boolean) {
+      addLong(((Boolean) item) ? 1L : 0L, count);
     } else {
       addLong(Utils.integralToLong(item), count);
     }
@@ -216,10 +227,6 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return ((int) hash) % width;
   }
 
-  private static int[] getHashBuckets(String key, int hashCount, int max) {
-    return getHashBuckets(Utils.getBytesFromUTF8String(key), hashCount, max);
-  }
-
   private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
     int[] result = new int[hashCount];
     int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
@@ -233,7 +240,18 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
   @Override
   public long estimateCount(Object item) {
     if (item instanceof String) {
-      return estimateCountForStringItem((String) item);
+      return estimateCountForBinaryItem(Utils.getBytesFromUTF8String((String) item));
+    } else if (item instanceof BigDecimal) {
+      return estimateCountForBinaryItem(
+        Utils.getBytesFromUTF8String(((BigDecimal) item).toString()));
+    } else if (item instanceof byte[]) {
+      return estimateCountForBinaryItem((byte[]) item);
+    } else if (item instanceof Float) {
+      return estimateCountForLongItem(Float.floatToIntBits((Float) item));
+    } else if (item instanceof Double) {
+      return estimateCountForLongItem(Double.doubleToLongBits((Double) item));
+    } else if (item instanceof Boolean) {
+      return estimateCountForLongItem(((Boolean) item) ? 1L : 0L);
     } else {
       return estimateCountForLongItem(Utils.integralToLong(item));
     }
@@ -247,7 +265,7 @@ class CountMinSketchImpl extends CountMinSketch implements Serializable {
     return res;
   }
 
-  private long estimateCountForStringItem(String item) {
+  private long estimateCountForBinaryItem(byte[] item) {
     long res = Long.MAX_VALUE;
     int[] buckets = getHashBuckets(item, depth, width);
     for (int i = 0; i < depth; ++i) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
----------------------------------------------------------------------
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
index b9c7f5c..2c358fc 100644
--- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.util.sketch
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+import java.nio.charset.StandardCharsets
 
 import scala.reflect.ClassTag
 import scala.util.Random
@@ -44,6 +45,12 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
   }
 
   def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+    def getProbeItem(item: T): Any = item match {
+      // Use a string to represent the content of an array of bytes
+      case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+      case i => identity(i)
+    }
+
     test(s"accuracy - $typeName") {
       // Uses fixed seed to ensure reproducible test execution
       val r = new Random(31)
@@ -56,7 +63,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val exactFreq = {
         val sampledItems = sampledItemIndices.map(allItems)
-        sampledItems.groupBy(identity).mapValues(_.length.toLong)
+        sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
       }
 
       val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
@@ -67,7 +74,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
       val probCorrect = {
         val numErrors = allItems.map { item =>
-          val count = exactFreq.getOrElse(item, 0L)
+          val count = exactFreq.getOrElse(getProbeItem(item), 0L)
           val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
           if (ratio > epsOfTotalCount) 1 else 0
         }.sum
@@ -135,6 +142,18 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
 
   testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
 
+  testItemType[Float]("Float") { _.nextFloat() }
+
+  testItemType[Double]("Double") { _.nextDouble() }
+
+  testItemType[java.math.BigDecimal]("Decimal") { r => new java.math.BigDecimal(r.nextDouble()) }
+
+  testItemType[Boolean]("Boolean") { _.nextBoolean() }
+
+  testItemType[Array[Byte]]("Binary") { r =>
+    Utils.getBytesFromUTF8String(r.nextString(r.nextInt(20)))
+  }
+
   test("incompatible merge") {
     intercept[IncompatibleMergeException] {
       CountMinSketch.create(10, 10, 1).mergeInPlace(null)

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/sql/catalyst/pom.xml
----------------------------------------------------------------------
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index f118a9a..82a5a85 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -62,6 +62,11 @@
       <version>${project.version}</version>
     </dependency>
     <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sketch_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
       <groupId>org.scalacheck</groupId>
       <artifactId>scalacheck_${scala.binary.version}</artifactId>
       <scope>test</scope>

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 2636afe..e41f1ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -262,6 +262,7 @@ object FunctionRegistry {
     expression[VarianceSamp]("var_samp"),
     expression[CollectList]("collect_list"),
     expression[CollectSet]("collect_set"),
+    expression[CountMinSketchAgg]("count_min_sketch"),
 
     // string functions
     expression[Ascii]("ascii"),

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
new file mode 100644
index 0000000..1bfae9e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+/**
+ * This function returns a count-min sketch of a column with the given esp, confidence and seed.
+ * A count-min sketch is a probabilistic data structure used for summarizing streams of data in
+ * sub-linear space, which is useful for equality predicates and join size estimation.
+ * The result returned by the function is an array of bytes, which should be deserialized to a
+ * `CountMinSketch` before usage.
+ *
+ * @param child child expression that can produce column value with `child.eval(inputRow)`
+ * @param epsExpression relative error, must be positive
+ * @param confidenceExpression confidence, must be positive and less than 1.0
+ * @param seedExpression random seed
+ */
+@ExpressionDescription(
+  usage = """
+    _FUNC_(col, eps, confidence, seed) - Returns a count-min sketch of a column with the given esp,
+      confidence and seed. The result is an array of bytes, which should be deserialized to a
+      `CountMinSketch` before usage. `CountMinSketch` is useful for equality predicates and join
+      size estimation.
+  """)
+case class CountMinSketchAgg(
+    child: Expression,
+    epsExpression: Expression,
+    confidenceExpression: Expression,
+    seedExpression: Expression,
+    override val mutableAggBufferOffset: Int,
+    override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {
+
+  def this(
+      child: Expression,
+      epsExpression: Expression,
+      confidenceExpression: Expression,
+      seedExpression: Expression) = {
+    this(child, epsExpression, confidenceExpression, seedExpression, 0, 0)
+  }
+
+  // Mark as lazy so that they are not evaluated during tree transformation.
+  private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
+  private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
+  private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!epsExpression.foldable || !confidenceExpression.foldable ||
+      !seedExpression.foldable) {
+      TypeCheckFailure(
+        "The eps, confidence or seed provided must be a literal or constant foldable")
+    } else if (epsExpression.eval() == null || confidenceExpression.eval() == null ||
+      seedExpression.eval() == null) {
+      TypeCheckFailure("The eps, confidence or seed provided should not be null")
+    } else if (eps <= 0D) {
+      TypeCheckFailure(s"Relative error must be positive (current value = $eps)")
+    } else if (confidence <= 0D || confidence >= 1D) {
+      TypeCheckFailure(s"Confidence must be within range (0.0, 1.0) (current value = $confidence)")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  override def createAggregationBuffer(): CountMinSketch = {
+    CountMinSketch.create(eps, confidence, seed)
+  }
+
+  override def update(buffer: CountMinSketch, input: InternalRow): Unit = {
+    val value = child.eval(input)
+    // Ignore empty rows
+    if (value != null) {
+      child.dataType match {
+        // `Decimal` and `UTF8String` are internal types in spark sql, we need to convert them
+        // into acceptable types for `CountMinSketch`.
+        case DecimalType() => buffer.add(value.asInstanceOf[Decimal].toJavaBigDecimal)
+        // For string type, we can get bytes of our `UTF8String` directly, and call the `addBinary`
+        // instead of `addString` to avoid unnecessary conversion.
+        case StringType => buffer.addBinary(value.asInstanceOf[UTF8String].getBytes)
+        case _ => buffer.add(value)
+      }
+    }
+  }
+
+  override def merge(buffer: CountMinSketch, input: CountMinSketch): Unit = {
+    buffer.mergeInPlace(input)
+  }
+
+  override def eval(buffer: CountMinSketch): Any = serialize(buffer)
+
+  override def serialize(buffer: CountMinSketch): Array[Byte] = {
+    val out = new ByteArrayOutputStream()
+    buffer.writeTo(out)
+    out.toByteArray
+  }
+
+  override def deserialize(storageFormat: Array[Byte]): CountMinSketch = {
+    val in = new ByteArrayInputStream(storageFormat)
+    CountMinSketch.readFrom(in)
+  }
+
+  override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): CountMinSketchAgg =
+    copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+  override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): CountMinSketchAgg =
+    copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+  override def inputTypes: Seq[AbstractDataType] = {
+    Seq(TypeCollection(NumericType, StringType, DateType, TimestampType, BooleanType, BinaryType),
+      DoubleType, DoubleType, IntegerType)
+  }
+
+  override def nullable: Boolean = false
+
+  override def dataType: DataType = BinaryType
+
+  override def children: Seq[Expression] =
+    Seq(child, epsExpression, confidenceExpression, seedExpression)
+
+  override def prettyName: String = "count_min_sketch"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
new file mode 100644
index 0000000..6e08e29
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAggSuite.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.ByteArrayInputStream
+import java.nio.charset.StandardCharsets
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Cast, GenericInternalRow, Literal}
+import org.apache.spark.sql.types.{DecimalType, _}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+class CountMinSketchAggSuite extends SparkFunSuite {
+  private val childExpression = BoundReference(0, IntegerType, nullable = true)
+  private val epsOfTotalCount = 0.0001
+  private val confidence = 0.99
+  private val seed = 42
+
+  test("serialize and de-serialize") {
+    // Check empty serialize and de-serialize
+    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
+      Literal(seed))
+    val buffer = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+
+    // Check non-empty serialize and de-serialize
+    val random = new Random(31)
+    (0 until 10000).map(_ => random.nextInt(100)).foreach { value =>
+      buffer.add(value)
+    }
+    assert(buffer.equals(agg.deserialize(agg.serialize(buffer))))
+  }
+
+  def testHighLevelInterface[T: ClassTag](
+      dataType: DataType,
+      sampledItemIndices: Array[Int],
+      allItems: Array[T],
+      exactFreq: Map[Any, Long]): Any = {
+    test(s"high level interface, update, merge, eval... - $dataType") {
+      val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
+        Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
+      assert(!agg.nullable)
+
+      val group1 = 0 until sampledItemIndices.length / 2
+      val group1Buffer = agg.createAggregationBuffer()
+      group1.foreach { index =>
+        val input = InternalRow(allItems(sampledItemIndices(index)))
+        agg.update(group1Buffer, input)
+      }
+
+      val group2 = sampledItemIndices.length / 2 until sampledItemIndices.length
+      val group2Buffer = agg.createAggregationBuffer()
+      group2.foreach { index =>
+        val input = InternalRow(allItems(sampledItemIndices(index)))
+        agg.update(group2Buffer, input)
+      }
+
+      var mergeBuffer = agg.createAggregationBuffer()
+      agg.merge(mergeBuffer, group1Buffer)
+      agg.merge(mergeBuffer, group2Buffer)
+      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+
+      // Merge in a different order
+      mergeBuffer = agg.createAggregationBuffer()
+      agg.merge(mergeBuffer, group2Buffer)
+      agg.merge(mergeBuffer, group1Buffer)
+      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+
+      // Merge with an empty partition
+      val emptyBuffer = agg.createAggregationBuffer()
+      agg.merge(mergeBuffer, emptyBuffer)
+      checkResult(agg.eval(mergeBuffer), allItems, exactFreq)
+    }
+  }
+
+  def testLowLevelInterface[T: ClassTag](
+      dataType: DataType,
+      sampledItemIndices: Array[Int],
+      allItems: Array[T],
+      exactFreq: Map[Any, Long]): Any = {
+    test(s"low level interface, update, merge, eval... - ${dataType.typeName}") {
+      val inputAggregationBufferOffset = 1
+      val mutableAggregationBufferOffset = 2
+
+      // Phase one, partial mode aggregation
+      val agg = new CountMinSketchAgg(BoundReference(0, dataType, nullable = true),
+        Literal(epsOfTotalCount), Literal(confidence), Literal(seed))
+        .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+        .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+      val mutableAggBuffer = new GenericInternalRow(
+        new Array[Any](mutableAggregationBufferOffset + 1))
+      agg.initialize(mutableAggBuffer)
+
+      sampledItemIndices.foreach { i =>
+        agg.update(mutableAggBuffer, InternalRow(allItems(i)))
+      }
+      agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+      // Serialize the aggregation buffer
+      val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+      val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
+
+      // Phase 2: final mode aggregation
+      // Re-initialize the aggregation buffer
+      agg.initialize(mutableAggBuffer)
+      agg.merge(mutableAggBuffer, inputAggBuffer)
+      checkResult(agg.eval(mutableAggBuffer), allItems, exactFreq)
+    }
+  }
+
+  private def checkResult[T: ClassTag](
+      result: Any,
+      data: Array[T],
+      exactFreq: Map[Any, Long]): Unit = {
+    result match {
+      case bytesData: Array[Byte] =>
+        val in = new ByteArrayInputStream(bytesData)
+        val cms = CountMinSketch.readFrom(in)
+        val probCorrect = {
+          val numErrors = data.map { i =>
+            val count = exactFreq.getOrElse(getProbeItem(i), 0L)
+            val item = i match {
+              case dec: Decimal => dec.toJavaBigDecimal
+              case str: UTF8String => str.getBytes
+              case _ => i
+            }
+            val ratio = (cms.estimateCount(item) - count).toDouble / data.length
+            if (ratio > epsOfTotalCount) 1 else 0
+          }.sum
+
+          1D - numErrors.toDouble / data.length
+        }
+
+        assert(
+          probCorrect > confidence,
+          s"Confidence not reached: required $confidence, reached $probCorrect"
+        )
+      case _ => fail("unexpected return type")
+    }
+  }
+
+  private def getProbeItem[T: ClassTag](item: T): Any = item match {
+    // Use a string to represent the content of an array of bytes
+    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+    case i => identity(i)
+  }
+
+  def testItemType[T: ClassTag](dataType: DataType)(itemGenerator: Random => T): Unit = {
+    // Uses fixed seed to ensure reproducible test execution
+    val r = new Random(31)
+
+    val numAllItems = 1000000
+    val allItems = Array.fill(numAllItems)(itemGenerator(r))
+
+    val numSamples = numAllItems / 10
+    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+
+    val exactFreq = {
+      val sampledItems = sampledItemIndices.map(allItems)
+      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+    }
+
+    testLowLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
+    testHighLevelInterface[T](dataType, sampledItemIndices, allItems, exactFreq)
+  }
+
+  testItemType[Byte](ByteType) { _.nextInt().toByte }
+
+  testItemType[Short](ShortType) { _.nextInt().toShort }
+
+  testItemType[Int](IntegerType) { _.nextInt() }
+
+  testItemType[Long](LongType) { _.nextLong() }
+
+  testItemType[UTF8String](StringType) { r => UTF8String.fromString(r.nextString(r.nextInt(20))) }
+
+  testItemType[Float](FloatType) { _.nextFloat() }
+
+  testItemType[Double](DoubleType) { _.nextDouble() }
+
+  testItemType[Decimal](new DecimalType()) { r => Decimal(r.nextDouble()) }
+
+  testItemType[Boolean](BooleanType) { _.nextBoolean() }
+
+  testItemType[Array[Byte]](BinaryType) { r =>
+    r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+  }
+
+
+  test("fails analysis if eps, confidence or seed provided is not a literal or constant foldable") {
+    val wrongEps = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = AttributeReference("a", DoubleType)(),
+      confidenceExpression = Literal(confidence),
+      seedExpression = Literal(seed))
+    val wrongConfidence = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = Literal(epsOfTotalCount),
+      confidenceExpression = AttributeReference("b", DoubleType)(),
+      seedExpression = Literal(seed))
+    val wrongSeed = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = Literal(epsOfTotalCount),
+      confidenceExpression = Literal(confidence),
+      seedExpression = AttributeReference("c", IntegerType)())
+
+    Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
+      assertEqual(
+        wrongAgg.checkInputDataTypes(),
+        TypeCheckFailure(
+          "The eps, confidence or seed provided must be a literal or constant foldable")
+      )
+    }
+  }
+
+  test("fails analysis if parameters are invalid") {
+    // parameters are null
+    val wrongEps = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = Cast(Literal(null), DoubleType),
+      confidenceExpression = Literal(confidence),
+      seedExpression = Literal(seed))
+    val wrongConfidence = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = Literal(epsOfTotalCount),
+      confidenceExpression = Cast(Literal(null), DoubleType),
+      seedExpression = Literal(seed))
+    val wrongSeed = new CountMinSketchAgg(
+      childExpression,
+      epsExpression = Literal(epsOfTotalCount),
+      confidenceExpression = Literal(confidence),
+      seedExpression = Cast(Literal(null), IntegerType))
+
+    Seq(wrongEps, wrongConfidence, wrongSeed).foreach { wrongAgg =>
+      assertEqual(
+        wrongAgg.checkInputDataTypes(),
+        TypeCheckFailure("The eps, confidence or seed provided should not be null")
+      )
+    }
+
+    // parameters are out of the valid range
+    Seq(0.0, -1000.0).foreach { invalidEps =>
+      val invalidAgg = new CountMinSketchAgg(
+        childExpression,
+        epsExpression = Literal(invalidEps),
+        confidenceExpression = Literal(confidence),
+        seedExpression = Literal(seed))
+      assertEqual(
+        invalidAgg.checkInputDataTypes(),
+        TypeCheckFailure(s"Relative error must be positive (current value = $invalidEps)")
+      )
+    }
+
+    Seq(0.0, 1.0, -2.0, 2.0).foreach { invalidConfidence =>
+      val invalidAgg = new CountMinSketchAgg(
+        childExpression,
+        epsExpression = Literal(epsOfTotalCount),
+        confidenceExpression = Literal(invalidConfidence),
+        seedExpression = Literal(seed))
+      assertEqual(
+        invalidAgg.checkInputDataTypes(),
+        TypeCheckFailure(
+          s"Confidence must be within range (0.0, 1.0) (current value = $invalidConfidence)")
+      )
+    }
+  }
+
+  private def assertEqual[T](left: T, right: T): Unit = {
+    assert(left == right)
+  }
+
+  test("null handling") {
+    def isEqual(result: Any, other: CountMinSketch): Boolean = {
+      result match {
+        case bytesData: Array[Byte] =>
+          val in = new ByteArrayInputStream(bytesData)
+          val cms = CountMinSketch.readFrom(in)
+          cms.equals(other)
+        case _ => fail("unexpected return type")
+      }
+    }
+
+    val agg = new CountMinSketchAgg(childExpression, Literal(epsOfTotalCount), Literal(confidence),
+      Literal(seed))
+    val emptyCms = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+    val buffer = new GenericInternalRow(new Array[Any](1))
+    agg.initialize(buffer)
+    // Empty aggregation buffer
+    assert(isEqual(agg.eval(buffer), emptyCms))
+    // Empty input row
+    agg.update(buffer, InternalRow(null))
+    assert(isEqual(agg.eval(buffer), emptyCms))
+
+    // Add some non-empty row
+    agg.update(buffer, InternalRow(0))
+    assert(!isEqual(agg.eval(buffer), emptyCms))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d57a594b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
new file mode 100644
index 0000000..4cc5060
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.ByteArrayInputStream
+import java.nio.charset.StandardCharsets
+import java.sql.{Date, Timestamp}
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{Decimal, StringType, _}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.sketch.CountMinSketch
+
+class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
+
+  private val table = "count_min_sketch_table"
+
+  /** Uses fixed seed to ensure reproducible test execution */
+  private val r = new Random(42)
+  private val numAllItems = 1000
+  private val numSamples = numAllItems / 10
+
+  private val eps = 0.1D
+  private val confidence = 0.95D
+  private val seed = 11
+
+  val startDate = DateTimeUtils.fromJavaDate(Date.valueOf("1900-01-01"))
+  val endDate = DateTimeUtils.fromJavaDate(Date.valueOf("2016-01-01"))
+  val startTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("1900-01-01 00:00:00"))
+  val endTS = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-01-01 00:00:00"))
+
+  test(s"compute count-min sketch for multiple columns of different types") {
+    val (allBytes, sampledByteIndices, exactByteFreq) =
+      generateTestData[Byte] { _.nextInt().toByte }
+    val (allShorts, sampledShortIndices, exactShortFreq) =
+      generateTestData[Short] { _.nextInt().toShort }
+    val (allInts, sampledIntIndices, exactIntFreq) =
+      generateTestData[Int] { _.nextInt() }
+    val (allLongs, sampledLongIndices, exactLongFreq) =
+      generateTestData[Long] { _.nextLong() }
+    val (allStrings, sampledStringIndices, exactStringFreq) =
+      generateTestData[String] { r => r.nextString(r.nextInt(20)) }
+    val (allDates, sampledDateIndices, exactDateFreq) = generateTestData[Date] { r =>
+      DateTimeUtils.toJavaDate(r.nextInt(endDate - startDate) + startDate)
+    }
+    val (allTimestamps, sampledTSIndices, exactTSFreq) = generateTestData[Timestamp] { r =>
+      DateTimeUtils.toJavaTimestamp(r.nextLong() % (endTS - startTS) + startTS)
+    }
+    val (allFloats, sampledFloatIndices, exactFloatFreq) =
+      generateTestData[Float] { _.nextFloat() }
+    val (allDoubles, sampledDoubleIndices, exactDoubleFreq) =
+      generateTestData[Double] { _.nextDouble() }
+    val (allDeciamls, sampledDecimalIndices, exactDecimalFreq) =
+      generateTestData[Decimal] { r => Decimal(r.nextDouble()) }
+    val (allBooleans, sampledBooleanIndices, exactBooleanFreq) =
+      generateTestData[Boolean] { _.nextBoolean() }
+    val (allBinaries, sampledBinaryIndices, exactBinaryFreq) = generateTestData[Array[Byte]] { r =>
+      r.nextString(r.nextInt(20)).getBytes(StandardCharsets.UTF_8)
+    }
+
+    val data = (0 until numSamples).map { i =>
+      Row(allBytes(sampledByteIndices(i)),
+        allShorts(sampledShortIndices(i)),
+        allInts(sampledIntIndices(i)),
+        allLongs(sampledLongIndices(i)),
+        allStrings(sampledStringIndices(i)),
+        allDates(sampledDateIndices(i)),
+        allTimestamps(sampledTSIndices(i)),
+        allFloats(sampledFloatIndices(i)),
+        allDoubles(sampledDoubleIndices(i)),
+        allDeciamls(sampledDecimalIndices(i)),
+        allBooleans(sampledBooleanIndices(i)),
+        allBinaries(sampledBinaryIndices(i)))
+    }
+
+    val schema = StructType(Seq(
+      StructField("c1", ByteType),
+      StructField("c2", ShortType),
+      StructField("c3", IntegerType),
+      StructField("c4", LongType),
+      StructField("c5", StringType),
+      StructField("c6", DateType),
+      StructField("c7", TimestampType),
+      StructField("c8", FloatType),
+      StructField("c9", DoubleType),
+      StructField("c10", new DecimalType()),
+      StructField("c11", BooleanType),
+      StructField("c12", BinaryType)))
+
+    withTempView(table) {
+      val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
+      spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
+      val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)")
+        .mkString(", ")
+      val result = sql(s"SELECT $cmsSql FROM $table").head()
+      schema.indices.foreach { i =>
+        val binaryData = result.getAs[Array[Byte]](i)
+        val in = new ByteArrayInputStream(binaryData)
+        val cms = CountMinSketch.readFrom(in)
+        schema.fields(i).dataType match {
+          case ByteType => checkResult(cms, allBytes, exactByteFreq)
+          case ShortType => checkResult(cms, allShorts, exactShortFreq)
+          case IntegerType => checkResult(cms, allInts, exactIntFreq)
+          case LongType => checkResult(cms, allLongs, exactLongFreq)
+          case StringType => checkResult(cms, allStrings, exactStringFreq)
+          case DateType =>
+            checkResult(cms,
+              allDates.map(DateTimeUtils.fromJavaDate),
+              exactDateFreq.map { e =>
+                (DateTimeUtils.fromJavaDate(e._1.asInstanceOf[Date]), e._2)
+              })
+          case TimestampType =>
+            checkResult(cms,
+              allTimestamps.map(DateTimeUtils.fromJavaTimestamp),
+              exactTSFreq.map { e =>
+                (DateTimeUtils.fromJavaTimestamp(e._1.asInstanceOf[Timestamp]), e._2)
+              })
+          case FloatType => checkResult(cms, allFloats, exactFloatFreq)
+          case DoubleType => checkResult(cms, allDoubles, exactDoubleFreq)
+          case DecimalType() => checkResult(cms, allDeciamls, exactDecimalFreq)
+          case BooleanType => checkResult(cms, allBooleans, exactBooleanFreq)
+          case BinaryType => checkResult(cms, allBinaries, exactBinaryFreq)
+        }
+      }
+    }
+  }
+
+  private def checkResult[T: ClassTag](
+      cms: CountMinSketch,
+      data: Array[T],
+      exactFreq: Map[Any, Long]): Unit = {
+    val probCorrect = {
+      val numErrors = data.map { i =>
+        val count = exactFreq.getOrElse(getProbeItem(i), 0L)
+        val item = i match {
+          case dec: Decimal => dec.toJavaBigDecimal
+          case str: UTF8String => str.getBytes
+          case _ => i
+        }
+        val ratio = (cms.estimateCount(item) - count).toDouble / data.length
+        if (ratio > eps) 1 else 0
+      }.sum
+
+      1D - numErrors.toDouble / data.length
+    }
+
+    assert(
+      probCorrect > confidence,
+      s"Confidence not reached: required $confidence, reached $probCorrect"
+    )
+  }
+
+  private def getProbeItem[T: ClassTag](item: T): Any = item match {
+    // Use a string to represent the content of an array of bytes
+    case bytes: Array[Byte] => new String(bytes, StandardCharsets.UTF_8)
+    case i => identity(i)
+  }
+
+  private def generateTestData[T: ClassTag](
+      itemGenerator: Random => T): (Array[T], Array[Int], Map[Any, Long]) = {
+    val allItems = Array.fill(numAllItems)(itemGenerator(r))
+    val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+    val exactFreq = {
+      val sampledItems = sampledItemIndices.map(allItems)
+      sampledItems.groupBy(getProbeItem).mapValues(_.length.toLong)
+    }
+    (allItems, sampledItemIndices, exactFreq)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org