You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/04/04 16:30:09 UTC

[2/2] beam git commit: Add GroupIntoBatches

Add GroupIntoBatches

This groups input KVs into output K, Iterable<V>s of a specified size.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1e9089ff
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1e9089ff
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1e9089ff

Branch: refs/heads/master
Commit: 1e9089ffdf969792d2fae3ecca829a8a4f3f3884
Parents: 498ce9f
Author: Etienne Chauchot <ec...@gmail.com>
Authored: Wed Jan 18 15:04:48 2017 +0100
Committer: Thomas Groh <tg...@google.com>
Committed: Tue Apr 4 09:29:56 2017 -0700

----------------------------------------------------------------------
 .../beam/sdk/transforms/GroupIntoBatches.java   | 229 ++++++++++++++++++
 .../sdk/transforms/GroupIntoBatchesTest.java    | 232 +++++++++++++++++++
 2 files changed, 461 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/1e9089ff/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
new file mode 100644
index 0000000..095ca2a
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterables;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.TimeDomain;
+import org.apache.beam.sdk.util.Timer;
+import org.apache.beam.sdk.util.TimerSpec;
+import org.apache.beam.sdk.util.TimerSpecs;
+import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
+import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.StateSpec;
+import org.apache.beam.sdk.util.state.StateSpecs;
+import org.apache.beam.sdk.util.state.ValueState;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link PTransform} that batches inputs to a desired batch size. Batches will contain only
+ * elements of a single key.
+ *
+ * <p>Elements are buffered until there are {@code batchSize} elements
+ * buffered, at which point they are output to the output {@link PCollection}.
+ *
+ * <p>Windows are preserved (batches contain elements from the same window).
+ * Batches may contain elements from more than one bundle
+ *
+ * <p>Example (batch call a webservice and get return codes)
+ *
+ * <pre>{@code
+ *  Pipeline pipeline = Pipeline.create(...);
+ *  ... // KV collection
+ *  long batchSize = 100L;
+ *  pipeline.apply(GroupIntoBatches.<String, String>ofSize(batchSize))
+ * .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())))
+ * .apply(ParDo.of(new DoFn<KV<String, Iterable<String>>, KV<String, String>>() {
+ * {@literal @}ProcessElement
+ * public void processElement(ProcessContext c){
+ * c.output(KV.of(c.element().getKey(), callWebService(c.element().getValue())));
+ * }
+ * }));
+ *  pipeline.run();
+ * }</pre>
+ */
+public class GroupIntoBatches<K, InputT>
+    extends PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, Iterable<InputT>>>> {
+
+  private final long batchSize;
+
+  private GroupIntoBatches(long batchSize) {
+    this.batchSize = batchSize;
+  }
+
+  public static <K, InputT> GroupIntoBatches<K, InputT> ofSize(long batchSize) {
+    return new GroupIntoBatches<>(batchSize);
+  }
+
+  @Override
+  public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {
+    Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
+
+    checkArgument(
+        input.getCoder() instanceof KvCoder,
+        "coder specified in the input PCollection is not a KvCoder");
+    KvCoder inputCoder = (KvCoder) input.getCoder();
+    Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
+    Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
+
+    return input.apply(
+        ParDo.of(new GroupIntoBatchesDoFn<>(batchSize, allowedLateness, keyCoder, valueCoder)));
+  }
+
+  @VisibleForTesting
+  static class GroupIntoBatchesDoFn<K, InputT>
+      extends DoFn<KV<K, InputT>, KV<K, Iterable<InputT>>> {
+
+    private static final Logger LOGGER = LoggerFactory.getLogger(GroupIntoBatchesDoFn.class);
+    private static final String END_OF_WINDOW_ID = "endOFWindow";
+    private static final String BATCH_ID = "batch";
+    private static final String NUM_ELEMENTS_IN_BATCH_ID = "numElementsInBatch";
+    private static final String KEY_ID = "key";
+    private final long batchSize;
+    private final Duration allowedLateness;
+
+    @TimerId(END_OF_WINDOW_ID)
+    private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+    @StateId(BATCH_ID)
+    private final StateSpec<Object, BagState<InputT>> batchSpec;
+
+    @StateId(NUM_ELEMENTS_IN_BATCH_ID)
+    private final StateSpec<Object, AccumulatorCombiningState<Long, Long, Long>>
+        numElementsInBatchSpec;
+
+    @StateId(KEY_ID)
+    private final StateSpec<Object, ValueState<K>> keySpec;
+
+    private final long prefetchFrequency;
+
+    GroupIntoBatchesDoFn(
+        long batchSize,
+        Duration allowedLateness,
+        Coder<K> inputKeyCoder,
+        Coder<InputT> inputValueCoder) {
+      this.batchSize = batchSize;
+      this.allowedLateness = allowedLateness;
+      this.batchSpec = StateSpecs.bag(inputValueCoder);
+      this.numElementsInBatchSpec =
+          StateSpecs.combiningValue(
+              VarLongCoder.of(),
+              new Combine.CombineFn<Long, Long, Long>() {
+
+                @Override
+                public Long createAccumulator() {
+                  return 0L;
+                }
+
+                @Override
+                public Long addInput(Long accumulator, Long input) {
+                  return accumulator + input;
+                }
+
+                @Override
+                public Long mergeAccumulators(Iterable<Long> accumulators) {
+                  long sum = 0L;
+                  for (Long accumulator : accumulators) {
+                    sum += accumulator;
+                  }
+                  return sum;
+                }
+
+                @Override
+                public Long extractOutput(Long accumulator) {
+                  return accumulator;
+                }
+              });
+
+      this.keySpec = StateSpecs.value(inputKeyCoder);
+      // prefetch every 20% of batchSize elements. Do not prefetch if batchSize is too little
+      this.prefetchFrequency = ((batchSize / 5) <= 1) ? Long.MAX_VALUE : (batchSize / 5);
+    }
+
+    @ProcessElement
+    public void processElement(
+        @TimerId(END_OF_WINDOW_ID) Timer timer,
+        @StateId(BATCH_ID) BagState<InputT> batch,
+        @StateId(NUM_ELEMENTS_IN_BATCH_ID)
+            AccumulatorCombiningState<Long, Long, Long> numElementsInBatch,
+        @StateId(KEY_ID) ValueState<K> key,
+        ProcessContext c,
+        BoundedWindow window) {
+      Instant windowExpires = window.maxTimestamp().plus(allowedLateness);
+
+      LOGGER.debug(
+          "*** SET TIMER *** to point in time {} for window {}",
+          windowExpires.toString(), window.toString());
+      timer.set(windowExpires);
+      key.write(c.element().getKey());
+      batch.add(c.element().getValue());
+      LOGGER.debug("*** BATCH *** Add element for window {} ", window.toString());
+      // blind add is supported with combiningState
+      numElementsInBatch.add(1L);
+      Long num = numElementsInBatch.read();
+      if (num % prefetchFrequency == 0) {
+        //prefetch data and modify batch state (readLater() modifies this)
+        batch.readLater();
+      }
+      if (num >= batchSize) {
+        LOGGER.debug("*** END OF BATCH *** for window {}", window.toString());
+        flushBatch(c, key, batch, numElementsInBatch);
+      }
+    }
+
+    @OnTimer(END_OF_WINDOW_ID)
+    public void onTimerCallback(
+        OnTimerContext context,
+        @StateId(KEY_ID) ValueState<K> key,
+        @StateId(BATCH_ID) BagState<InputT> batch,
+        @StateId(NUM_ELEMENTS_IN_BATCH_ID)
+            AccumulatorCombiningState<Long, Long, Long> numElementsInBatch,
+        BoundedWindow window) {
+      LOGGER.debug(
+          "*** END OF WINDOW *** for timer timestamp {} in windows {}",
+          context.timestamp(), window.toString());
+      flushBatch(context, key, batch, numElementsInBatch);
+    }
+
+    private void flushBatch(
+        Context c,
+        ValueState<K> key,
+        BagState<InputT> batch,
+        AccumulatorCombiningState<Long, Long, Long> numElementsInBatch) {
+      Iterable<InputT> values = batch.read();
+      // when the timer fires, batch state might be empty
+      if (Iterables.size(values) > 0) {
+        c.output(KV.of(key.read(), values));
+      }
+      batch.clear();
+      LOGGER.debug("*** BATCH *** clear");
+      numElementsInBatch.clear();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/1e9089ff/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
new file mode 100644
index 0000000..54e2d5a
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.collect.Iterables;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Iterator;
+import org.apache.beam.sdk.coders.IterableCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
+import org.apache.beam.sdk.testing.UsesStatefulParDo;
+import org.apache.beam.sdk.testing.UsesTestStream;
+import org.apache.beam.sdk.testing.UsesTimersInParDo;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Test Class for {@link GroupIntoBatches}. */
+@RunWith(JUnit4.class)
+public class GroupIntoBatchesTest implements Serializable {
+  private static final int BATCH_SIZE = 5;
+  private static final long NUM_ELEMENTS = 10;
+  private static final int ALLOWED_LATENESS = 0;
+  private static final Logger LOGGER = LoggerFactory.getLogger(GroupIntoBatchesTest.class);
+  @Rule public transient TestPipeline pipeline = TestPipeline.create();
+  private transient ArrayList<KV<String, String>> data = createTestData();
+
+  private static ArrayList<KV<String, String>> createTestData() {
+    String[] scientists = {
+      "Einstein",
+      "Darwin",
+      "Copernicus",
+      "Pasteur",
+      "Curie",
+      "Faraday",
+      "Newton",
+      "Bohr",
+      "Galilei",
+      "Maxwell"
+    };
+    ArrayList<KV<String, String>> data = new ArrayList<>();
+    for (int i = 0; i < NUM_ELEMENTS; i++) {
+      int index = i % scientists.length;
+      KV<String, String> element = KV.of("key", scientists[index]);
+      data.add(element);
+    }
+    return data;
+  }
+
+  @Test
+  @Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class})
+  public void testInGlobalWindow() {
+    PCollection<KV<String, Iterable<String>>> collection =
+        pipeline
+            .apply("Input data", Create.of(data))
+            .apply(GroupIntoBatches.<String, String>ofSize(BATCH_SIZE))
+            //set output coder
+            .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())));
+    PAssert.that("Incorrect batch size in one ore more elements", collection)
+        .satisfies(
+            new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {
+
+              private boolean checkBatchSizes(Iterable<KV<String, Iterable<String>>> listToCheck) {
+                for (KV<String, Iterable<String>> element : listToCheck) {
+                  if (Iterables.size(element.getValue()) != BATCH_SIZE){
+                    return false;
+                  }
+                }
+                return true;
+              }
+
+              @Override
+              public Void apply(Iterable<KV<String, Iterable<String>>> input) {
+                assertTrue(checkBatchSizes(input));
+                return null;
+              }
+            });
+    PAssert.thatSingleton(
+            "Incorrect collection size",
+            collection.apply("Count", Count.<KV<String, Iterable<String>>>globally()))
+        .isEqualTo(NUM_ELEMENTS / BATCH_SIZE);
+    pipeline.run();
+  }
+
+  @Test
+  @Category({
+    NeedsRunner.class,
+    UsesTimersInParDo.class,
+    UsesTestStream.class,
+    UsesStatefulParDo.class
+  })
+  public void testInStreamingMode() {
+    int timestampInterval = 1;
+    Instant startInstant = new Instant(0L);
+    TestStream.Builder<KV<String, String>> streamBuilder =
+        TestStream.create(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+            .advanceWatermarkTo(startInstant);
+    long offset = 0L;
+    for (KV<String, String> element : data) {
+      streamBuilder =
+          streamBuilder.addElements(
+              TimestampedValue.of(
+                  element,
+                  startInstant.plus(Duration.standardSeconds(offset * timestampInterval))));
+      offset++;
+    }
+    final long windowDuration = 6;
+    TestStream<KV<String, String>> stream =
+        streamBuilder
+            .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(windowDuration - 1)))
+            .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(windowDuration + 1)))
+            .advanceWatermarkTo(startInstant.plus(Duration.standardSeconds(NUM_ELEMENTS)))
+            .advanceWatermarkToInfinity();
+
+    PCollection<KV<String, String>> inputCollection =
+        pipeline
+            .apply(stream)
+            .apply(
+                Window.<KV<String, String>>into(
+                        FixedWindows.of(Duration.standardSeconds(windowDuration)))
+                    .withAllowedLateness(Duration.millis(ALLOWED_LATENESS)));
+    inputCollection.apply(
+        ParDo.of(
+            new DoFn<KV<String, String>, Void>() {
+              @ProcessElement
+              public void processElement(ProcessContext c, BoundedWindow window) {
+                LOGGER.debug(
+                    "*** ELEMENT: ({},{}) *** with timestamp %s in window %s",
+                    c.element().getKey(),
+                    c.element().getValue(),
+                    c.timestamp().toString(),
+                    window.toString());
+              }
+            }));
+
+    PCollection<KV<String, Iterable<String>>> outputCollection =
+        inputCollection
+            .apply(GroupIntoBatches.<String, String>ofSize(BATCH_SIZE))
+            .setCoder(KvCoder.of(StringUtf8Coder.of(), IterableCoder.of(StringUtf8Coder.of())));
+
+    // elements have the same key and collection is divided into windows,
+    // so Count.perKey values are the number of elements in windows
+    PCollection<KV<String, Long>> countOutput =
+        outputCollection.apply(
+            "Count elements in windows after applying GroupIntoBatches",
+            Count.<String, Iterable<String>>perKey());
+
+    PAssert.that("Wrong number of elements in windows after GroupIntoBatches", countOutput)
+        .satisfies(
+            new SerializableFunction<Iterable<KV<String, Long>>, Void>() {
+
+              @Override
+              public Void apply(Iterable<KV<String, Long>> input) {
+                Iterator<KV<String, Long>> inputIterator = input.iterator();
+                // first element
+                long count0 = inputIterator.next().getValue();
+                // window duration is 6 and batch size is 5, so there should be 2 elements in the
+                // window (flush because batchSize reached and for end of window reached)
+                assertEquals("Wrong number of elements in first window", 2, count0);
+                // second element
+                long count1 = inputIterator.next().getValue();
+                // collection is 10 elements, there is only 4 elements left, so there should be only
+                // one element in the window (flush because end of window/collection reached)
+                assertEquals("Wrong number of elements in second window", 1, count1);
+                // third element
+                return null;
+              }
+            });
+
+    PAssert.that("Incorrect output collection after GroupIntoBatches", outputCollection)
+        .satisfies(
+            new SerializableFunction<Iterable<KV<String, Iterable<String>>>, Void>() {
+
+              @Override
+              public Void apply(Iterable<KV<String, Iterable<String>>> input) {
+                Iterator<KV<String, Iterable<String>>> inputIterator = input.iterator();
+                // first element
+                int size0 = Iterables.size(inputIterator.next().getValue());
+                // window duration is 6 and batch size is 5, so output batch size should de 5
+                // (flush because of batchSize reached)
+                assertEquals("Wrong first element batch Size", 5, size0);
+                // second element
+                int size1 = Iterables.size(inputIterator.next().getValue());
+                // there is only one element left in the window so batch size should be 1
+                // (flush because of end of window reached)
+                assertEquals("Wrong second element batch Size", 1, size1);
+                // third element
+                int size2 = Iterables.size(inputIterator.next().getValue());
+                // collection is 10 elements, there is only 4 left, so batch size should be 4
+                // (flush because end of collection reached)
+                assertEquals("Wrong third element batch Size", 4, size2);
+                return null;
+              }
+            });
+    pipeline.run().waitUntilFinish();
+  }
+}