You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/01/26 02:58:14 UTC

spark git commit: [SPARK-12936][SQL] Initial bloom filter implementation

Repository: spark
Updated Branches:
  refs/heads/master be375fcbd -> 109061f7a


[SPARK-12936][SQL] Initial bloom filter implementation

This PR adds an initial implementation of bloom filter in the newly added sketch module.  The implementation is based on the [`BloomFilter` class in guava](https://code.google.com/p/guava-libraries/source/browse/guava/src/com/google/common/hash/BloomFilter.java).

Some difference from the design doc:

* expose `bitSize` instead of `sizeInBytes` to user.
* always need the `expectedInsertions` parameter when create bloom filter.

Author: Wenchen Fan <we...@databricks.com>

Closes #10883 from cloud-fan/bloom-filter.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/109061f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/109061f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/109061f7

Branch: refs/heads/master
Commit: 109061f7ad27225669cbe609ec38756b31d4e1b9
Parents: be375fc
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Jan 25 17:58:11 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon Jan 25 17:58:11 2016 -0800

----------------------------------------------------------------------
 .../org/apache/spark/util/sketch/BitArray.java  |  94 +++++++++++
 .../apache/spark/util/sketch/BloomFilter.java   | 153 +++++++++++++++++
 .../spark/util/sketch/BloomFilterImpl.java      | 164 +++++++++++++++++++
 .../spark/util/sketch/BitArraySuite.scala       |  77 +++++++++
 .../spark/util/sketch/BloomFilterSuite.scala    | 114 +++++++++++++
 5 files changed, 602 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/109061f7/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
new file mode 100644
index 0000000..1bc665a
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.util.Arrays;
+
+public final class BitArray {
+  private final long[] data;
+  private long bitCount;
+
+  static int numWords(long numBits) {
+    long numWords = (long) Math.ceil(numBits / 64.0);
+    if (numWords > Integer.MAX_VALUE) {
+      throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits");
+    }
+    return (int) numWords;
+  }
+
+  BitArray(long numBits) {
+    if (numBits <= 0) {
+      throw new IllegalArgumentException("numBits must be positive");
+    }
+    this.data = new long[numWords(numBits)];
+    long bitCount = 0;
+    for (long value : data) {
+      bitCount += Long.bitCount(value);
+    }
+    this.bitCount = bitCount;
+  }
+
+  /** Returns true if the bit changed value. */
+  boolean set(long index) {
+    if (!get(index)) {
+      data[(int) (index >>> 6)] |= (1L << index);
+      bitCount++;
+      return true;
+    }
+    return false;
+  }
+
+  boolean get(long index) {
+    return (data[(int) (index >>> 6)] & (1L << index)) != 0;
+  }
+
+  /** Number of bits */
+  long bitSize() {
+    return (long) data.length * Long.SIZE;
+  }
+
+  /** Number of set bits (1s) */
+  long cardinality() {
+    return bitCount;
+  }
+
+  /** Combines the two BitArrays using bitwise OR. */
+  void putAll(BitArray array) {
+    assert data.length == array.data.length : "BitArrays must be of equal length when merging";
+    long bitCount = 0;
+    for (int i = 0; i < data.length; i++) {
+      data[i] |= array.data[i];
+      bitCount += Long.bitCount(data[i]);
+    }
+    this.bitCount = bitCount;
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || !(o instanceof BitArray)) return false;
+
+    BitArray bitArray = (BitArray) o;
+    return Arrays.equals(data, bitArray.data);
+  }
+
+  @Override
+  public int hashCode() {
+    return Arrays.hashCode(data);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/109061f7/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
new file mode 100644
index 0000000..38949c6
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilter.java
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+/**
+ * A Bloom filter is a space-efficient probabilistic data structure, that is used to test whether
+ * an element is a member of a set. It returns false when the element is definitely not in the
+ * set, returns true when the element is probably in the set.
+ *
+ * Internally a Bloom filter is initialized with 2 information: how many space to use(number of
+ * bits) and how many hash values to calculate for each record.  To get as lower false positive
+ * probability as possible, user should call {@link BloomFilter#create} to automatically pick a
+ * best combination of these 2 parameters.
+ *
+ * Currently the following data types are supported:
+ * <ul>
+ *   <li>{@link Byte}</li>
+ *   <li>{@link Short}</li>
+ *   <li>{@link Integer}</li>
+ *   <li>{@link Long}</li>
+ *   <li>{@link String}</li>
+ * </ul>
+ *
+ * The implementation is largely based on the {@code BloomFilter} class from guava.
+ */
+public abstract class BloomFilter {
+  /**
+   * Returns the false positive probability, i.e. the probability that
+   * {@linkplain #mightContain(Object)} will erroneously return {@code true} for an object that
+   * has not actually been put in the {@code BloomFilter}.
+   *
+   * <p>Ideally, this number should be close to the {@code fpp} parameter
+   * passed in to create this bloom filter, or smaller. If it is
+   * significantly higher, it is usually the case that too many elements (more than
+   * expected) have been put in the {@code BloomFilter}, degenerating it.
+   */
+  public abstract double expectedFpp();
+
+  /**
+   * Returns the number of bits in the underlying bit array.
+   */
+  public abstract long bitSize();
+
+  /**
+   * Puts an element into this {@code BloomFilter}. Ensures that subsequent invocations of
+   * {@link #mightContain(Object)} with the same element will always return {@code true}.
+   *
+   * @return true if the bloom filter's bits changed as a result of this operation. If the bits
+   *     changed, this is <i>definitely</i> the first time {@code object} has been added to the
+   *     filter. If the bits haven't changed, this <i>might</i> be the first time {@code object}
+   *     has been added to the filter. Note that {@code put(t)} always returns the
+   *     <i>opposite</i> result to what {@code mightContain(t)} would have returned at the time
+   *     it is called.
+   */
+  public abstract boolean put(Object item);
+
+  /**
+   * Determines whether a given bloom filter is compatible with this bloom filter. For two
+   * bloom filters to be compatible, they must have the same bit size.
+   *
+   * @param other The bloom filter to check for compatibility.
+   */
+  public abstract boolean isCompatible(BloomFilter other);
+
+  /**
+   * Combines this bloom filter with another bloom filter by performing a bitwise OR of the
+   * underlying data. The mutations happen to <b>this</b> instance. Callers must ensure the
+   * bloom filters are appropriately sized to avoid saturating them.
+   *
+   * @param other The bloom filter to combine this bloom filter with. It is not mutated.
+   * @throws IllegalArgumentException if {@code isCompatible(that) == false}
+   */
+  public abstract BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException;
+
+  /**
+   * Returns {@code true} if the element <i>might</i> have been put in this Bloom filter,
+   * {@code false} if this is <i>definitely</i> not the case.
+   */
+  public abstract boolean mightContain(Object item);
+
+  /**
+   * Computes the optimal k (number of hashes per element inserted in Bloom filter), given the
+   * expected insertions and total number of bits in the Bloom filter.
+   *
+   * See http://en.wikipedia.org/wiki/File:Bloom_filter_fp_probability.svg for the formula.
+   *
+   * @param n expected insertions (must be positive)
+   * @param m total number of bits in Bloom filter (must be positive)
+   */
+  private static int optimalNumOfHashFunctions(long n, long m) {
+    // (m / n) * log(2), but avoid truncation due to division!
+    return Math.max(1, (int) Math.round((double) m / n * Math.log(2)));
+  }
+
+  /**
+   * Computes m (total bits of Bloom filter) which is expected to achieve, for the specified
+   * expected insertions, the required false positive probability.
+   *
+   * See http://en.wikipedia.org/wiki/Bloom_filter#Probability_of_false_positives for the formula.
+   *
+   * @param n expected insertions (must be positive)
+   * @param p false positive rate (must be 0 < p < 1)
+   */
+  private static long optimalNumOfBits(long n, double p) {
+    return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2)));
+  }
+
+  static final double DEFAULT_FPP = 0.03;
+
+  /**
+   * Creates a {@link BloomFilter} with given {@code expectedNumItems} and the default {@code fpp}.
+   */
+  public static BloomFilter create(long expectedNumItems) {
+    return create(expectedNumItems, DEFAULT_FPP);
+  }
+
+  /**
+   * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code fpp}, it will pick
+   * an optimal {@code numBits} and {@code numHashFunctions} for the bloom filter.
+   */
+  public static BloomFilter create(long expectedNumItems, double fpp) {
+    assert fpp > 0.0 : "False positive probability must be > 0.0";
+    assert fpp < 1.0 : "False positive probability must be < 1.0";
+    long numBits = optimalNumOfBits(expectedNumItems, fpp);
+    return create(expectedNumItems, numBits);
+  }
+
+  /**
+   * Creates a {@link BloomFilter} with given {@code expectedNumItems} and {@code numBits}, it will
+   * pick an optimal {@code numHashFunctions} which can minimize {@code fpp} for the bloom filter.
+   */
+  public static BloomFilter create(long expectedNumItems, long numBits) {
+    assert expectedNumItems > 0 : "Expected insertions must be > 0";
+    assert numBits > 0 : "number of bits must be > 0";
+    int numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits);
+    return new BloomFilterImpl(numHashFunctions, numBits);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/109061f7/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
----------------------------------------------------------------------
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
new file mode 100644
index 0000000..bbd6cf7
--- /dev/null
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/BloomFilterImpl.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch;
+
+import java.io.UnsupportedEncodingException;
+
+public class BloomFilterImpl extends BloomFilter {
+
+  private final int numHashFunctions;
+  private final BitArray bits;
+
+  BloomFilterImpl(int numHashFunctions, long numBits) {
+    this.numHashFunctions = numHashFunctions;
+    this.bits = new BitArray(numBits);
+  }
+
+  @Override
+  public double expectedFpp() {
+    return Math.pow((double) bits.cardinality() / bits.bitSize(), numHashFunctions);
+  }
+
+  @Override
+  public long bitSize() {
+    return bits.bitSize();
+  }
+
+  private static long hashObjectToLong(Object item) {
+    if (item instanceof String) {
+      try {
+        byte[] bytes = ((String) item).getBytes("utf-8");
+        return hashBytesToLong(bytes);
+      } catch (UnsupportedEncodingException e) {
+        throw new RuntimeException("Only support utf-8 string", e);
+      }
+    } else {
+      long longValue;
+
+      if (item instanceof Long) {
+        longValue = (Long) item;
+      } else if (item instanceof Integer) {
+        longValue = ((Integer) item).longValue();
+      } else if (item instanceof Short) {
+        longValue = ((Short) item).longValue();
+      } else if (item instanceof Byte) {
+        longValue = ((Byte) item).longValue();
+      } else {
+        throw new IllegalArgumentException(
+          "Support for " + item.getClass().getName() + " not implemented"
+        );
+      }
+
+      int h1 = Murmur3_x86_32.hashLong(longValue, 0);
+      int h2 = Murmur3_x86_32.hashLong(longValue, h1);
+      return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+    }
+  }
+
+  private static long hashBytesToLong(byte[] bytes) {
+    int h1 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0);
+    int h2 = Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, h1);
+    return (((long) h1) << 32) | (h2 & 0xFFFFFFFFL);
+  }
+
+  @Override
+  public boolean put(Object item) {
+    long bitSize = bits.bitSize();
+
+    // Here we first hash the input element into 2 int hash values, h1 and h2, then produce n hash
+    // values by `h1 + i * h2` with 1 <= i <= numHashFunctions.
+    // Note that `CountMinSketch` use a different strategy for long type, it hash the input long
+    // element with every i to produce n hash values.
+    long hash64 = hashObjectToLong(item);
+    int h1 = (int) (hash64 >> 32);
+    int h2 = (int) hash64;
+
+    boolean bitsChanged = false;
+    for (int i = 1; i <= numHashFunctions; i++) {
+      int combinedHash = h1 + (i * h2);
+      // Flip all the bits if it's negative (guaranteed positive number)
+      if (combinedHash < 0) {
+        combinedHash = ~combinedHash;
+      }
+      bitsChanged |= bits.set(combinedHash % bitSize);
+    }
+    return bitsChanged;
+  }
+
+  @Override
+  public boolean mightContain(Object item) {
+    long bitSize = bits.bitSize();
+    long hash64 = hashObjectToLong(item);
+    int h1 = (int) (hash64 >> 32);
+    int h2 = (int) hash64;
+
+    for (int i = 1; i <= numHashFunctions; i++) {
+      int combinedHash = h1 + (i * h2);
+      // Flip all the bits if it's negative (guaranteed positive number)
+      if (combinedHash < 0) {
+        combinedHash = ~combinedHash;
+      }
+      if (!bits.get(combinedHash % bitSize)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  @Override
+  public boolean isCompatible(BloomFilter other) {
+    if (other == null) {
+      return false;
+    }
+
+    if (!(other instanceof BloomFilterImpl)) {
+      return false;
+    }
+
+    BloomFilterImpl that = (BloomFilterImpl) other;
+    return this.bitSize() == that.bitSize() && this.numHashFunctions == that.numHashFunctions;
+  }
+
+  @Override
+  public BloomFilter mergeInPlace(BloomFilter other) throws IncompatibleMergeException {
+    // Duplicates the logic of `isCompatible` here to provide better error message.
+    if (other == null) {
+      throw new IncompatibleMergeException("Cannot merge null bloom filter");
+    }
+
+    if (!(other instanceof BloomFilter)) {
+      throw new IncompatibleMergeException(
+        "Cannot merge bloom filter of class " + other.getClass().getName()
+      );
+    }
+
+    BloomFilterImpl that = (BloomFilterImpl) other;
+
+    if (this.bitSize() != that.bitSize()) {
+      throw new IncompatibleMergeException("Cannot merge bloom filters with different bit size");
+    }
+
+    if (this.numHashFunctions != that.numHashFunctions) {
+      throw new IncompatibleMergeException(
+        "Cannot merge bloom filters with different number of hash functions");
+    }
+
+    this.bits.putAll(that.bits);
+    return this;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/109061f7/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala
----------------------------------------------------------------------
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala
new file mode 100644
index 0000000..ff728f0
--- /dev/null
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BitArraySuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch
+
+import scala.util.Random
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+class BitArraySuite extends FunSuite { // scalastyle:ignore funsuite
+
+  test("error case when create BitArray") {
+    intercept[IllegalArgumentException](new BitArray(0))
+    intercept[IllegalArgumentException](new BitArray(64L * Integer.MAX_VALUE + 1))
+  }
+
+  test("bitSize") {
+    assert(new BitArray(64).bitSize() == 64)
+    // BitArray is word-aligned, so 65~128 bits need 2 long to store, which is 128 bits.
+    assert(new BitArray(65).bitSize() == 128)
+    assert(new BitArray(127).bitSize() == 128)
+    assert(new BitArray(128).bitSize() == 128)
+  }
+
+  test("set") {
+    val bitArray = new BitArray(64)
+    assert(bitArray.set(1))
+    // Only returns true if the bit changed.
+    assert(!bitArray.set(1))
+    assert(bitArray.set(2))
+  }
+
+  test("normal operation") {
+    // use a fixed seed to make the test predictable.
+    val r = new Random(37)
+
+    val bitArray = new BitArray(320)
+    val indexes = (1 to 100).map(_ => r.nextInt(320).toLong).distinct
+
+    indexes.foreach(bitArray.set)
+    indexes.foreach(i => assert(bitArray.get(i)))
+    assert(bitArray.cardinality() == indexes.length)
+  }
+
+  test("merge") {
+    // use a fixed seed to make the test predictable.
+    val r = new Random(37)
+
+    val bitArray1 = new BitArray(64 * 6)
+    val bitArray2 = new BitArray(64 * 6)
+
+    val indexes1 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct
+    val indexes2 = (1 to 100).map(_ => r.nextInt(64 * 6).toLong).distinct
+
+    indexes1.foreach(bitArray1.set)
+    indexes2.foreach(bitArray2.set)
+
+    bitArray1.putAll(bitArray2)
+    indexes1.foreach(i => assert(bitArray1.get(i)))
+    indexes2.foreach(i => assert(bitArray1.get(i)))
+    assert(bitArray1.cardinality() == (indexes1 ++ indexes2).distinct.length)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/109061f7/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
----------------------------------------------------------------------
diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
new file mode 100644
index 0000000..d2de509
--- /dev/null
+++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/BloomFilterSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.sketch
+
+import scala.reflect.ClassTag
+import scala.util.Random
+
+import org.scalatest.FunSuite // scalastyle:ignore funsuite
+
+class BloomFilterSuite extends FunSuite { // scalastyle:ignore funsuite
+  private final val EPSILON = 0.01
+
+  def testAccuracy[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
+    test(s"accuracy - $typeName") {
+      // use a fixed seed to make the test predictable.
+      val r = new Random(37)
+      val fpp = 0.05
+      val numInsertion = numItems / 10
+
+      val allItems = Array.fill(numItems)(itemGen(r))
+
+      val filter = BloomFilter.create(numInsertion, fpp)
+
+      // insert first `numInsertion` items.
+      allItems.take(numInsertion).foreach(filter.put)
+
+      // false negative is not allowed.
+      assert(allItems.take(numInsertion).forall(filter.mightContain))
+
+      // The number of inserted items doesn't exceed `expectedNumItems`, so the `expectedFpp`
+      // should not be significantly higher than the one we passed in to create this bloom filter.
+      assert(filter.expectedFpp() - fpp < EPSILON)
+
+      val errorCount = allItems.drop(numInsertion).count(filter.mightContain)
+
+      // Also check the actual fpp is not significantly higher than we expected.
+      val actualFpp = errorCount.toDouble / (numItems - numInsertion)
+      assert(actualFpp - fpp < EPSILON)
+    }
+  }
+
+  def testMergeInPlace[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
+    test(s"mergeInPlace - $typeName") {
+      // use a fixed seed to make the test predictable.
+      val r = new Random(37)
+
+      val items1 = Array.fill(numItems / 2)(itemGen(r))
+      val items2 = Array.fill(numItems / 2)(itemGen(r))
+
+      val filter1 = BloomFilter.create(numItems)
+      items1.foreach(filter1.put)
+
+      val filter2 = BloomFilter.create(numItems)
+      items2.foreach(filter2.put)
+
+      filter1.mergeInPlace(filter2)
+
+      // After merge, `filter1` has `numItems` items which doesn't exceed `expectedNumItems`, so the
+      // `expectedFpp` should not be significantly higher than the default one.
+      assert(filter1.expectedFpp() - BloomFilter.DEFAULT_FPP < EPSILON)
+
+      items1.foreach(i => assert(filter1.mightContain(i)))
+      items2.foreach(i => assert(filter1.mightContain(i)))
+    }
+  }
+
+  def testItemType[T: ClassTag](typeName: String, numItems: Int)(itemGen: Random => T): Unit = {
+    testAccuracy[T](typeName, numItems)(itemGen)
+    testMergeInPlace[T](typeName, numItems)(itemGen)
+  }
+
+  testItemType[Byte]("Byte", 160) { _.nextInt().toByte }
+
+  testItemType[Short]("Short", 1000) { _.nextInt().toShort }
+
+  testItemType[Int]("Int", 100000) { _.nextInt() }
+
+  testItemType[Long]("Long", 100000) { _.nextLong() }
+
+  testItemType[String]("String", 100000) { r => r.nextString(r.nextInt(512)) }
+
+  test("incompatible merge") {
+    intercept[IncompatibleMergeException] {
+      BloomFilter.create(1000).mergeInPlace(null)
+    }
+
+    intercept[IncompatibleMergeException] {
+      val filter1 = BloomFilter.create(1000, 6400)
+      val filter2 = BloomFilter.create(1000, 3200)
+      filter1.mergeInPlace(filter2)
+    }
+
+    intercept[IncompatibleMergeException] {
+      val filter1 = BloomFilter.create(1000, 6400)
+      val filter2 = BloomFilter.create(2000, 6400)
+      filter1.mergeInPlace(filter2)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org