You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/23 09:34:57 UTC

spark git commit: [SPARK-12933][SQL] Initial implementation of Count-Min sketch

Repository: spark
Updated Branches:
  refs/heads/master 5af5a0216 -> 1c690ddaf


[SPARK-12933][SQL] Initial implementation of Count-Min sketch

This PR adds an initial implementation of count min sketch, contained in a new module spark-sketch under `common/sketch`. The implementation is based on the [`CountMinSketch` class in stream-lib][1].

As required by the [design doc][2], spark-sketch should have no external dependency.
Two classes, `Murmur3_x86_32` and `Platform` are copied to spark-sketch from spark-unsafe for hashing facilities. They'll also be used in the upcoming bloom filter implementation.

The following features will be added in future follow-up PRs:

- Serialization support
- DataFrame API integration

[1]: https://github.com/addthis/stream-lib/blob/aac6b4d23a8686b000f80baa447e0922ecac3bcb/src/main/java/com/clearspring/analytics/stream/frequency/CountMinSketch.java
[2]: https://issues.apache.org/jira/secure/attachment/12782378/BloomFilterandCount-MinSketchinSpark2.0.pdf

Author: Cheng Lian <li...@databricks.com>

Closes #10851 from liancheng/count-min-sketch.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c690dda
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c690dda
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c690dda

Branch: refs/heads/master
Commit: 1c690ddafa8376c55cbc5b7a7a750200abfbe2a6
Parents: 5af5a02
Author: Cheng Lian <li...@databricks.com>
Authored: Sat Jan 23 00:34:55 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Jan 23 00:34:55 2016 -0800

----------------------------------------------------------------------
 common/sketch/pom.xml                           |  42 +++
 .../spark/util/sketch/CountMinSketch.java       | 132 +++++++++
 .../spark/util/sketch/CountMinSketchImpl.java   | 268 +++++++++++++++++++
 .../spark/util/sketch/Murmur3_x86_32.java       | 126 +++++++++
 .../org/apache/spark/util/sketch/Platform.java  | 172 ++++++++++++
 .../spark/util/sketch/CountMinSketchSuite.scala | 112 ++++++++
 dev/sparktestsupport/modules.py                 |  12 +
 pom.xml                                         |   1 +
 project/SparkBuild.scala                        |  39 ++-
 9 files changed, 892 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/pom.xml
