You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2018/01/24 23:13:49 UTC
[beam] 01/04: [BEAM-2728] Add Count-Min Sketch in sketching
extension
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
commit 34745056d063ac2f3c77d71f75d209639bb63cc2
Author: ArnaudFnr <ar...@gmail.com>
AuthorDate: Tue Nov 28 11:37:13 2017 +0100
[BEAM-2728] Add Count-Min Sketch in sketching extension
---
sdks/java/extensions/sketching/pom.xml | 1 +
.../extensions/sketching/SketchFrequencies.java | 544 +++++++++++++++++++++
.../sketching/SketchFrequenciesTest.java | 204 ++++++++
3 files changed, 749 insertions(+)
diff --git a/sdks/java/extensions/sketching/pom.xml b/sdks/java/extensions/sketching/pom.xml
index f0538ae..7390218 100755
--- a/sdks/java/extensions/sketching/pom.xml
+++ b/sdks/java/extensions/sketching/pom.xml
@@ -39,6 +39,7 @@
<artifactId>beam-sdks-java-core</artifactId>
</dependency>
+ <!-- Library containing sketches' implementation -->
<dependency>
<groupId>com.clearspring.analytics</groupId>
<artifactId>stream</artifactId>
diff --git a/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java
new file mode 100644
index 0000000..9872bcd
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/main/java/org/apache/beam/sdk/extensions/sketching/SketchFrequencies.java
@@ -0,0 +1,544 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching;
+
+import com.clearspring.analytics.stream.frequency.CountMinSketch;
+import com.clearspring.analytics.stream.frequency.FrequencyMergeException;
+import com.google.auto.value.AutoValue;
+import com.google.common.hash.Hashing;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Iterator;
+
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+
+/**
+ * {@code PTransform}s to compute the estimate frequency of each element in a stream.
+ *
+ * <p>This class uses the Count-min Sketch structure which allows very efficient queries
+ * on the data stream summarization.
+ *
+ * <h2>References</h2>
+ *
+ * <p>The implementation comes from <a href="https://github.com/addthis/stream-lib">
+ * Addthis' Stream-lib library</a>. <br>
+ * The papers and other useful information about Count-Min Sketch are available on <a
+ * href="https://sites.google.com/site/countminsketch/">this website</a>. <br>
+ *
+ * <h2>Parameters</h2>
+ *
+ * <p>Two parameters can be tuned in order to control the accuracy of the computation:
+ *
+ * <ul>
+ * <li><b>Relative Error:</b> <br>
+ * The relative error "{@code epsilon}" controls the accuracy of the estimation.
+ * By default, the relative is around {@code 1%} of the total count.
+ * <li><b>Confidence</b> <br>
+ * The relative error can be guaranteed only with a certain "{@code confidence}",
+ * between 0 and 1 (1 being of course impossible). <br>
+ * The default value is set to 0.999 meaning that we can guarantee
+ * that the relative error will not exceed 1% of the total count in 99.9% of cases.
+ * </ul>
+ *
+ * <p>These two parameters will determine the size of the Count-min sketch, which is
+ * a two-dimensional array with depth and width defined as follows :
+ * <ul>
+ * <li>{@code width = ceil(2 / epsilon)}</li>
+ * <li>{@code depth = ceil(-log(1 - confidence) / log(2))}</li>
+ * </ul>
+ *
+ * <p>With the default values, this gives a depth of 200 and a width of 10.
+ *
+ * <p><b>WARNING:</b> The relative error concerns the total number of distinct number elements
+ * in the stream. Thus, an element having 1000 occurrences in a streams of 1 million distinct
+ * elements will have 1% of 1 million as relative error, i.e. 10 000. This means the frequency
+ * is 1000 +/- 10 000 for this element. Therefore this is obvious that the relative error must
+ * be really low in very large streams. <br>
+ * Also keep in mind that this algorithm works well on highly skewed data but gives poor
+ * results if the elements are evenly distributed.
+ *
+ * <h2>Examples</h2>
+ *
+ * <p>There are 2 ways of using this class:
+ *
+ * <ul>
+ * <li>Use the {@link PTransform}s that return a {@link PCollection} singleton that contains
+ * a Count-min sketch for querying the estimate number of hits of the elements.
+ * <li>Use the {@link CountMinSketchFn} {@code CombineFn} that is exposed in order to make
+ * advanced processing involving the Count-Min sketch.
+ * </ul>
+ *
+ * <h3>Example 1: simple default use</h3>
+ *
+ * <p>The simplest use is simply to call the {@link #globally()} or {@link #perKey()} method in
+ * order to retrieve the sketch with an estimate number of hits for each element in the stream.
+ *
+ * <pre><code>
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal .<MyObject>}globally()); //{@literal .<MyObject>}perKey();
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 2: tune accuracy parameters</h3>
+ *
+ * <p>One can tune the {@code epsilon} and {@code confidence} parameters in order to
+ * control accuracy and memory. <br>
+ * The tuning works exactly the same for {@link #globally()} and {@link #perKey()}.
+ *
+ * <pre><code>
+ * double eps = 0.001;
+ * double conf = 0.9999;
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal .<MyObject>}globally() //{@literal .<MyObject>}perKey()
+ * .withRelativeError(eps)
+ * .withConfidence(conf));
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 3: query the resulting sketch</h3>
+ *
+ * <p>This example shows how to query the resulting {@link Sketch}.
+ * To estimate the number of hits of an element, one has to use
+ * {@link Sketch#estimateCount(Object, Coder)} method and to provide
+ * the coder for the element type. <br>
+ * For instance, one can build a KV Pair linking each element to an estimation
+ * of its frequency, using the sketch as side input of a {@link ParDo}. <br>
+ *
+ * <pre><code>
+ * {@literal PCollection<MyObject>} pc = ...;
+ * {@literal PCollection<CountMinSketch>} countMinSketch = pc.apply(SketchFrequencies
+ * {@literal .<MyObject>}globally());
+ *
+ * // Retrieve the coder for MyObject
+ * final{@literal Coder<MyObject>} = pc.getCoder();
+ * // build a View of the sketch so it can be passed a sideInput
+ * final{@literal PCollectionView<CountMinSketch>} sketchView = sketch.apply(View
+ * {@literal .<CountMinSketch>}asSingleton());
+ *
+ * {@literal PCollection<KV<MyObject, Long>>} pairs = pc.apply(ParDo.of(
+ * {@literal new DoFn<Long, KV<MyObject, Long>>()} {
+ * {@literal @ProcessElement}
+ * public void procesElement(ProcessContext c) {
+ * Long object = c.element();
+ * CountMinSketch sketch = c.sideInput(sketchView);
+ * sketch.estimateCount(elem, coder);
+ * }}).withSideInputs(sketchView));
+ * }
+ * </code></pre>
+ *
+ * <h3>Example 4 : Using the CombineFn</h3>
+ *
+ * <p>The {@code CombineFn} does the same thing as the {@code PTransform}s but
+ * it can be used for doing stateful processing or in
+ * {@link org.apache.beam.sdk.transforms.CombineFns.ComposedCombineFn}.
+ *
+ * <p>This example is not really interesting but it shows how you can properly create
+ * a {@link CountMinSketchFn}. One must always specify a coder using the {@link
+ * CountMinSketchFn#create(Coder)} method.
+ *
+ * <pre><code>
+ * double eps = 0.0001;
+ * double conf = 0.9999;
+ * {@literal PCollection<MyObject>} input = ...;
+ * {@literal PCollection<CountMinSketch>} output = input.apply(Combine.globally(CountMinSketchFn
+ * {@literal .<MyObject>}create(new MyObjectCoder())
+ * .withAccuracy(eps, conf)));
+ * }
+ * </code></pre>
+ *
+ * <p><b>Warning: this class is experimental.</b> <br>
+ * Its API is subject to change in future versions of Beam.
+ */
+@Experimental
+public final class SketchFrequencies {
+
+ /**
+ * Create the {@link PTransform} that will build a Count-min sketch for keeping track
+ * of the frequency of the elements in the whole stream.
+ *
+ * <p>It returns a {@code PCollection<{@link CountMinSketch}>} that can be queried in order to
+ * obtain estimations of the elements' frequencies.
+ *
+ * @param <InputT> the type of the elements in the input {@link PCollection}
+ */
+ public static <InputT> GlobalSketch<InputT> globally() {
+ return GlobalSketch.<InputT>builder().build();
+ }
+
+ /**
+ * Like {@link #globally()} but per key, i.e a Count-min sketch per key
+ * in {@code PCollection<KV<K, V>>} and returns a
+ * {@code PCollection<KV<K, {@link CountMinSketch}>>}.
+ *
+ * @param <K> type of the keys mapping the elements
+ * @param <V> type of the values being combined per key
+ */
+ public static <K, V> PerKeySketch<K, V> perKey() {
+ return PerKeySketch.<K, V>builder().build();
+ }
+
+ /**
+ * Implementation of {@link #globally()}.
+ *
+ * @param <InputT>
+ */
+ @AutoValue
+ public abstract static class GlobalSketch<InputT>
+ extends PTransform<PCollection<InputT>, PCollection<Sketch<InputT>>> {
+
+ abstract double relativeError();
+
+ abstract double confidence();
+
+ abstract Builder<InputT> toBuilder();
+
+ static <InputT> Builder<InputT> builder() {
+ return new AutoValue_SketchFrequencies_GlobalSketch.Builder<InputT>()
+ .setRelativeError(0.01)
+ .setConfidence(0.999);
+ }
+
+ @AutoValue.Builder
+ abstract static class Builder<InputT> {
+ abstract Builder<InputT> setRelativeError(double eps);
+
+ abstract Builder<InputT> setConfidence(double conf);
+
+ abstract GlobalSketch<InputT> build();
+ }
+
+ public GlobalSketch<InputT> withRelativeError(double eps) {
+ return toBuilder().setRelativeError(eps).build();
+ }
+
+ public GlobalSketch<InputT> withConfidence(double conf) {
+ return toBuilder().setConfidence(conf).build();
+ }
+
+ @Override
+ public PCollection<Sketch<InputT>> expand(PCollection<InputT> input) {
+ return input.apply("Compute Count-Min Sketch",
+ Combine.<InputT, Sketch<InputT>>globally(CountMinSketchFn
+ .<InputT>create(input.getCoder())
+ .withAccuracy(this.relativeError(), this.confidence())));
+ }
+ }
+
+ /**
+ * Implementation of {@link #perKey()}.
+ *
+ * @param <K>
+ * @param <V>
+ */
+ @AutoValue
+ public abstract static class PerKeySketch<K, V>
+ extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Sketch<V>>>> {
+
+ abstract double relativeError();
+
+ abstract double confidence();
+
+ abstract Builder<K, V> toBuilder();
+
+ static <K, V> Builder<K, V> builder() {
+ return new AutoValue_SketchFrequencies_PerKeySketch.Builder<K, V>()
+ .setRelativeError(0.001)
+ .setConfidence(0.999);
+ }
+
+ @AutoValue.Builder
+ abstract static class Builder<K, V> {
+ abstract Builder<K, V> setRelativeError(double eps);
+
+ abstract Builder<K, V> setConfidence(double conf);
+
+ abstract PerKeySketch<K, V> build();
+ }
+
+ public PerKeySketch<K, V> withRelativeError(double eps) {
+ return toBuilder().setRelativeError(eps).build();
+ }
+
+ public PerKeySketch<K, V> withConfidence(double conf) {
+ return toBuilder().setConfidence(conf).build();
+ }
+
+ @Override
+ public PCollection<KV<K, Sketch<V>>> expand(PCollection<KV<K, V>> input) {
+ KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+ return input.apply("Compute Count-Min Sketch perKey",
+ Combine.<K, V, Sketch<V>>perKey(CountMinSketchFn
+ .<V>create(inputCoder.getValueCoder())
+ .withAccuracy(this.relativeError(), this.confidence())));
+ }
+ }
+
+ /**
+ * Implements the {@link CombineFn} of {@link SketchFrequencies} transforms.
+ *
+ * @param <InputT> the type of the elements in the input {@link PCollection}
+ */
+ public static class CountMinSketchFn<InputT>
+ extends CombineFn<InputT, Sketch<InputT>, Sketch<InputT>> {
+
+ private final Coder<InputT> inputCoder;
+ private final int depth;
+ private final int width;
+ private final double epsilon;
+ private final double confidence;
+
+ private CountMinSketchFn(final Coder<InputT> coder, double eps, double confidence) {
+ this.epsilon = eps;
+ this.confidence = confidence;
+ this.width = (int) Math.ceil(2 / eps);
+ this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+ this.inputCoder = coder;
+ }
+
+ /**
+ * Returns an {@link CountMinSketchFn} combiner with the given input coder.
+ *
+ * @param coder the coder that encodes the elements' type
+ */
+ public static <InputT> CountMinSketchFn<InputT>create(Coder<InputT> coder) {
+ try {
+ coder.verifyDeterministic();
+ } catch (Coder.NonDeterministicException e) {
+ throw new IllegalArgumentException("Coder is not deterministic ! " + e.getMessage(), e);
+ }
+ return new CountMinSketchFn<>(coder, 0.01, 0.999);
+ }
+
+ /**
+ * Returns a new {@link CountMinSketchFn} combiner with new precision accuracy parameters
+ * {@code epsilon} and {@code confidence}.
+ *
+ * <p>Keep in mind that the lower the {@code epsilon} value, the greater the width,
+ * and the greater the confidence, the greater the depth.
+ *
+ * @param epsilon the error relative to the total number of distinct elements
+ * @param confidence the confidence in the result to not exceed the relative error
+ */
+ public CountMinSketchFn<InputT> withAccuracy(double epsilon, double confidence) {
+ if (epsilon <= 0D) {
+ throw new IllegalArgumentException("The relative error must be positive");
+ }
+
+ if (confidence <= 0D || confidence >= 1D) {
+ throw new IllegalArgumentException("The confidence must be comprised between 0 and 1");
+ }
+ return new CountMinSketchFn<InputT>(this.inputCoder, epsilon, confidence);
+ }
+
+ @Override public Sketch<InputT> createAccumulator() {
+ return new Sketch<InputT>(this.epsilon, this.confidence);
+ }
+
+ @Override public Sketch<InputT> addInput(Sketch<InputT> accumulator, InputT element) {
+ accumulator.add(element, inputCoder);
+
+ return accumulator;
+ }
+
+ @Override public Sketch<InputT> mergeAccumulators(Iterable<Sketch<InputT>> accumulators) {
+ Iterator<Sketch<InputT>> it = accumulators.iterator();
+ Sketch<InputT> first = it.next();
+ CountMinSketch mergedSketches = first.sketch;
+ try {
+ while (it.hasNext()) {
+ mergedSketches = CountMinSketch.merge(mergedSketches, it.next().sketch);
+ }
+ } catch (FrequencyMergeException e) {
+ // Should never happen because every instantiated accumulator are of the same type.
+ throw new IllegalStateException("The accumulators cannot be merged !" + e.getMessage());
+ }
+ first.sketch = mergedSketches;
+
+ return first;
+ }
+
+ /** Output the whole structure so it can be queried, reused or stored easily. */
+ @Override public Sketch<InputT> extractOutput(Sketch<InputT> accumulator) {
+ return accumulator;
+ }
+
+
+ @Override public Coder<Sketch<InputT>> getAccumulatorCoder(CoderRegistry registry,
+ Coder inputCoder) {
+ return new CountMinSketchCoder<InputT>();
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+ builder
+ .add(DisplayData.item("width", width)
+ .withLabel("width of the Count-Min sketch array"))
+ .add(DisplayData.item("depth", depth)
+ .withLabel("depth of the Count-Min sketch array"))
+ .add(DisplayData.item("eps", epsilon)
+ .withLabel("relative error to the total number of elements"))
+ .add(DisplayData.item("conf", confidence)
+ .withLabel("confidence in the relative error"));
+ }
+ }
+
+ /**
+ * Wrapper of StreamLib's Count-Min Sketch to fit with Beam requirements.
+ */
+ public static class Sketch<T> implements Serializable {
+
+ static final int SEED = 123456;
+
+ int width;
+ int depth;
+ CountMinSketch sketch;
+
+ public Sketch(double eps, double confidence) {
+ this.width = (int) Math.ceil(2 / eps);
+ this.depth = (int) Math.ceil(-Math.log(1 - confidence) / Math.log(2));
+ sketch = new CountMinSketch(depth, width, SEED);
+ }
+
+ private Sketch(int width, int depth, CountMinSketch sketch) {
+ this.sketch = sketch;
+ this.width = width;
+ this.depth = depth;
+ }
+
+ public void add(T element, long count, Coder<T> coder) {
+ sketch.add(hashElement(element, coder), count);
+ }
+
+ public void add(T element, Coder<T> coder) {
+ add(element, 1L, coder);
+ }
+
+ private long hashElement(T element, Coder<T> coder) {
+ try {
+ byte[] elemBytes = CoderUtils.encodeToByteArray(coder, element);
+ return Hashing.murmur3_128().hashBytes(elemBytes).asLong();
+ } catch (CoderException e) {
+ throw new IllegalStateException("The input value cannot be encoded: " + e.getMessage(), e);
+ }
+ }
+
+ public int getWidth() {
+ return this.width;
+ }
+
+ public int getDepth() {
+ return this.depth;
+ }
+
+ /**
+ * Utility class to retrieve the estimate frequency of an element from a {@link
+ * CountMinSketch}.
+ */
+ public long estimateCount(T element, Coder<T> coder) {
+ return sketch.estimateCount(hashElement(element, coder));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ final Sketch<T> other = (Sketch<T>) o;
+
+ if (depth != other.depth) {
+ return false;
+ }
+ if (width != other.width) {
+ return false;
+ }
+ return sketch.equals(other.sketch);
+ }
+ }
+
+ /**
+ * Coder for {@link CountMinSketch} class.
+ */
+ static class CountMinSketchCoder<T> extends CustomCoder<Sketch<T>> {
+
+ private static final ByteArrayCoder BYTE_ARRAY_CODER = ByteArrayCoder.of();
+ private static final BigEndianIntegerCoder INT_CODER = BigEndianIntegerCoder.of();
+
+
+ @Override
+ public void encode(Sketch<T> value, OutputStream outStream) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null Count-min Sketch");
+ }
+ INT_CODER.encode(value.width, outStream);
+ INT_CODER.encode(value.depth, outStream);
+ BYTE_ARRAY_CODER.encode(CountMinSketch.serialize(value.sketch), outStream);
+ }
+
+ @Override
+ public Sketch<T> decode(InputStream inStream) throws IOException {
+ int width = INT_CODER.decode(inStream);
+ int depth = INT_CODER.decode(inStream);
+ byte[] sketchBytes = BYTE_ARRAY_CODER.decode(inStream);
+ CountMinSketch sketch = CountMinSketch.deserialize(sketchBytes);
+ return new Sketch<T>(width, depth, sketch);
+ }
+
+ @Override
+ public boolean isRegisterByteSizeObserverCheap(Sketch<T> value) {
+ return true;
+ }
+
+ @Override
+ protected long getEncodedElementByteSize(Sketch<T> value) throws IOException {
+ if (value == null) {
+ throw new CoderException("cannot encode a null Count-min Sketch");
+ } else {
+ // 8L is for the sketch's size (long)
+ // 4L * 2 is for depth and width (ints) in Sketch<T> and in the Count-Min sketch
+ // 8L * depth * (width + 1) is a factorization for the sizes of table (long[depth][width])
+ // and hashA (long[depth])
+ return 8L + 4L * 4 + 8L * value.depth * (value.width + 1);
+ }
+ }
+ }
+}
diff --git a/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java
new file mode 100644
index 0000000..ea773e6
--- /dev/null
+++ b/sdks/java/extensions/sketching/src/test/java/org/apache/beam/sdk/extensions/sketching/SketchFrequenciesTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.extensions.sketching;
+
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.avro.Schema;
+import org.apache.avro.SchemaBuilder;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericRecord;
+import org.apache.beam.sdk.coders.AvroCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.extensions.sketching.SketchFrequencies.CountMinSketchFn;
+import org.apache.beam.sdk.extensions.sketching.SketchFrequencies.Sketch;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.WithKeys;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+
+/**
+ * Tests for {@link SketchFrequencies}.
+ */
+public class SketchFrequenciesTest implements Serializable {
+
+ @Rule public final transient TestPipeline tp = TestPipeline.create();
+
+ private List<Long> smallStream = Arrays.asList(
+ 1L,
+ 2L, 2L,
+ 3L, 3L, 3L,
+ 4L, 4L, 4L, 4L,
+ 5L, 5L, 5L, 5L, 5L,
+ 6L, 6L, 6L, 6L, 6L, 6L,
+ 7L, 7L, 7L, 7L, 7L, 7L, 7L,
+ 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L,
+ 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L, 9L);
+
+ private Long[] distinctElems = {1L, 2L, 3L, 4L, 5L, 6L, 8L, 9L};
+ private Long[] frequencies = distinctElems.clone();
+
+ @Test
+ public void perKeyDefault() {
+ PCollection<Long> stream = tp.apply(Create.of(smallStream));
+ PCollection<Sketch<Long>> sketch = stream
+ .apply(WithKeys.<Integer, Long>of(1))
+ .apply(SketchFrequencies.<Integer, Long>perKey())
+ .apply(Values.<Sketch<Long>>create());
+
+ Coder<Long> coder = stream.getCoder();
+
+ PAssert.thatSingleton("Verify number of hits", sketch)
+ .satisfies(new VerifyStreamFrequencies<Long>(coder, distinctElems, frequencies));
+
+ tp.run();
+ }
+
+ @Test
+ public void globallyWithTunedParameters() {
+ double eps = 0.01;
+ double conf = 0.8;
+ PCollection<Long> stream = tp.apply(Create.of(smallStream));
+ PCollection<Sketch<Long>> sketch = stream
+ .apply(SketchFrequencies
+ .<Long>globally()
+ .withRelativeError(eps)
+ .withConfidence(conf));
+
+ Coder<Long> coder = stream.getCoder();
+
+ PAssert.thatSingleton("Verify number of hits", sketch)
+ .satisfies(new VerifyStreamFrequencies<Long>(coder, distinctElems, frequencies));
+
+ tp.run();
+ }
+
+ @Test
+ public void merge() {
+ double eps = 0.01;
+ double conf = 0.8;
+ long nOccurrences = 2L;
+ int size = 3;
+
+ List<Sketch<Integer>> sketches = new ArrayList<>();
+ Coder<Integer> coder = VarIntCoder.of();
+
+ // n sketches each containing [0, 1, 2]
+ for (int i = 0; i < nOccurrences; i++) {
+ Sketch<Integer> sketch = new Sketch<Integer>(eps, conf);
+ for (int j = 0; j < size; j++) {
+ sketch.add(j, coder);
+ }
+ sketches.add(sketch);
+ }
+
+ CountMinSketchFn<Integer> fn = CountMinSketchFn.create(coder).withAccuracy(eps, conf);
+ Sketch<Integer> merged = fn.mergeAccumulators(sketches);
+ for (int i = 0; i < size; i++) {
+ Assert.assertEquals(nOccurrences, merged.estimateCount(i, coder));
+ }
+ }
+
+ @Test
+ public void customObject() {
+ int nUsers = 10;
+ long occurrences = 2L; // occurrence of each user in the stream
+ double eps = 0.01;
+ double conf = 0.8;
+ Sketch<GenericRecord> sketch = new Sketch<>(eps, conf);
+ Schema schema =
+ SchemaBuilder.record("User")
+ .fields()
+ .requiredString("Pseudo")
+ .requiredInt("Age")
+ .endRecord();
+ Coder<GenericRecord> coder = AvroCoder.of(schema);
+
+ for (int i = 1; i <= nUsers; i++) {
+ GenericData.Record newRecord = new GenericData.Record(schema);
+ newRecord.put("Pseudo", "User" + i);
+ newRecord.put("Age", i);
+ sketch.add(newRecord, occurrences, coder);
+ Assert.assertEquals("Test API", occurrences, sketch.estimateCount(newRecord, coder));
+ }
+ }
+
+ @Test
+ public void testCoder() throws Exception {
+ Sketch<Integer> cMSketch = new Sketch<Integer>(0.01, 0.8);
+ Coder<Integer> coder = VarIntCoder.of();
+ for (int i = 0; i < 3; i++) {
+ cMSketch.add(i, coder);
+ }
+
+ CoderProperties.<Sketch<Integer>>coderDecodeEncodeEqual(
+ new SketchFrequencies.CountMinSketchCoder<Integer>(), cMSketch);
+ }
+
+ @Test
+ public void testDisplayData() {
+ double eps = 0.01;
+ double conf = 0.8;
+ int width = (int) Math.ceil(2 / eps);
+ int depth = (int) Math.ceil(-Math.log(1 - conf) / Math.log(2));
+
+ final CountMinSketchFn<Integer> fn =
+ CountMinSketchFn.create(VarIntCoder.of()).withAccuracy(eps, conf);
+
+ assertThat(DisplayData.from(fn), hasDisplayItem("width", width));
+ assertThat(DisplayData.from(fn), hasDisplayItem("depth", depth));
+ assertThat(DisplayData.from(fn), hasDisplayItem("eps", eps));
+ assertThat(DisplayData.from(fn), hasDisplayItem("conf", conf));
+ }
+
+ static class VerifyStreamFrequencies<T> implements SerializableFunction<Sketch<T>, Void> {
+
+ Coder<T> coder;
+ Long[] expectedHits;
+ T[] elements;
+
+ VerifyStreamFrequencies(Coder<T> coder, T[] elements, Long[] expectedHits) {
+ this.coder = coder;
+ this.elements = elements;
+ this.expectedHits = expectedHits;
+ }
+
+ @Override
+ public Void apply(Sketch<T> sketch) {
+ for (int i = 0; i < elements.length; i++) {
+ Assert.assertEquals((long) expectedHits[i], sketch.estimateCount(elements[i], coder));
+ }
+ return null;
+ }
+ }
+}
--
To stop receiving notification emails like this one, please contact
lcwik@apache.org.