You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2018/01/24 23:13:49 UTC

[beam] 01/04: [BEAM-2728] Add Count-Min Sketch in sketching extension

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 34745056d063ac2f3c77d71f75d209639bb63cc2
Author: ArnaudFnr <ar...@gmail.com>
AuthorDate: Tue Nov 28 11:37:13 2017 +0100

    [BEAM-2728] Add Count-Min Sketch in sketching extension
---
 sdks/java/extensions/sketching/pom.xml             |   1 +
 .../extensions/sketching/SketchFrequencies.java    | 544 +++++++++++++++++++++
 .../sketching/SketchFrequenciesTest.java           | 204 ++++++++
 3 files changed, 749 insertions(+)

diff --git a/sdks/java/extensions/sketching/pom.xml b/sdks/java/extensions/sketching/pom.xml
index f0538ae..7390218 100755
--- a/sdks/java/extensions/sketching/pom.xml
+++ b/sdks/java/extensions/sketching/pom.xml
@@ -39,6 +39,7 @@
       <artifactId>beam-sdks-java-core</artifactId>
     </dependency>
 
+    <!-- Library containing sketches' implementation -->
     <dependency>
       <groupId>com.clearspring.analytics</groupId>
       <artifactId>stream</artifactId>
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java
new file mode 100644
index 0000000..9872bcd
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java
@@ -0,0 +1,544 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching;
+
+import com.clearspring.analytics.stream.frequency.CountMinSketch;
+import com.clearspring.analytics.stream.frequency.FrequencyMergeException;
+import com.google.auto.value.AutoValue;
+import com.google.common.hash.Hashing;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * {@code PTransform}s to compute the estimate frequency of each element in a stream.
+ *
+ * <p>This class uses the Count-min Sketch structure which allows very efficient queries
+ * on the data stream summarization.
+ *
+ * <h2>References</h2>
+ *
+ * <p>The implementation comes from <a href="https://github.com/addthis/stream-lib">
+ * Addthis' Stream-lib library</a>. <br>
+ * The papers and other useful information about Count-Min Sketch are available on <a
+ * href="https://sites.google.com/site/countminsketch/">this website</a>. <br>
+ *
+ * <h2>Parameters</h2>
+ *
+ * <p>Two parameters can be tuned in order to control the accuracy of the computation:
+ *
+ * <ul>
+ *   <li><b>Relative Error:</b> <br>
+ *       The relative error "{@code epsilon}" controls the accuracy of the estimation.
+ *       By default, the relative is around {@code 1%} of the total count.
+ *   <li><b>Confidence</b> <br>
+ *       The relative error can be guaranteed only with a certain "{@code confidence}",
+ *       between 0 and 1 (1 being of course impossible). <br>
+ *       The default value is set to 0.999 meaning that we can guarantee
+ *       that the relative error will not exceed 1% of the total count in 99.9% of cases.
+ * </ul>
+ *
+ * <p>These two parameters will determine the size of the Count-min sketch, which is
+ * a two-dimensional array with depth and width defined as follows :
+ * <ul>
+ *   <li>{@code width = ceil(2 / epsilon)}</li>
+ *   <li>{@code depth = ceil(-log(1 - confidence) / log(2))}</li>
+ * </ul>
+ *
+ * <p>With the default values, this gives a depth of 200 and a width of 10.
+ *
+ * <p><b>WARNING:</b> The relative error concerns the total number of distinct number elements
+ * in the stream. Thus, an element having 1000 occurrences in a streams of 1 million distinct
+ * elements will have 1% of 1 million as relative error, i.e. 10 000. This means the frequency
+ * is 1000 +/- 10 000 for this element. Therefore this is obvious that the relative error must
+ * be really low in very large streams. <br>
+ * Also keep in mind that this algorithm works well on highly skewed data but gives poor
+ * results if the elements are evenly distributed.
+ *
+ * <h2>Examples</h2>
+ *
+ * <p>There are 2 ways of using this class:
+ *
+ * <ul>
+ *   <li>Use the {@link PTransform}s that return a {@link PCollection} singleton that contains
+ *       a Count-min sketch for querying the estimate number of hits of the elements.
+ *   <li>Use the {@link CountMinSketchFn} {@code CombineFn} that is exposed in order to make
+ *       advanced processing involving the Count-Min sketch.
+ * </ul>
+ *
+ * <h3>Example 1: simple default use</h3>
+ *
+ * <p>The simplest use is simply to call the {@link #globally()} or {@link #perKey()} method in
+ * order to retrieve the sketch with an estimate number of hits for each element in the stream.
+ *
+ * <pre><code>
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal        .<MyObject>}globally()); //{@literal .<MyObject>}perKey();
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 2: tune accuracy parameters</h3>
+ *
+ * <p>One can tune the {@code epsilon} and {@code confidence} parameters in order to
+ * control accuracy and memory. <br>
+ * The tuning works exactly the same for {@link #globally()} and {@link #perKey()}.
+ *
+ * <pre><code>
+ *  double eps = 0.001;
+ *  double conf = 0.9999;
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal  .<MyObject>}globally() //{@literal .<MyObject>}perKey()
+ *            .withRelativeError(eps)
+ *            .withConfidence(conf));
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 3: query the resulting sketch</h3>
+ *
+ * <p>This example shows how to query the resulting {@link Sketch}.
+ * To estimate the number of hits of an element, one has to use
+ * {@link Sketch#estimateCount(Object, Coder)} method and to provide
+ * the coder for the element type. <br>
+ * For instance, one can build a KV Pair linking each element to an estimation
+ * of its frequency, using the sketch as side input of a {@link ParDo}. <br>
+ *
+ * <pre><code>
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal       .<MyObject>}globally());
+ *
+ * // Retrieve the coder for MyObject
+ * final{@literal Coder<MyObject>} = pc.getCoder();
+ * // build a View of the sketch so it can be passed a sideInput
+ * final{@literal PCollectionView<CountMinSketch>} sketchView = sketch.apply(View
+ * {@literal       .<CountMinSketch>}asSingleton());
+ *
+ * {@literal PCollection<KV<MyObject, Long>>} pairs = pc.apply(ParDo.of(
+ *        {@literal new DoFn<Long, KV<MyObject, Long>>()} {
+ *          {@literal @ProcessElement}
+ *           public void procesElement(ProcessContext c) {
+ *             Long object = c.element();
+ *             CountMinSketch sketch = c.sideInput(sketchView);
+ *             sketch.estimateCount(elem, coder);
+ *            }}).withSideInputs(sketchView));
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 4 : Using the CombineFn</h3>
+ *
+ * <p>The {@code CombineFn} does the same thing as the {@code PTransform}s but
+ * it can be used for doing stateful processing or in
+ * {@link org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn}.
+ *
+ * <p>This example is not really interesting but it shows how you can properly create
+ * a {@link CountMinSketchFn}. One must always specify a coder using the {@link
+ * CountMinSketchFn#create(Coder)} method.
+ *
+ * <pre><code>
+ *  double eps = 0.0001;
+ *  double conf = 0.9999;
+ * {@literal PCollection<MyObject>} input = ...;
+ * {@literal PCollection<CountMinSketch>} output = input.apply(Combine.globally(CountMinSketchFn
+ * {@literal    .<MyObject>}create(new MyObjectCoder())
+ *              .withAccuracy(eps, conf)));
+ * }
+ * </code></pre>
+ *
+ * <p><b>Warning: this class is experimental.</b> <br>
+ * Its API is subject to change in future versions of Beam.
+ */
+@Experimental
+public final class SketchFrequencies {
+
+  /**
+   * Create the {@link PTransform} that will build a Count-min sketch for keeping track
+   * of the frequency of the elements in the whole stream.
+   *
+   * <p>It returns a {@code PCollection<{@link CountMinSketch}>}  that can be queried in order to
+   * obtain estimations of the elements' frequencies.
+   *
+   * @param <InputT> the type of the elements in the input {@link PCollection}
+   */
+  public static <InputT> GlobalSketch<InputT> globally() {
+    return GlobalSketch.<InputT>builder().build();
+  }
+
+  /**
+   * Like {@link #globally()} but per key, i.e a Count-min sketch per key
+   * in {@code  PCollection<KV<K, V>>} and returns a
+   * {@code PCollection<KV<K, {@link CountMinSketch}>>}.
+   *
+   * @param <K> type of the keys mapping the elements
+   * @param <V> type of the values being combined per key
+   */
+  public static <K, V> PerKeySketch<K, V> perKey() {
+    return PerKeySketch.<K, V>builder().build();
+  }
+
+  /**
+   * Implementation of {@link #globally()}.
+   *
+   * @param <InputT>
+   */
+  @AutoValue
+  public abstract static class GlobalSketch<InputT>
+          extends PTransform<PCollection<InputT>, PCollection<Sketch<InputT>>> {
+
+    abstract double relativeError();
+
+    abstract double confidence();
+
+    abstract Builder<InputT> toBuilder();
+
+    static <InputT> Builder<InputT> builder() {
+      return new AutoValue_SketchFrequencies_GlobalSketch.Builder<InputT>()
+              .setRelativeError(0.01)
+              .setConfidence(0.999);
+    }
+
+    @AutoValue.Builder
+    abstract static class Builder<InputT> {
+      abstract Builder<InputT> setRelativeError(double eps);
+
+      abstract Builder<InputT> setConfidence(double conf);
+
+      abstract GlobalSketch<InputT> build();
+    }
+
+    public GlobalSketch<InputT> withRelativeError(double eps) {
+      return toBuilder().setRelativeError(eps).build();
+    }
+
+    public GlobalSketch<InputT> withConfidence(double conf) {
+      return toBuilder().setConfidence(conf).build();
+    }
+
+    @Override
+    public PCollection<Sketch<InputT>> expand(PCollection<InputT> input) {
+      return input.apply("Compute Count-Min Sketch",
+                      Combine.<InputT, Sketch<InputT>>globally(CountMinSketchFn
+                              .<InputT>create(input.getCoder())
+                              .withAccuracy(this.relativeError(), this.confidence())));
+    }
+  }
+
+  /**
+   * Implementation of {@link #perKey()}.
+   *
+   * @param <K>
+   * @param <V>
+   */
+  @AutoValue
+  public abstract static class PerKeySketch<K, V>
+          extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Sketch<V>>>> {
+
+    abstract double relativeError();
+
+    abstract double confidence();
+
+    abstract Builder<K, V> toBuilder();
+
+    static <K, V> Builder<K, V> builder() {
+      return new AutoValue_SketchFrequencies_PerKeySketch.Builder<K, V>()
+              .setRelativeError(0.001)
+              .setConfidence(0.999);
+    }
+
+    @AutoValue.Builder
+    abstract static class Builder<K, V> {
+      abstract Builder<K, V> setRelativeError(double eps);
+
+      abstract Builder<K, V> setConfidence(double conf);
+
+      abstract PerKeySketch<K, V> build();
+    }
+
+    public PerKeySketch<K, V> withRelativeError(double eps) {
+      return toBuilder().setRelativeError(eps).build();
+    }
+
+    public PerKeySketch<K, V> withConfidence(double conf) {
+      return toBuilder().setConfidence(conf).build();
+    }
+
+    @Override
+    public PCollection<KV<K, Sketch<V>>> expand(PCollection<KV<K, V>> input) {
+      KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+      return input.apply("Compute Count-Min Sketch perKey",
+              Combine.<K, V, Sketch<V>>perKey(CountMinSketchFn
+                      .<V>create(inputCoder.getValueCoder())
+                      .withAccuracy(this.relativeError(), this.confidence())));
+    }
+  }
+
+  /**
+   * Implements the {@link CombineFn} of {@link SketchFrequencies} transforms.
+   *
+   * @param <InputT> the type of the elements in the input {@link PCollection}
+   */
+  public static class CountMinSketchFn<InputT>
+          extends CombineFn<InputT, Sketch<InputT>, Sketch<InputT>> {
+
+    private final Coder<InputT> inputCoder;
+    private final int depth;
+    private final int width;
+    private final double epsilon;
+    private final double confidence;
+
+    private CountMinSketchFn(final Coder<InputT> coder, double eps, double confidence) {
+      this.epsilon = eps;
+      this.confidence = confidence;
+      this.width = (int) Math.ceil(2 / eps);
+      this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+      this.inputCoder = coder;
+    }
+
+    /**
+     * Returns an {@link CountMinSketchFn} combiner with the given input coder.
+     *
+     * @param coder the coder that encodes the elements' type
+     */
+    public static <InputT> CountMinSketchFn<InputT>create(Coder<InputT> coder) {
+      try {
+        coder.verifyDeterministic();
+      } catch (Coder.NonDeterministicException e) {
+        throw new IllegalArgumentException("Coder is not deterministic ! " + e.getMessage(), e);
+      }
+      return new CountMinSketchFn<>(coder, 0.01, 0.999);
+    }
+
+    /**
+     * Returns a new {@link CountMinSketchFn} combiner with new precision accuracy parameters
+     * {@code epsilon} and {@code confidence}.
+     *
+     * <p>Keep in mind that the lower the {@code epsilon} value, the greater the width,
+     * and the greater the confidence, the greater the depth.
+     *
+     * @param epsilon the error relative to the total number of distinct elements
+     * @param confidence the confidence in the result to not exceed the relative error
+     */
+    public CountMinSketchFn<InputT> withAccuracy(double epsilon, double confidence) {
+      if (epsilon <= 0D) {
+        throw new IllegalArgumentException("The relative error must be positive");
+      }
+
+      if (confidence <= 0D || confidence >= 1D) {
+        throw new IllegalArgumentException("The confidence must be comprised between 0 and 1");
+      }
+      return new CountMinSketchFn<InputT>(this.inputCoder, epsilon, confidence);
+    }
+
+    @Override public Sketch<InputT> createAccumulator() {
+      return new Sketch<InputT>(this.epsilon, this.confidence);
+    }
+
+    @Override public Sketch<InputT> addInput(Sketch<InputT> accumulator, InputT element) {
+      accumulator.add(element, inputCoder);
+
+      return accumulator;
+    }
+
+    @Override public Sketch<InputT> mergeAccumulators(Iterable<Sketch<InputT>> accumulators) {
+      Iterator<Sketch<InputT>> it = accumulators.iterator();
+      Sketch<InputT> first = it.next();
+      CountMinSketch mergedSketches = first.sketch;
+      try {
+        while (it.hasNext()) {
+          mergedSketches = CountMinSketch.merge(mergedSketches, it.next().sketch);
+        }
+      } catch (FrequencyMergeException e) {
+        // Should never happen because every instantiated accumulator are of the same type.
+        throw new IllegalStateException("The accumulators cannot be merged !" + e.getMessage());
+      }
+      first.sketch = mergedSketches;
+
+      return first;
+    }
+
+    /** Output the whole structure so it can be queried, reused or stored easily. */
+    @Override public Sketch<InputT> extractOutput(Sketch<InputT> accumulator) {
+      return accumulator;
+    }
+
+
+    @Override public Coder<Sketch<InputT>> getAccumulatorCoder(CoderRegistry registry,
+                                                               Coder inputCoder) {
+      return new CountMinSketchCoder<InputT>();
+    }
+
+    @Override
+    public void populateDisplayData(DisplayData.Builder builder) {
+      super.populateDisplayData(builder);
+      builder
+              .add(DisplayData.item("width", width)
+                      .withLabel("width of the Count-Min sketch array"))
+              .add(DisplayData.item("depth", depth)
+                      .withLabel("depth of the Count-Min sketch array"))
+              .add(DisplayData.item("eps", epsilon)
+                      .withLabel("relative error to the total number of elements"))
+              .add(DisplayData.item("conf", confidence)
+                      .withLabel("confidence in the relative error"));
+    }
+  }
+
+  /**
+   * Wrapper of StreamLib's Count-Min Sketch to fit with Beam requirements.
+   */
+  public static class Sketch<T> implements Serializable {
+
+    static final int SEED = 123456;
+
+    int width;
+    int depth;
+    CountMinSketch sketch;
+
+    public Sketch(double eps, double confidence) {
+      this.width = (int) Math.ceil(2 / eps);
+      this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+      sketch = new CountMinSketch(depth, width, SEED);
+    }
+
+    private Sketch(int width, int depth, CountMinSketch sketch) {
+      this.sketch = sketch;
+      this.width = width;
+      this.depth = depth;
+    }
+
+    public void add(T element, long count, Coder<T> coder) {
+      sketch.add(hashElement(element, coder), count);
+    }
+
+    public void add(T element, Coder<T> coder) {
+      add(element, 1L, coder);
+    }
+
+    private long hashElement(T element, Coder<T> coder) {
+      try {
+        byte[] elemBytes = CoderUtils.encodeToByteArray(coder, element);
+        return Hashing.murmur3_128().hashBytes(elemBytes).asLong();
+      } catch (CoderException e) {
+        throw new IllegalStateException("The input value cannot be encoded: " + e.getMessage(), e);
+      }
+    }
+
+    public int getWidth() {
+      return this.width;
+    }
+
+    public int getDepth() {
+      return this.depth;
+    }
+
+    /**
+     * Utility class to retrieve the estimate frequency of an element from a {@link
+     * CountMinSketch}.
+     */
+    public long estimateCount(T element, Coder<T> coder) {
+      return sketch.estimateCount(hashElement(element, coder));
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      final Sketch<T> other = (Sketch<T>) o;
+
+      if (depth != other.depth) {
+        return false;
+      }
+      if (width != other.width) {
+        return false;
+      }
+      return sketch.equals(other.sketch);
+    }
+  }
+
+  /**
+   * Coder for {@link CountMinSketch} class.
+   */
+  static class CountMinSketchCoder<T> extends CustomCoder<Sketch<T>> {
+
+    private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of();
+    private static final BigEndianIntegerCoder INT_CODER = BigEndianIntegerCoder.of();
+
+
+    @Override
+    public void encode(Sketch<T> value, OutputStream outStream) throws IOException {
+      if (value == null) {
+        throw new CoderException("cannot encode a null Count-min Sketch");
+      }
+      INT_CODER.encode(value.width, outStream);
+      INT_CODER.encode(value.depth, outStream);
+      BYTE_ARRAY_CODER.encode(CountMinSketch.serialize(value.sketch), outStream);
+    }
+
+    @Override
+    public Sketch<T> decode(InputStream inStream) throws IOException {
+      int width = INT_CODER.decode(inStream);
+      int depth = INT_CODER.decode(inStream);
+      byte[] sketchBytes = BYTE_ARRAY_CODER.decode(inStream);
+      CountMinSketch sketch = CountMinSketch.deserialize(sketchBytes);
+      return new Sketch<T>(width, depth, sketch);
+    }
+
+    @Override
+    public boolean isRegisterByteSizeObserverCheap(Sketch<T> value) {
+      return true;
+    }
+
+    @Override
+    protected long getEncodedElementByteSize(Sketch<T> value) throws IOException {
+      if (value == null) {
+        throw new CoderException("cannot encode a null Count-min Sketch");
+      } else {
+        // 8L is for the sketch's size (long)
+        // 4L * 2 is for depth and width (ints) in Sketch<T> and in the Count-Min sketch
+        // 8L * depth * (width + 1) is a factorization for the sizes of table (long[depth][width])
+        // and hashA (long[depth])
+        return 8L + 4L * 4 + 8L * value.depth * (value.width + 1);
+      }
+    }
+  }
+}
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java
new file mode 100644
index 0000000..ea773e6
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching;
+
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.avro.Schema;
+import org.apache.avro.SchemaBuilder;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.beam.sdk.coders.AvroCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.extensions.sketching.SketchFrequencies.CountMinSketchFn;
+import org.apache.beam.sdk.extensions.sketching.SketchFrequencies.Sketch;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Tests for {@link SketchFrequencies}.
+ */
+public class SketchFrequenciesTest implements Serializable {
+
+  @Rule public final transient TestPipeline tp = TestPipeline.create();
+
+  private List<Long> smallStream = Arrays.asList(
+          1L,
+          2L, 2L,
+          3L, 3L, 3L,
+          4L, 4L, 4L, 4L,
+          5L, 5L, 5L, 5L, 5L,
+          6L, 6L, 6L, 6L, 6L, 6L,
+          7L, 7L, 7L, 7L, 7L, 7L, 7L,
+          8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L,
+          9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L);
+
+  private Long[] distinctElems = {1L, 2L, 3L, 4L, 5L, 6L, 8L, 9L};
+  private Long[] frequencies = distinctElems.clone();
+
+  @Test
+  public void perKeyDefault() {
+    PCollection<Long> stream = tp.apply(Create.of(smallStream));
+    PCollection<Sketch<Long>> sketch = stream
+            .apply(WithKeys.<Integer, Long>of(1))
+            .apply(SketchFrequencies.<Integer, Long>perKey())
+            .apply(Values.<Sketch<Long>>create());
+
+    Coder<Long> coder = stream.getCoder();
+
+    PAssert.thatSingleton("Verify number of hits", sketch)
+            .satisfies(new VerifyStreamFrequencies<Long>(coder, distinctElems, frequencies));
+
+    tp.run();
+  }
+
+  @Test
+  public void globallyWithTunedParameters() {
+    double eps = 0.01;
+    double conf = 0.8;
+    PCollection<Long> stream = tp.apply(Create.of(smallStream));
+    PCollection<Sketch<Long>> sketch = stream
+            .apply(SketchFrequencies
+                    .<Long>globally()
+                    .withRelativeError(eps)
+                    .withConfidence(conf));
+
+    Coder<Long> coder = stream.getCoder();
+
+    PAssert.thatSingleton("Verify number of hits", sketch)
+            .satisfies(new VerifyStreamFrequencies<Long>(coder, distinctElems, frequencies));
+
+    tp.run();
+  }
+
+  @Test
+  public void merge() {
+    double eps = 0.01;
+    double conf = 0.8;
+    long nOccurrences = 2L;
+    int size = 3;
+
+    List<Sketch<Integer>> sketches = new ArrayList<>();
+    Coder<Integer> coder = VarIntCoder.of();
+
+    // n sketches each containing [0, 1, 2]
+    for (int i = 0; i < nOccurrences; i++) {
+      Sketch<Integer> sketch = new Sketch<Integer>(eps, conf);
+      for (int j = 0; j < size; j++) {
+        sketch.add(j, coder);
+      }
+      sketches.add(sketch);
+    }
+
+    CountMinSketchFn<Integer> fn = CountMinSketchFn.create(coder).withAccuracy(eps, conf);
+    Sketch<Integer> merged = fn.mergeAccumulators(sketches);
+    for (int i = 0; i < size; i++) {
+      Assert.assertEquals(nOccurrences, merged.estimateCount(i, coder));
+    }
+  }
+
+  @Test
+  public void customObject() {
+    int nUsers = 10;
+    long occurrences = 2L; // occurrence of each user in the stream
+    double eps = 0.01;
+    double conf = 0.8;
+    Sketch<GenericRecord> sketch = new Sketch<>(eps, conf);
+    Schema schema =
+            SchemaBuilder.record("User")
+                    .fields()
+                    .requiredString("Pseudo")
+                    .requiredInt("Age")
+                    .endRecord();
+    Coder<GenericRecord> coder = AvroCoder.of(schema);
+
+    for (int i = 1; i <= nUsers; i++) {
+      GenericData.Record newRecord = new GenericData.Record(schema);
+      newRecord.put("Pseudo", "User" + i);
+      newRecord.put("Age", i);
+      sketch.add(newRecord, occurrences, coder);
+      Assert.assertEquals("Test API", occurrences, sketch.estimateCount(newRecord, coder));
+    }
+  }
+
+  @Test
+  public void testCoder() throws Exception {
+    Sketch<Integer> cMSketch = new Sketch<Integer>(0.01, 0.8);
+    Coder<Integer> coder = VarIntCoder.of();
+    for (int i = 0; i < 3; i++) {
+      cMSketch.add(i, coder);
+    }
+
+    CoderProperties.<Sketch<Integer>>coderDecodeEncodeEqual(
+            new SketchFrequencies.CountMinSketchCoder<Integer>(), cMSketch);
+  }
+
+  @Test
+  public void testDisplayData() {
+    double eps = 0.01;
+    double conf = 0.8;
+    int width = (int) Math.ceil(2 / eps);
+    int depth = (int) Math.ceil(-Math.log(1 - conf) / Math.log(2));
+
+    final CountMinSketchFn<Integer> fn =
+            CountMinSketchFn.create(VarIntCoder.of()).withAccuracy(eps, conf);
+
+    assertThat(DisplayData.from(fn), hasDisplayItem("width", width));
+    assertThat(DisplayData.from(fn), hasDisplayItem("depth", depth));
+    assertThat(DisplayData.from(fn), hasDisplayItem("eps", eps));
+    assertThat(DisplayData.from(fn), hasDisplayItem("conf", conf));
+  }
+
+  static class VerifyStreamFrequencies<T> implements SerializableFunction<Sketch<T>, Void> {
+
+    Coder<T> coder;
+    Long[] expectedHits;
+    T[] elements;
+
+    VerifyStreamFrequencies(Coder<T> coder, T[] elements, Long[] expectedHits) {
+      this.coder = coder;
+      this.elements = elements;
+      this.expectedHits = expectedHits;
+    }
+
+    @Override
+    public Void apply(Sketch<T> sketch) {
+      for (int i = 0; i < elements.length; i++) {
+        Assert.assertEquals((long) expectedHits[i], sketch.estimateCount(elements[i], coder));
+      }
+      return null;
+    }
+  }
+}

-- 
To stop receiving notification emails like this one, please contact
lcwik@apache.org.