You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/27 05:12:37 UTC

spark git commit: [SPARK-12935][SQL] DataFrame API for Count-Min Sketch

Repository: spark
Updated Branches:
  refs/heads/master e7f9199e7 -> ce38a35b7


[SPARK-12935][SQL] DataFrame API for Count-Min Sketch

This PR integrates Count-Min Sketch from spark-sketch into DataFrame. This version resorts to `RDD.aggregate` for building the sketch. A more performant UDAF version can be built in future follow-up PRs.

Author: Cheng Lian <li...@databricks.com>

Closes #10911 from liancheng/cms-df-api.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce38a35b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce38a35b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce38a35b

Branch: refs/heads/master
Commit: ce38a35b764397fcf561ac81de6da96579f5c13e
Parents: e7f9199
Author: Cheng Lian <li...@databricks.com>
Authored: Tue Jan 26 20:12:34 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Jan 26 20:12:34 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/util/sketch/BloomFilter.java   | 10 ++-
 .../spark/util/sketch/CountMinSketch.java       | 26 ++++---
 .../spark/util/sketch/CountMinSketchImpl.java   | 56 ++++++++------
 sql/core/pom.xml                                |  5 ++
 .../spark/sql/DataFrameStatFunctions.scala      | 81 ++++++++++++++++++++
 .../apache/spark/sql/JavaDataFrameSuite.java    | 28 ++++++-
 .../apache/spark/sql/DataFrameStatSuite.scala   | 36 +++++++++
 7 files changed, 205 insertions(+), 37 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