----------------------------------------------------------------------
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
new file mode 100644
index 0000000..67723fa
--- /dev/null
+++ b/common/sketch/pom.xml
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Licensed to the Apache Software Foundation (ASF) under one or more
+  ~ contributor license agreements.  See the NOTICE file distributed with
+  ~ this work for additional information regarding copyright ownership.
+  ~ The ASF licenses this file to You under the Apache License, Version 2.0
+  ~ (the "License"); you may not use this file except in compliance with
+  ~ the License.  You may obtain a copy of the License at
+  ~
+  ~    http://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+  <modelVersion>4.0.0</modelVersion>
+  <parent>
+    <groupId>org.apache.spark</groupId>
+    <artifactId>spark-parent_2.10</artifactId>
+    <version>2.0.0-SNAPSHOT</version>
+    <relativePath>../../pom.xml</relativePath>
+  </parent>
+
+  <groupId>org.apache.spark</groupId>
+  <artifactId>spark-sketch_2.10</artifactId>
+  <packaging>jar</packaging>
+  <name>Spark Project Sketch</name>
+  <url>http://spark.apache.org/</url>
+  <properties>
+    <sbt.project.name>sketch</sbt.project.name>
+  </properties>
+
+  <build>
+    <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+    <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
+  </build>
+</project>

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
new file mode 100644
index 0000000..21b161b
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.InputStream;
+import java.io.OutputStream;
+
+/**
+ * A Count-Min sketch is a probabilistic data structure used for summarizing streams of data in
+ * sub-linear space.  Currently, supported data types include:
+ * <ul>
+ *   <li>{@link Byte}</li>
+ *   <li>{@link Short}</li>
+ *   <li>{@link Integer}</li>
+ *   <li>{@link Long}</li>
+ *   <li>{@link String}</li>
+ * </ul>
+ * Each {@link CountMinSketch} is initialized with a random seed, and a pair
+ * of parameters:
+ * <ol>
+ *   <li>relative error (or {@code eps}), and
+ *   <li>confidence (or {@code delta})
+ * </ol>
+ * Suppose you want to estimate the number of times an element {@code x} has appeared in a data
+ * stream so far.  With probability {@code delta}, the estimate of this frequency is within the
+ * range {@code true frequency <= estimate <= true frequency + eps * N}, where {@code N} is the
+ * total count of items have appeared the the data stream so far.
+ *
+ * Under the cover, a {@link CountMinSketch} is essentially a two-dimensional {@code long} array
+ * with depth {@code d} and width {@code w}, where
+ * <ul>
+ *   <li>{@code d = ceil(2 / eps)}</li>
+ *   <li>{@code w = ceil(-log(1 - confidence) / log(2))}</li>
+ * </ul>
+ *
+ * See http://www.eecs.harvard.edu/~michaelm/CS222/countmin.pdf for technical details,
+ * including proofs of the estimates and error bounds used in this implementation.
+ *
+ * This implementation is largely based on the {@code CountMinSketch} class from stream-lib.
+ */
+abstract public class CountMinSketch {
+  /**
+   * Returns the relative error (or {@code eps}) of this {@link CountMinSketch}.
+   */
+  public abstract double relativeError();
+
+  /**
+   * Returns the confidence (or {@code delta}) of this {@link CountMinSketch}.
+   */
+  public abstract double confidence();
+
+  /**
+   * Depth of this {@link CountMinSketch}.
+   */
+  public abstract int depth();
+
+  /**
+   * Width of this {@link CountMinSketch}.
+   */
+  public abstract int width();
+
+  /**
+   * Total count of items added to this {@link CountMinSketch} so far.
+   */
+  public abstract long totalCount();
+
+  /**
+   * Adds 1 to {@code item}.
+   */
+  public abstract void add(Object item);
+
+  /**
+   * Adds {@code count} to {@code item}.
+   */
+  public abstract void add(Object item, long count);
+
+  /**
+   * Returns the estimated frequency of {@code item}.
+   */
+  public abstract long estimateCount(Object item);
+
+  /**
+   * Merges another {@link CountMinSketch} with this one in place.
+   *
+   * Note that only Count-Min sketches with the same {@code depth}, {@code width}, and random seed
+   * can be merged.
+   */
+  public abstract CountMinSketch mergeInPlace(CountMinSketch other);
+
+  /**
+   * Writes out this {@link CountMinSketch} to an output stream in binary format.
+   */
+  public abstract void writeTo(OutputStream out);
+
+  /**
+   * Reads in a {@link CountMinSketch} from an input stream.
+   */
+  public static CountMinSketch readFrom(InputStream in) {
+    throw new UnsupportedOperationException("Not implemented yet");
+  }
+
+  /**
+   * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random
+   * {@code seed}.
+   */
+  public static CountMinSketch create(int depth, int width, int seed) {
+    return new CountMinSketchImpl(depth, width, seed);
+  }
+
+  /**
+   * Creates a {@link CountMinSketch} with given relative error ({@code eps}), {@code confidence},
+   * and random {@code seed}.
+   */
+  public static CountMinSketch create(double eps, double confidence, int seed) {
+    return new CountMinSketchImpl(eps, confidence, seed);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
new file mode 100644
index 0000000..e9fdbe3
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.OutputStream;
+import java.io.UnsupportedEncodingException;
+import java.util.Arrays;
+import java.util.Random;
+
+class CountMinSketchImpl extends CountMinSketch {
+  public static final long PRIME_MODULUS = (1L << 31) - 1;
+
+  private int depth;
+  private int width;
+  private long[][] table;
+  private long[] hashA;
+  private long totalCount;
+  private double eps;
+  private double confidence;
+
+  public CountMinSketchImpl(int depth, int width, int seed) {
+    this.depth = depth;
+    this.width = width;
+    this.eps = 2.0 / width;
+    this.confidence = 1 - 1 / Math.pow(2, depth);
+    initTablesWith(depth, width, seed);
+  }
+
+  public CountMinSketchImpl(double eps, double confidence, int seed) {
+    // 2/w = eps ; w = 2/eps
+    // 1/2^depth <= 1-confidence ; depth >= -log2 (1-confidence)
+    this.eps = eps;
+    this.confidence = confidence;
+    this.width = (int) Math.ceil(2 / eps);
+    this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+    initTablesWith(depth, width, seed);
+  }
+
+  private void initTablesWith(int depth, int width, int seed) {
+    this.table = new long[depth][width];
+    this.hashA = new long[depth];
+    Random r = new Random(seed);
+    // We're using a linear hash functions
+    // of the form (a*x+b) mod p.
+    // a,b are chosen independently for each hash function.
+    // However we can set b = 0 as all it does is shift the results
+    // without compromising their uniformity or independence with
+    // the other hashes.
+    for (int i = 0; i < depth; ++i) {
+      hashA[i] = r.nextInt(Integer.MAX_VALUE);
+    }
+  }
+
+  @Override
+  public double relativeError() {
+    return eps;
+  }
+
+  @Override
+  public double confidence() {
+    return confidence;
+  }
+
+  @Override
+  public int depth() {
+    return depth;
+  }
+
+  @Override
+  public int width() {
+    return width;
+  }
+
+  @Override
+  public long totalCount() {
+    return totalCount;
+  }
+
+  @Override
+  public void add(Object item) {
+    add(item, 1);
+  }
+
+  @Override
+  public void add(Object item, long count) {
+    if (item instanceof String) {
+      addString((String) item, count);
+    } else {
+      long longValue;
+
+      if (item instanceof Long) {
+        longValue = (Long) item;
+      } else if (item instanceof Integer) {
+        longValue = ((Integer) item).longValue();
+      } else if (item instanceof Short) {
+        longValue = ((Short) item).longValue();
+      } else if (item instanceof Byte) {
+        longValue = ((Byte) item).longValue();
+      } else {
+        throw new IllegalArgumentException(
+          "Support for " + item.getClass().getName() + " not implemented"
+        );
+      }
+
+      addLong(longValue, count);
+    }
+  }
+
+  private void addString(String item, long count) {
+    if (count < 0) {
+      throw new IllegalArgumentException("Negative increments not implemented");
+    }
+
+    int[] buckets = getHashBuckets(item, depth, width);
+
+    for (int i = 0; i < depth; ++i) {
+      table[i][buckets[i]] += count;
+    }
+
+    totalCount += count;
+  }
+
+  private void addLong(long item, long count) {
+    if (count < 0) {
+      throw new IllegalArgumentException("Negative increments not implemented");
+    }
+
+    for (int i = 0; i < depth; ++i) {
+      table[i][hash(item, i)] += count;
+    }
+
+    totalCount += count;
+  }
+
+  private int hash(long item, int count) {
+    long hash = hashA[count] * item;
+    // A super fast way of computing x mod 2^p-1
+    // See http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
+    // page 149, right after Proposition 7.
+    hash += hash >> 32;
+    hash &= PRIME_MODULUS;
+    // Doing "%" after (int) conversion is ~2x faster than %'ing longs.
+    return ((int) hash) % width;
+  }
+
+  private static int[] getHashBuckets(String key, int hashCount, int max) {
+    byte[] b;
+    try {
+      b = key.getBytes("UTF-8");
+    } catch (UnsupportedEncodingException e) {
+      throw new RuntimeException(e);
+    }
+    return getHashBuckets(b, hashCount, max);
+  }
+
+  private static int[] getHashBuckets(byte[] b, int hashCount, int max) {
+    int[] result = new int[hashCount];
+    int hash1 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, 0);
+    int hash2 = Murmur3_x86_32.hashUnsafeBytes(b, Platform.BYTE_ARRAY_OFFSET, b.length, hash1);
+    for (int i = 0; i < hashCount; i++) {
+      result[i] = Math.abs((hash1 + i * hash2) % max);
+    }
+    return result;
+  }
+
+  @Override
+  public long estimateCount(Object item) {
+    if (item instanceof String) {
+      return estimateCountForStringItem((String) item);
+    } else {
+      long longValue;
+
+      if (item instanceof Long) {
+        longValue = (Long) item;
+      } else if (item instanceof Integer) {
+        longValue = ((Integer) item).longValue();
+      } else if (item instanceof Short) {
+        longValue = ((Short) item).longValue();
+      } else if (item instanceof Byte) {
+        longValue = ((Byte) item).longValue();
+      } else {
+        throw new IllegalArgumentException(
+            "Support for " + item.getClass().getName() + " not implemented"
+        );
+      }
+
+      return estimateCountForLongItem(longValue);
+    }
+  }
+
+  private long estimateCountForLongItem(long item) {
+    long res = Long.MAX_VALUE;
+    for (int i = 0; i < depth; ++i) {
+      res = Math.min(res, table[i][hash(item, i)]);
+    }
+    return res;
+  }
+
+  private long estimateCountForStringItem(String item) {
+    long res = Long.MAX_VALUE;
+    int[] buckets = getHashBuckets(item, depth, width);
+    for (int i = 0; i < depth; ++i) {
+      res = Math.min(res, table[i][buckets[i]]);
+    }
+    return res;
+  }
+
+  @Override
+  public CountMinSketch mergeInPlace(CountMinSketch other) {
+    if (other == null) {
+      throw new CMSMergeException("Cannot merge null estimator");
+    }
+
+    if (!(other instanceof CountMinSketchImpl)) {
+      throw new CMSMergeException("Cannot merge estimator of class " + other.getClass().getName());
+    }
+
+    CountMinSketchImpl that = (CountMinSketchImpl) other;
+
+    if (this.depth != that.depth) {
+      throw new CMSMergeException("Cannot merge estimators of different depth");
+    }
+
+    if (this.width != that.width) {
+      throw new CMSMergeException("Cannot merge estimators of different width");
+    }
+
+    if (!Arrays.equals(this.hashA, that.hashA)) {
+      throw new CMSMergeException("Cannot merge estimators of different seed");
+    }
+
+    for (int i = 0; i < this.table.length; ++i) {
+      for (int j = 0; j < this.table[i].length; ++j) {
+        this.table[i][j] = this.table[i][j] + that.table[i][j];
+      }
+    }
+
+    this.totalCount += that.totalCount;
+
+    return this;
+  }
+
+  @Override
+  public void writeTo(OutputStream out) {
+    throw new UnsupportedOperationException("Not implemented yet");
+  }
+
+  protected static class CMSMergeException extends RuntimeException {
+    public CMSMergeException(String message) {
+      super(message);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java
new file mode 100644
index 0000000..3d1f28b
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Murmur3_x86_32.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+/**
+ * 32-bit Murmur3 hasher.  This is based on Guava's Murmur3_32HashFunction.
+ */
+// This class is duplicated from `org.apache.spark.unsafe.hash.Murmur3_x86_32` to make sure
+// spark-sketch has no external dependencies.
+final class Murmur3_x86_32 {
+  private static final int C1 = 0xcc9e2d51;
+  private static final int C2 = 0x1b873593;
+
+  private final int seed;
+
+  public Murmur3_x86_32(int seed) {
+    this.seed = seed;
+  }
+
+  @Override
+  public String toString() {
+    return "Murmur3_32(seed=" + seed + ")";
+  }
+
+  public int hashInt(int input) {
+    return hashInt(input, seed);
+  }
+
+  public static int hashInt(int input, int seed) {
+    int k1 = mixK1(input);
+    int h1 = mixH1(seed, k1);
+
+    return fmix(h1, 4);
+  }
+
+  public int hashUnsafeWords(Object base, long offset, int lengthInBytes) {
+    return hashUnsafeWords(base, offset, lengthInBytes, seed);
+  }
+
+  public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
+    // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
+    assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+    int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
+    return fmix(h1, lengthInBytes);
+  }
+
+  public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
+    int lengthAligned = lengthInBytes - lengthInBytes % 4;
+    int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
+    for (int i = lengthAligned; i < lengthInBytes; i++) {
+      int halfWord = Platform.getByte(base, offset + i);
+      int k1 = mixK1(halfWord);
+      h1 = mixH1(h1, k1);
+    }
+    return fmix(h1, lengthInBytes);
+  }
+
+  private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
+    assert (lengthInBytes % 4 == 0);
+    int h1 = seed;
+    for (int i = 0; i < lengthInBytes; i += 4) {
+      int halfWord = Platform.getInt(base, offset + i);
+      int k1 = mixK1(halfWord);
+      h1 = mixH1(h1, k1);
+    }
+    return h1;
+  }
+
+  public int hashLong(long input) {
+    return hashLong(input, seed);
+  }
+
+  public static int hashLong(long input, int seed) {
+    int low = (int) input;
+    int high = (int) (input >>> 32);
+
+    int k1 = mixK1(low);
+    int h1 = mixH1(seed, k1);
+
+    k1 = mixK1(high);
+    h1 = mixH1(h1, k1);
+
+    return fmix(h1, 8);
+  }
+
+  private static int mixK1(int k1) {
+    k1 *= C1;
+    k1 = Integer.rotateLeft(k1, 15);
+    k1 *= C2;
+    return k1;
+  }
+
+  private static int mixH1(int h1, int k1) {
+    h1 ^= k1;
+    h1 = Integer.rotateLeft(h1, 13);
+    h1 = h1 * 5 + 0xe6546b64;
+    return h1;
+  }
+
+  // Finalization mix - force all bits of a hash block to avalanche
+  private static int fmix(int h1, int length) {
+    h1 ^= length;
+    h1 ^= h1 >>> 16;
+    h1 *= 0x85ebca6b;
+    h1 ^= h1 >>> 13;
+    h1 *= 0xc2b2ae35;
+    h1 ^= h1 >>> 16;
+    return h1;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java
new file mode 100644
index 0000000..75d6a6b
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/Platform.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.lang.reflect.Field;
+
+import sun.misc.Unsafe;
+
+// This class is duplicated from `org.apache.spark.unsafe.Platform` to make sure spark-sketch has no
+// external dependencies.
+final class Platform {
+
+  private static final Unsafe _UNSAFE;
+
+  public static final int BYTE_ARRAY_OFFSET;
+
+  public static final int INT_ARRAY_OFFSET;
+
+  public static final int LONG_ARRAY_OFFSET;
+
+  public static final int DOUBLE_ARRAY_OFFSET;
+
+  public static int getInt(Object object, long offset) {
+    return _UNSAFE.getInt(object, offset);
+  }
+
+  public static void putInt(Object object, long offset, int value) {
+    _UNSAFE.putInt(object, offset, value);
+  }
+
+  public static boolean getBoolean(Object object, long offset) {
+    return _UNSAFE.getBoolean(object, offset);
+  }
+
+  public static void putBoolean(Object object, long offset, boolean value) {
+    _UNSAFE.putBoolean(object, offset, value);
+  }
+
+  public static byte getByte(Object object, long offset) {
+    return _UNSAFE.getByte(object, offset);
+  }
+
+  public static void putByte(Object object, long offset, byte value) {
+    _UNSAFE.putByte(object, offset, value);
+  }
+
+  public static short getShort(Object object, long offset) {
+    return _UNSAFE.getShort(object, offset);
+  }
+
+  public static void putShort(Object object, long offset, short value) {
+    _UNSAFE.putShort(object, offset, value);
+  }
+
+  public static long getLong(Object object, long offset) {
+    return _UNSAFE.getLong(object, offset);
+  }
+
+  public static void putLong(Object object, long offset, long value) {
+    _UNSAFE.putLong(object, offset, value);
+  }
+
+  public static float getFloat(Object object, long offset) {
+    return _UNSAFE.getFloat(object, offset);
+  }
+
+  public static void putFloat(Object object, long offset, float value) {
+    _UNSAFE.putFloat(object, offset, value);
+  }
+
+  public static double getDouble(Object object, long offset) {
+    return _UNSAFE.getDouble(object, offset);
+  }
+
+  public static void putDouble(Object object, long offset, double value) {
+    _UNSAFE.putDouble(object, offset, value);
+  }
+
+  public static Object getObjectVolatile(Object object, long offset) {
+    return _UNSAFE.getObjectVolatile(object, offset);
+  }
+
+  public static void putObjectVolatile(Object object, long offset, Object value) {
+    _UNSAFE.putObjectVolatile(object, offset, value);
+  }
+
+  public static long allocateMemory(long size) {
+    return _UNSAFE.allocateMemory(size);
+  }
+
+  public static void freeMemory(long address) {
+    _UNSAFE.freeMemory(address);
+  }
+
+  public static void copyMemory(
+    Object src, long srcOffset, Object dst, long dstOffset, long length) {
+    // Check if dstOffset is before or after srcOffset to determine if we should copy
+    // forward or backwards. This is necessary in case src and dst overlap.
+    if (dstOffset < srcOffset) {
+      while (length > 0) {
+        long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+        _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+        length -= size;
+        srcOffset += size;
+        dstOffset += size;
+      }
+    } else {
+      srcOffset += length;
+      dstOffset += length;
+      while (length > 0) {
+        long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+        srcOffset -= size;
+        dstOffset -= size;
+        _UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+        length -= size;
+      }
+
+    }
+  }
+
+  /**
+   * Raises an exception bypassing compiler checks for checked exceptions.
+   */
+  public static void throwException(Throwable t) {
+    _UNSAFE.throwException(t);
+  }
+
+  /**
+   * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
+   * allow safepoint polling during a large copy.
+   */
+  private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
+
+  static {
+    sun.misc.Unsafe unsafe;
+    try {
+      Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
+      unsafeField.setAccessible(true);
+      unsafe = (sun.misc.Unsafe) unsafeField.get(null);
+    } catch (Throwable cause) {
+      unsafe = null;
+    }
+    _UNSAFE = unsafe;
+
+    if (_UNSAFE != null) {
+      BYTE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(byte[].class);
+      INT_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(int[].class);
+      LONG_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(long[].class);
+      DOUBLE_ARRAY_OFFSET = _UNSAFE.arrayBaseOffset(double[].class);
+    } else {
+      BYTE_ARRAY_OFFSET = 0;
+      INT_ARRAY_OFFSET = 0;
+      LONG_ARRAY_OFFSET = 0;
+      DOUBLE_ARRAY_OFFSET = 0;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
----------------------------------------------------------------------
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
new file mode 100644
index 0000000..ec5b4ed
--- /dev/null
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite
+  private val epsOfTotalCount = 0.0001
+
+  private val confidence = 0.99
+
+  private val seed = 42
+
+  def testAccuracy[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+    test(s"accuracy - $typeName") {
+      val r = new Random()
+
+      val numAllItems = 1000000
+      val allItems = Array.fill(numAllItems)(itemGenerator(r))
+
+      val numSamples = numAllItems / 10
+      val sampledItemIndices = Array.fill(numSamples)(r.nextInt(numAllItems))
+
+      val exactFreq = {
+        val sampledItems = sampledItemIndices.map(allItems)
+        sampledItems.groupBy(identity).mapValues(_.length.toLong)
+      }
+
+      val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+      sampledItemIndices.foreach(i => sketch.add(allItems(i)))
+
+      val probCorrect = {
+        val numErrors = allItems.map { item =>
+          val count = exactFreq.getOrElse(item, 0L)
+          val ratio = (sketch.estimateCount(item) - count).toDouble / numAllItems
+          if (ratio > epsOfTotalCount) 1 else 0
+        }.sum
+
+        1D - numErrors.toDouble / numAllItems
+      }
+
+      assert(
+        probCorrect > confidence,
+        s"Confidence not reached: required $confidence, reached $probCorrect"
+      )
+    }
+  }
+
+  def testMergeInPlace[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+    test(s"mergeInPlace - $typeName") {
+      val r = new Random()
+      val numToMerge = 5
+      val numItemsPerSketch = 100000
+      val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) {
+        itemGenerator(r)
+      }
+
+      val sketches = perSketchItems.map { items =>
+        val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+        items.foreach(sketch.add)
+        sketch
+      }
+
+      val mergedSketch = sketches.reduce(_ mergeInPlace _)
+
+      val expectedSketch = {
+        val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed)
+        perSketchItems.foreach(_.foreach(sketch.add))
+        sketch
+      }
+
+      perSketchItems.foreach {
+        _.foreach { item =>
+          assert(mergedSketch.estimateCount(item) === expectedSketch.estimateCount(item))
+        }
+      }
+    }
+  }
+
+  def testItemType[T: ClassTag](typeName: String)(itemGenerator: Random => T): Unit = {
+    testAccuracy[T](typeName)(itemGenerator)
+    testMergeInPlace[T](typeName)(itemGenerator)
+  }
+
+  testItemType[Byte]("Byte") { _.nextInt().toByte }
+
+  testItemType[Short]("Short") { _.nextInt().toShort }
+
+  testItemType[Int]("Int") { _.nextInt() }
+
+  testItemType[Long]("Long") { _.nextLong() }
+
+  testItemType[String]("String") { r => r.nextString(r.nextInt(20)) }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index efe58ea..032c061 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -113,6 +113,18 @@ hive_thriftserver = Module(
 )
 
 
+sketch = Module(
+    name="sketch",
+    dependencies=[],
+    source_file_regexes=[
+        "common/sketch/",
+    ],
+    sbt_test_goals=[
+        "sketch/test"
+    ]
+)
+
+
 graphx = Module(
     name="graphx",
     dependencies=[],

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index f08642f..fb77506 100644
--- a/pom.xml
+++ b/pom.xml
@@ -86,6 +86,7 @@
   </mailingLists>
 
   <modules>
+    <module>common/sketch</module>
     <module>tags</module>
     <module>core</module>
     <module>graphx</module>

http://git-wip-us.apache.org/repos/asf/spark/blob/1c690dda/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 3927b88..4224a65 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -34,13 +34,24 @@ object BuildCommons {
 
   private val buildLocation = file(".").getAbsoluteFile.getParentFile
 
-  val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
-    sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka,
-    streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) =
-    Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
-      "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
-      "streaming-flume", "streaming-akka", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
-      "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _))
+  val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer) = Seq(
+    "catalyst", "sql", "hive", "hive-thriftserver"
+  ).map(ProjectRef(buildLocation, _))
+
+  val streamingProjects@Seq(
+    streaming, streamingFlumeSink, streamingFlume, streamingAkka, streamingKafka, streamingMqtt,
+    streamingTwitter, streamingZeromq
+  ) = Seq(
+    "streaming", "streaming-flume-sink", "streaming-flume", "streaming-akka", "streaming-kafka",
+    "streaming-mqtt", "streaming-twitter", "streaming-zeromq"
+  ).map(ProjectRef(buildLocation, _))
+
+  val allProjects@Seq(
+    core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _*
+  ) = Seq(
+    "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe",
+    "test-tags", "sketch"
+  ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects
 
   val optionallyEnabledProjects@Seq(yarn, java8Tests, sparkGangliaLgpl,
     streamingKinesisAsl, dockerIntegrationTests) =
@@ -232,11 +243,15 @@ object SparkBuild extends PomBuild {
   /* Enable tests settings for all projects except examples, assembly and tools */
   (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
 
-  // TODO: remove streamingAkka from this list after 2.0.0
-  allProjects.filterNot(x => Seq(spark, hive, hiveThriftServer, catalyst, repl,
-    networkCommon, networkShuffle, networkYarn, unsafe, streamingAkka, testTags).contains(x)).foreach {
-      x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
-    }
+  // TODO: remove streamingAkka and sketch from this list after 2.0.0
+  allProjects.filterNot { x =>
+    Seq(
+      spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
+      unsafe, streamingAkka, testTags, sketch
+    ).contains(x)
+  }.foreach { x =>
+    enable(MimaBuild.mimaSettings(sparkHome, x))(x)
+  }
 
   /* Unsafe settings */
   enable(Unsafe.settings)(unsafe)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org