index 00378d5..d392fb1 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -47,10 +47,12 @@ public abstract class BloomFilter {
   public enum Version {
     /**
      * {@code BloomFilter} binary format version 1 (all values written in big-endian order):
-     * - Version number, always 1 (32 bit)
-     * - Total number of words of the underlying bit array (32 bit)
-     * - The words/longs (numWords * 64 bit)
-     * - Number of hash functions (32 bit)
+     * <ul>
+     *   <li>Version number, always 1 (32 bit)</li>
+     *   <li>Total number of words of the underlying bit array (32 bit)</li>
+     *   <li>The words/longs (numWords * 64 bit)</li>
+     *   <li>Number of hash functions (32 bit)</li>
+     * </ul>
      */
     V1(1);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
index 00c0b1b..5692e57 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -59,16 +59,22 @@ abstract public class CountMinSketch {
   public enum Version {
     /**
      * {@code CountMinSketch} binary format version 1 (all values written in big-endian order):
-     * - Version number, always 1 (32 bit)
-     * - Total count of added items (64 bit)
-     * - Depth (32 bit)
-     * - Width (32 bit)
-     * - Hash functions (depth * 64 bit)
-     * - Count table
-     *   - Row 0 (width * 64 bit)
-     *   - Row 1 (width * 64 bit)
-     *   - ...
-     *   - Row depth - 1 (width * 64 bit)
+     * <ul>
+     *   <li>Version number, always 1 (32 bit)</li>
+     *   <li>Total count of added items (64 bit)</li>
+     *   <li>Depth (32 bit)</li>
+     *   <li>Width (32 bit)</li>
+     *   <li>Hash functions (depth * 64 bit)</li>
+     *   <li>
+     *     Count table
+     *     <ul>
+     *       <li>Row 0 (width * 64 bit)</li>
+     *       <li>Row 1 (width * 64 bit)</li>
+     *       <li>...</li>
+     *       <li>Row {@code depth - 1} (width * 64 bit)</li>
+     *     </ul>
+     *   </li>
+     * </ul>
      */
     V1(1);
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index d088096..8cc29e4 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -21,13 +21,16 @@ import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
 import java.io.OutputStream;
+import java.io.Serializable;
 import java.io.UnsupportedEncodingException;
 import java.util.Arrays;
 import java.util.Random;
 
-class CountMinSketchImpl extends CountMinSketch {
-  public static final long PRIME_MODULUS = (1L << 31) - 1;
+class CountMinSketchImpl extends CountMinSketch implements Serializable {
+  private static final long PRIME_MODULUS = (1L << 31) - 1;
 
   private int depth;
   private int width;
@@ -37,6 +40,9 @@ class CountMinSketchImpl extends CountMinSketch {
   private double eps;
   private double confidence;
 
+  private CountMinSketchImpl() {
+  }
+
   CountMinSketchImpl(int depth, int width, int seed) {
     this.depth = depth;
     this.width = width;
@@ -55,16 +61,6 @@ class CountMinSketchImpl extends CountMinSketch {
     initTablesWith(depth, width, seed);
   }
 
-  CountMinSketchImpl(int depth, int width, long totalCount, long hashA[], long table[][]) {
-    this.depth = depth;
-    this.width = width;
-    this.eps = 2.0 / width;
-    this.confidence = 1 - 1 / Math.pow(2, depth);
-    this.hashA = hashA;
-    this.table = table;
-    this.totalCount = totalCount;
-  }
-
   @Override
   public boolean equals(Object other) {
     if (other == this) {
@@ -325,27 +321,43 @@ class CountMinSketchImpl extends CountMinSketch {
   }
 
   public static CountMinSketchImpl readFrom(InputStream in) throws IOException {
+    CountMinSketchImpl sketch = new CountMinSketchImpl();
+    sketch.readFrom0(in);
+    return sketch;
+  }
+
+  private void readFrom0(InputStream in) throws IOException {
     DataInputStream dis = new DataInputStream(in);
 
-    // Ignores version number
-    dis.readInt();
+    int version = dis.readInt();
+    if (version != Version.V1.getVersionNumber()) {
+      throw new IOException("Unexpected Count-Min Sketch version number (" + version + ")");
+    }
 
-    long totalCount = dis.readLong();
-    int depth = dis.readInt();
-    int width = dis.readInt();
+    this.totalCount = dis.readLong();
+    this.depth = dis.readInt();
+    this.width = dis.readInt();
+    this.eps = 2.0 / width;
+    this.confidence = 1 - 1 / Math.pow(2, depth);
 
-    long hashA[] = new long[depth];
+    this.hashA = new long[depth];
     for (int i = 0; i < depth; ++i) {
-      hashA[i] = dis.readLong();
+      this.hashA[i] = dis.readLong();
     }
 
-    long table[][] = new long[depth][width];
+    this.table = new long[depth][width];
     for (int i = 0; i < depth; ++i) {
       for (int j = 0; j < width; ++j) {
-        table[i][j] = dis.readLong();
+        this.table[i][j] = dis.readLong();
       }
     }
+  }
+
+  private void writeObject(ObjectOutputStream out) throws IOException {
+    this.writeTo(out);
+  }
 
-    return new CountMinSketchImpl(depth, width, totalCount, hashA, table);
+  private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+    this.readFrom0(in);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/sql/core/pom.xml
----------------------------------------------------------------------
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 31b364f..4bb55f6 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -44,6 +44,11 @@
     </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sketch_2.10</artifactId>
+      <version>${project.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
       <artifactId>spark-core_${scala.binary.version}</artifactId>
       <version>${project.version}</version>
     </dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index e66aa5f..465b12b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.execution.stat._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.sketch.CountMinSketch
 
 /**
  * :: Experimental ::
@@ -309,4 +311,83 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
   def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DataFrame = {
     sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed)
   }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param colName name of the column over which the sketch is built
+   * @param depth depth of the sketch
+   * @param width width of the sketch
+   * @param seed random seed
+   * @return a [[CountMinSketch]] over column `colName`
+   * @since 2.0.0
+   */
+  def countMinSketch(colName: String, depth: Int, width: Int, seed: Int): CountMinSketch = {
+    countMinSketch(Column(colName), depth, width, seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param colName name of the column over which the sketch is built
+   * @param eps relative error of the sketch
+   * @param confidence confidence of the sketch
+   * @param seed random seed
+   * @return a [[CountMinSketch]] over column `colName`
+   * @since 2.0.0
+   */
+  def countMinSketch(
+      colName: String, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+    countMinSketch(Column(colName), eps, confidence, seed)
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param col the column over which the sketch is built
+   * @param depth depth of the sketch
+   * @param width width of the sketch
+   * @param seed random seed
+   * @return a [[CountMinSketch]] over column `colName`
+   * @since 2.0.0
+   */
+  def countMinSketch(col: Column, depth: Int, width: Int, seed: Int): CountMinSketch = {
+    countMinSketch(col, CountMinSketch.create(depth, width, seed))
+  }
+
+  /**
+   * Builds a Count-min Sketch over a specified column.
+   *
+   * @param col the column over which the sketch is built
+   * @param eps relative error of the sketch
+   * @param confidence confidence of the sketch
+   * @param seed random seed
+   * @return a [[CountMinSketch]] over column `colName`
+   * @since 2.0.0
+   */
+  def countMinSketch(col: Column, eps: Double, confidence: Double, seed: Int): CountMinSketch = {
+    countMinSketch(col, CountMinSketch.create(eps, confidence, seed))
+  }
+
+  private def countMinSketch(col: Column, zero: CountMinSketch): CountMinSketch = {
+    val singleCol = df.select(col)
+    val colType = singleCol.schema.head.dataType
+
+    require(
+      colType == StringType || colType.isInstanceOf[IntegralType],
+      s"Count-min Sketch only supports string type and integral types, " +
+        s"and does not support type $colType."
+    )
+
+    singleCol.rdd.aggregate(zero)(
+      (sketch: CountMinSketch, row: Row) => {
+        sketch.add(row.get(0))
+        sketch
+      },
+
+      (sketch1: CountMinSketch, sketch2: CountMinSketch) => {
+        sketch1.mergeInPlace(sketch2)
+      }
+    )
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index ac1607b..9cf94e7 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -35,9 +35,10 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.*;
-import static org.apache.spark.sql.functions.*;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.types.*;
+import org.apache.spark.util.sketch.CountMinSketch;
+import static org.apache.spark.sql.functions.*;
 import static org.apache.spark.sql.types.DataTypes.*;
 
 public class JavaDataFrameSuite {
@@ -321,4 +322,29 @@ public class JavaDataFrameSuite {
       Thread.currentThread().getContextClassLoader().getResource("text-suite2.txt").toString());
     Assert.assertEquals(5L, df2.count());
   }
+
+  @Test
+  public void testCountMinSketch() {
+    DataFrame df = context.range(1000);
+
+    CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42);
+    Assert.assertEquals(sketch1.totalCount(), 1000);
+    Assert.assertEquals(sketch1.depth(), 10);
+    Assert.assertEquals(sketch1.width(), 20);
+
+    CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42);
+    Assert.assertEquals(sketch2.totalCount(), 1000);
+    Assert.assertEquals(sketch2.depth(), 10);
+    Assert.assertEquals(sketch2.width(), 20);
+
+    CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42);
+    Assert.assertEquals(sketch3.totalCount(), 1000);
+    Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4);
+    Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3);
+
+    CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42);
+    Assert.assertEquals(sketch4.totalCount(), 1000);
+    Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4);
+    Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ce38a35b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 63ad6c4..8f3ea5a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,8 +19,11 @@ package org.apache.spark.sql
 
 import java.util.Random
 
+import org.scalatest.Matchers._
+
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.DoubleType
 
 class DataFrameStatSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -210,4 +213,37 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
       sampled.groupBy("key").count().orderBy("key"),
       Seq(Row(0, 6), Row(1, 11)))
   }
+
+  // This test case only verifies that `DataFrame.countMinSketch()` methods do return
+  // `CountMinSketch`es that meet required specs.  Test cases for `CountMinSketch` can be found in
+  // `CountMinSketchSuite` in project spark-sketch.
+  test("countMinSketch") {
+    val df = sqlContext.range(1000)
+
+    val sketch1 = df.stat.countMinSketch("id", depth = 10, width = 20, seed = 42)
+    assert(sketch1.totalCount() === 1000)
+    assert(sketch1.depth() === 10)
+    assert(sketch1.width() === 20)
+
+    val sketch2 = df.stat.countMinSketch($"id", depth = 10, width = 20, seed = 42)
+    assert(sketch2.totalCount() === 1000)
+    assert(sketch2.depth() === 10)
+    assert(sketch2.width() === 20)
+
+    val sketch3 = df.stat.countMinSketch("id", eps = 0.001, confidence = 0.99, seed = 42)
+    assert(sketch3.totalCount() === 1000)
+    assert(sketch3.relativeError() === 0.001)
+    assert(sketch3.confidence() === 0.99 +- 5e-3)
+
+    val sketch4 = df.stat.countMinSketch($"id", eps = 0.001, confidence = 0.99, seed = 42)
+    assert(sketch4.totalCount() === 1000)
+    assert(sketch4.relativeError() === 0.001 +- 1e04)
+    assert(sketch4.confidence() === 0.99 +- 5e-3)
+
+    intercept[IllegalArgumentException] {
+      df.select('id cast DoubleType as 'id)
+        .stat
+        .countMinSketch('id, depth = 10, width = 20, seed = 42)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org