You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/07/28 22:25:22 UTC
[1/2] beam git commit: This closes #3579
Repository: beam
Updated Branches:
refs/heads/master a94d680ea -> 1f2634d23
This closes #3579
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/1f2634d2
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/1f2634d2
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/1f2634d2
Branch: refs/heads/master
Commit: 1f2634d23bfa5b32d88cef236f2080df79e5e47d
Parents: a94d680 6ea2eda
Author: Thomas Groh <tg...@google.com>
Authored: Fri Jul 28 15:25:11 2017 -0700
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Jul 28 15:25:11 2017 -0700
----------------------------------------------------------------------
.../beam/runners/direct/DirectRunner.java | 65 +--
.../beam/runners/direct/MultiStepCombine.java | 423 +++++++++++++++++++
.../direct/TransformEvaluatorRegistry.java | 4 +
.../runners/direct/MultiStepCombineTest.java | 228 ++++++++++
4 files changed, 690 insertions(+), 30 deletions(-)
----------------------------------------------------------------------
[2/2] beam git commit: Perform a Multi-step combine in the
DirectRunner
Posted by tg...@apache.org.
Perform a Multi-step combine in the DirectRunner
This exercises the entire CombineFn lifecycle for simple combine fns,
expressed as a collection of DoFns.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/6ea2eda2
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/6ea2eda2
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/6ea2eda2
Branch: refs/heads/master
Commit: 6ea2eda2b5dbe4fc0cc2a84b64b76a07c7d0eda8
Parents: a94d680
Author: Thomas Groh <tg...@google.com>
Authored: Thu Jun 15 15:53:46 2017 -0700
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Jul 28 15:25:11 2017 -0700
----------------------------------------------------------------------
.../beam/runners/direct/DirectRunner.java | 65 +--
.../beam/runners/direct/MultiStepCombine.java | 423 +++++++++++++++++++
.../direct/TransformEvaluatorRegistry.java | 4 +
.../runners/direct/MultiStepCombineTest.java | 228 ++++++++++
4 files changed, 690 insertions(+), 30 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index c5f29e5..642ce8f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -233,36 +233,41 @@ public class DirectRunner extends PipelineRunner<DirectPipelineResult> {
PTransformMatchers.writeWithRunnerDeterminedSharding(),
new WriteWithShardingFactory())); /* Uses a view internally. */
}
- builder = builder.add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN),
- new ViewOverrideFactory())) /* Uses pardos and GBKs */
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN),
- new DirectTestStreamFactory(this))) /* primitive */
- // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra
- // primitives
- .add(
- PTransformOverride.of(
- PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory()))
- // state and timer pardos are implemented in terms of simple ParDos and extra primitives
- .add(
- PTransformOverride.of(
- PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory()))
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(
- SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
- new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN),
- new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */
- .add(
- PTransformOverride.of(
- PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN),
- new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */
+ builder =
+ builder
+ .add(
+ PTransformOverride.of(
+ MultiStepCombine.matcher(), MultiStepCombine.Factory.create()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.CREATE_VIEW_TRANSFORM_URN),
+ new ViewOverrideFactory())) /* Uses pardos and GBKs */
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.TEST_STREAM_TRANSFORM_URN),
+ new DirectTestStreamFactory(this))) /* primitive */
+ // SplittableParMultiDo is implemented in terms of nonsplittable simple ParDos and extra
+ // primitives
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.splittableParDo(), new ParDoMultiOverrideFactory()))
+ // state and timer pardos are implemented in terms of simple ParDos and extra primitives
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.stateOrTimerParDo(), new ParDoMultiOverrideFactory()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(
+ SplittableParDo.SPLITTABLE_PROCESS_KEYED_ELEMENTS_URN),
+ new SplittableParDoViaKeyedWorkItems.OverrideFactory()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(SplittableParDo.SPLITTABLE_GBKIKWI_URN),
+ new DirectGBKIntoKeyedWorkItemsOverrideFactory())) /* Returns a GBKO */
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.urnEqualTo(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN),
+ new DirectGroupByKeyOverrideFactory())); /* returns two chained primitives. */
return builder.build();
}
http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java
new file mode 100644
index 0000000..6f49e94
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/MultiStepCombine.java
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Objects;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.CombineTranslation;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.PTransformTranslation.RawPTransform;
+import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.runners.AppliedPTransform;
+import org.apache.beam.sdk.runners.PTransformMatcher;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Combine.PerKey;
+import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.UserCodeException;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.joda.time.Instant;
+
+/** A {@link Combine} that performs the combine in multiple steps. */
+class MultiStepCombine<K, InputT, AccumT, OutputT>
+ extends RawPTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> {
+ public static PTransformMatcher matcher() {
+ return new PTransformMatcher() {
+ @Override
+ public boolean matches(AppliedPTransform<?, ?, ?> application) {
+ if (PTransformTranslation.COMBINE_TRANSFORM_URN.equals(
+ PTransformTranslation.urnForTransformOrNull(application.getTransform()))) {
+ try {
+ GlobalCombineFn fn = CombineTranslation.getCombineFn(application);
+ return isApplicable(application.getInputs(), fn);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ return false;
+ }
+
+ private <K, InputT> boolean isApplicable(
+ Map<TupleTag<?>, PValue> inputs, GlobalCombineFn<InputT, ?, ?> fn) {
+ if (!(fn instanceof CombineFn)) {
+ return false;
+ }
+ if (inputs.size() == 1) {
+ PCollection<KV<K, InputT>> input =
+ (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+ WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+ boolean windowFnApplicable = windowingStrategy.getWindowFn().isNonMerging();
+ // Triggering with count based triggers is not appropriately handled here. Disabling
+ // most triggers is safe, though more broad than is technically required.
+ boolean triggerApplicable = DefaultTrigger.of().equals(windowingStrategy.getTrigger());
+ boolean accumulatorCoderAvailable;
+ try {
+ if (input.getCoder() instanceof KvCoder) {
+ KvCoder<K, InputT> kvCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<?> accumulatorCoder =
+ fn.getAccumulatorCoder(
+ input.getPipeline().getCoderRegistry(), kvCoder.getValueCoder());
+ accumulatorCoderAvailable = accumulatorCoder != null;
+ } else {
+ accumulatorCoderAvailable = false;
+ }
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(
+ String.format(
+ "Could not construct an accumulator %s for %s. Accumulator %s for a %s may be"
+ + " null, but may not throw an exception",
+ Coder.class.getSimpleName(),
+ fn,
+ Coder.class.getSimpleName(),
+ Combine.class.getSimpleName()),
+ e);
+ }
+ return windowFnApplicable && triggerApplicable && accumulatorCoderAvailable;
+ }
+ return false;
+ }
+ };
+ }
+
+ static class Factory<K, InputT, AccumT, OutputT>
+ extends SingleInputOutputOverrideFactory<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+ PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
+ public static PTransformOverrideFactory create() {
+ return new Factory<>();
+ }
+
+ private Factory() {}
+
+ @Override
+ public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>
+ getReplacementTransform(
+ AppliedPTransform<
+ PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>,
+ PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>>
+ transform) {
+ try {
+ GlobalCombineFn<?, ?, ?> globalFn = CombineTranslation.getCombineFn(transform);
+ checkState(
+ globalFn instanceof CombineFn,
+ "%s.matcher() should only match %s instances using %s, got %s",
+ MultiStepCombine.class.getSimpleName(),
+ PerKey.class.getSimpleName(),
+ CombineFn.class.getSimpleName(),
+ globalFn.getClass().getName());
+ @SuppressWarnings("unchecked")
+ CombineFn<InputT, AccumT, OutputT> fn = (CombineFn<InputT, AccumT, OutputT>) globalFn;
+ @SuppressWarnings("unchecked")
+ PCollection<KV<K, InputT>> input =
+ (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(transform.getInputs().values());
+ @SuppressWarnings("unchecked")
+ PCollection<KV<K, OutputT>> output =
+ (PCollection<KV<K, OutputT>>) Iterables.getOnlyElement(transform.getOutputs().values());
+ return PTransformReplacement.of(input, new MultiStepCombine<>(fn, output.getCoder()));
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ // ===========================================================================================
+
+ private final CombineFn<InputT, AccumT, OutputT> combineFn;
+ private final Coder<KV<K, OutputT>> outputCoder;
+
+ public static <K, InputT, AccumT, OutputT> MultiStepCombine<K, InputT, AccumT, OutputT> of(
+ CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) {
+ return new MultiStepCombine<>(combineFn, outputCoder);
+ }
+
+ private MultiStepCombine(
+ CombineFn<InputT, AccumT, OutputT> combineFn, Coder<KV<K, OutputT>> outputCoder) {
+ this.combineFn = combineFn;
+ this.outputCoder = outputCoder;
+ }
+
+ @Nullable
+ @Override
+ public String getUrn() {
+ return "urn:beam:directrunner:transforms:multistepcombine:v1";
+ }
+
+ @Override
+ public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, InputT>> input) {
+ checkArgument(
+ input.getCoder() instanceof KvCoder,
+ "Expected input to have a %s of type %s, got %s",
+ Coder.class.getSimpleName(),
+ KvCoder.class.getSimpleName(),
+ input.getCoder());
+ KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<InputT> inputValueCoder = inputCoder.getValueCoder();
+ Coder<AccumT> accumulatorCoder;
+ try {
+ accumulatorCoder =
+ combineFn.getAccumulatorCoder(input.getPipeline().getCoderRegistry(), inputValueCoder);
+ } catch (CannotProvideCoderException e) {
+ throw new IllegalStateException(
+ String.format(
+ "Could not construct an Accumulator Coder with the provided %s %s",
+ CombineFn.class.getSimpleName(), combineFn),
+ e);
+ }
+ return input
+ .apply(
+ ParDo.of(
+ new CombineInputs<>(
+ combineFn,
+ input.getWindowingStrategy().getTimestampCombiner(),
+ inputCoder.getKeyCoder())))
+ .setCoder(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder))
+ .apply(GroupByKey.<K, AccumT>create())
+ .apply(new MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>(combineFn))
+ .setCoder(outputCoder);
+ }
+
+ private static class CombineInputs<K, InputT, AccumT> extends DoFn<KV<K, InputT>, KV<K, AccumT>> {
+ private final CombineFn<InputT, AccumT, ?> combineFn;
+ private final TimestampCombiner timestampCombiner;
+ private final Coder<K> keyCoder;
+
+ /**
+ * Per-bundle state. Accumulators and output timestamps should only be tracked while a bundle
+ * is being processed, and must be cleared when a bundle is completed.
+ */
+ private transient Map<WindowedStructuralKey<K>, AccumT> accumulators;
+ private transient Map<WindowedStructuralKey<K>, Instant> timestamps;
+
+ private CombineInputs(
+ CombineFn<InputT, AccumT, ?> combineFn,
+ TimestampCombiner timestampCombiner,
+ Coder<K> keyCoder) {
+ this.combineFn = combineFn;
+ this.timestampCombiner = timestampCombiner;
+ this.keyCoder = keyCoder;
+ }
+
+ @StartBundle
+ public void startBundle() {
+ accumulators = new LinkedHashMap<>();
+ timestamps = new LinkedHashMap<>();
+ }
+
+ @ProcessElement
+ public void processElement(ProcessContext context, BoundedWindow window) {
+ WindowedStructuralKey<K>
+ key = WindowedStructuralKey.create(keyCoder, context.element().getKey(), window);
+ AccumT accumulator = accumulators.get(key);
+ Instant assignedTs = timestampCombiner.assign(window, context.timestamp());
+ if (accumulator == null) {
+ accumulator = combineFn.createAccumulator();
+ accumulators.put(key, accumulator);
+ timestamps.put(key, assignedTs);
+ }
+ accumulators.put(key, combineFn.addInput(accumulator, context.element().getValue()));
+ timestamps.put(key, timestampCombiner.combine(assignedTs, timestamps.get(key)));
+ }
+
+ @FinishBundle
+ public void outputAccumulators(FinishBundleContext context) {
+ for (Map.Entry<WindowedStructuralKey<K>, AccumT> preCombineEntry : accumulators.entrySet()) {
+ context.output(
+ KV.of(preCombineEntry.getKey().getKey(), combineFn.compact(preCombineEntry.getValue())),
+ timestamps.get(preCombineEntry.getKey()),
+ preCombineEntry.getKey().getWindow());
+ }
+ accumulators = null;
+ timestamps = null;
+ }
+ }
+
+ static class WindowedStructuralKey<K> {
+ public static <K> WindowedStructuralKey<K> create(
+ Coder<K> keyCoder, K key, BoundedWindow window) {
+ return new WindowedStructuralKey<>(StructuralKey.of(key, keyCoder), window);
+ }
+
+ private final StructuralKey<K> key;
+ private final BoundedWindow window;
+
+ private WindowedStructuralKey(StructuralKey<K> key, BoundedWindow window) {
+ this.key = checkNotNull(key, "key cannot be null");
+ this.window = checkNotNull(window, "Window cannot be null");
+ }
+
+ public K getKey() {
+ return key.getKey();
+ }
+
+ public BoundedWindow getWindow() {
+ return window;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof MultiStepCombine.WindowedStructuralKey)) {
+ return false;
+ }
+ WindowedStructuralKey that = (WindowedStructuralKey<?>) other;
+ return this.window.equals(that.window) && this.key.equals(that.key);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(window, key);
+ }
+ }
+
+ static final String DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN =
+ "urn:beam:directrunner:transforms:merge_accumulators_extract_output:v1";
+ /**
+ * A primitive {@link PTransform} that merges iterables of accumulators and extracts the output.
+ *
+ * <p>Required to ensure that Immutability Enforcement is not applied. Accumulators
+ * are explicitly mutable.
+ */
+ static class MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>
+ extends RawPTransform<PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>> {
+ private final CombineFn<?, AccumT, OutputT> combineFn;
+
+ private MergeAndExtractAccumulatorOutput(CombineFn<?, AccumT, OutputT> combineFn) {
+ this.combineFn = combineFn;
+ }
+
+ CombineFn<?, AccumT, OutputT> getCombineFn() {
+ return combineFn;
+ }
+
+ @Override
+ public PCollection<KV<K, OutputT>> expand(PCollection<KV<K, Iterable<AccumT>>> input) {
+ return PCollection.createPrimitiveOutputInternal(
+ input.getPipeline(), input.getWindowingStrategy(), input.isBounded());
+ }
+
+ @Nullable
+ @Override
+ public String getUrn() {
+ return DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN;
+ }
+ }
+
+ static class MergeAndExtractAccumulatorOutputEvaluatorFactory
+ implements TransformEvaluatorFactory {
+ private final EvaluationContext ctxt;
+
+ public MergeAndExtractAccumulatorOutputEvaluatorFactory(EvaluationContext ctxt) {
+ this.ctxt = ctxt;
+ }
+
+ @Nullable
+ @Override
+ public <InputT> TransformEvaluator<InputT> forApplication(
+ AppliedPTransform<?, ?, ?> application, CommittedBundle<?> inputBundle) throws Exception {
+ return createEvaluator((AppliedPTransform) application, (CommittedBundle) inputBundle);
+ }
+
+ private <K, AccumT, OutputT> TransformEvaluator<KV<K, Iterable<AccumT>>> createEvaluator(
+ AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application,
+ CommittedBundle<KV<K, Iterable<AccumT>>> inputBundle) {
+ return new MergeAccumulatorsAndExtractOutputEvaluator<>(ctxt, application);
+ }
+
+ @Override
+ public void cleanup() throws Exception {}
+ }
+
+ private static class MergeAccumulatorsAndExtractOutputEvaluator<K, AccumT, OutputT>
+ implements TransformEvaluator<KV<K, Iterable<AccumT>>> {
+ private final AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application;
+ private final CombineFn<?, AccumT, OutputT> combineFn;
+ private final UncommittedBundle<KV<K, OutputT>> output;
+
+ public MergeAccumulatorsAndExtractOutputEvaluator(
+ EvaluationContext ctxt,
+ AppliedPTransform<
+ PCollection<KV<K, Iterable<AccumT>>>, PCollection<KV<K, OutputT>>,
+ MergeAndExtractAccumulatorOutput<K, AccumT, OutputT>>
+ application) {
+ this.application = application;
+ this.combineFn = application.getTransform().getCombineFn();
+ this.output =
+ ctxt.createBundle(
+ (PCollection<KV<K, OutputT>>)
+ Iterables.getOnlyElement(application.getOutputs().values()));
+ }
+
+ @Override
+ public void processElement(WindowedValue<KV<K, Iterable<AccumT>>> element) throws Exception {
+ checkState(
+ element.getWindows().size() == 1,
+ "Expected inputs to %s to be in exactly one window. Got %s",
+ MergeAccumulatorsAndExtractOutputEvaluator.class.getSimpleName(),
+ element.getWindows().size());
+ Iterable<AccumT> inputAccumulators = element.getValue().getValue();
+ try {
+ AccumT first = combineFn.createAccumulator();
+ AccumT merged = combineFn.mergeAccumulators(Iterables.concat(Collections.singleton(first),
+ inputAccumulators,
+ Collections.singleton(combineFn.createAccumulator())));
+ OutputT extracted = combineFn.extractOutput(merged);
+ output.add(element.withValue(KV.of(element.getValue().getKey(), extracted)));
+ } catch (Exception e) {
+ throw UserCodeException.wrap(e);
+ }
+ }
+
+ @Override
+ public TransformResult<KV<K, Iterable<AccumT>>> finishBundle() throws Exception {
+ return StepTransformResult.<KV<K, Iterable<AccumT>>>withoutHold(application)
+ .addOutput(output)
+ .build();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
index 0c907df..30666db 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java
@@ -26,6 +26,7 @@ import static org.apache.beam.runners.core.construction.PTransformTranslation.WI
import static org.apache.beam.runners.core.construction.SplittableParDo.SPLITTABLE_PROCESS_URN;
import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GABW_URN;
import static org.apache.beam.runners.direct.DirectGroupByKey.DIRECT_GBKO_URN;
+import static org.apache.beam.runners.direct.MultiStepCombine.DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN;
import static org.apache.beam.runners.direct.ParDoMultiOverrideFactory.DIRECT_STATEFUL_PAR_DO_URN;
import static org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DIRECT_TEST_STREAM_URN;
import static org.apache.beam.runners.direct.ViewOverrideFactory.DIRECT_WRITE_VIEW_URN;
@@ -73,6 +74,9 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory {
.put(DIRECT_GBKO_URN, new GroupByKeyOnlyEvaluatorFactory(ctxt))
.put(DIRECT_GABW_URN, new GroupAlsoByWindowEvaluatorFactory(ctxt))
.put(DIRECT_TEST_STREAM_URN, new TestStreamEvaluatorFactory(ctxt))
+ .put(
+ DIRECT_MERGE_ACCUMULATORS_EXTRACT_OUTPUT_URN,
+ new MultiStepCombine.MergeAndExtractAccumulatorOutputEvaluatorFactory(ctxt))
// Runners-core primitives
.put(SPLITTABLE_PROCESS_URN, new SplittableProcessElementsEvaluatorFactory<>(ctxt))
http://git-wip-us.apache.org/repos/asf/beam/blob/6ea2eda2/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java
new file mode 100644
index 0000000..0c11a8a
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/MultiStepCombineTest.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.is;
+import static org.junit.Assert.assertThat;
+
+import com.google.auto.value.AutoValue;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.VarInt;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TimestampedValue;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link MultiStepCombine}.
+ */
+@RunWith(JUnit4.class)
+public class MultiStepCombineTest implements Serializable {
+ @Rule public transient TestPipeline pipeline = TestPipeline.create();
+
+ private transient KvCoder<String, Long> combinedCoder =
+ KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of());
+
+ @Test
+ public void testMultiStepCombine() {
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.of(
+ KV.of("foo", 1L),
+ KV.of("bar", 2L),
+ KV.of("bizzle", 3L),
+ KV.of("bar", 4L),
+ KV.of("bizzle", 11L)))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+
+ PAssert.that(combined)
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 14L));
+ pipeline.run();
+ }
+
+ @Test
+ public void testMultiStepCombineWindowed() {
+ SlidingWindows windowFn = SlidingWindows.of(Duration.millis(6L)).every(Duration.millis(3L));
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.timestamped(
+ TimestampedValue.of(KV.of("foo", 1L), new Instant(1L)),
+ TimestampedValue.of(KV.of("bar", 2L), new Instant(2L)),
+ TimestampedValue.of(KV.of("bizzle", 3L), new Instant(3L)),
+ TimestampedValue.of(KV.of("bar", 4L), new Instant(4L)),
+ TimestampedValue.of(KV.of("bizzle", 11L), new Instant(11L))))
+ .apply(Window.<KV<String, Long>>into(windowFn))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+
+ PAssert.that("Windows should combine only elements in their windows", combined)
+ .inWindow(new IntervalWindow(new Instant(0L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 6L), KV.of("bizzle", 3L));
+ PAssert.that("Elements should appear in all the windows they are assigned to", combined)
+ .inWindow(new IntervalWindow(new Instant(-3L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("foo", 1L), KV.of("bar", 2L));
+ PAssert.that(combined)
+ .inWindow(new IntervalWindow(new Instant(6L), Duration.millis(6L)))
+ .containsInAnyOrder(KV.of("bizzle", 11L));
+ PAssert.that(combined)
+ .containsInAnyOrder(
+ KV.of("foo", 1L),
+ KV.of("foo", 1L),
+ KV.of("bar", 6L),
+ KV.of("bar", 2L),
+ KV.of("bar", 4L),
+ KV.of("bizzle", 11L),
+ KV.of("bizzle", 11L),
+ KV.of("bizzle", 3L),
+ KV.of("bizzle", 3L));
+ pipeline.run();
+ }
+
+ @Test
+ public void testMultiStepCombineTimestampCombiner() {
+ TimestampCombiner combiner = TimestampCombiner.LATEST;
+ combinedCoder = KvCoder.of(StringUtf8Coder.of(), VarLongCoder.of());
+ PCollection<KV<String, Long>> combined =
+ pipeline
+ .apply(
+ Create.timestamped(
+ TimestampedValue.of(KV.of("foo", 4L), new Instant(1L)),
+ TimestampedValue.of(KV.of("foo", 1L), new Instant(4L)),
+ TimestampedValue.of(KV.of("bazzle", 4L), new Instant(4L)),
+ TimestampedValue.of(KV.of("foo", 12L), new Instant(12L))))
+ .apply(
+ Window.<KV<String, Long>>into(FixedWindows.of(Duration.millis(5L)))
+ .withTimestampCombiner(combiner))
+ .apply(Combine.<String, Long, Long>perKey(new MultiStepCombineFn()));
+ PCollection<KV<String, TimestampedValue<Long>>> reified =
+ combined.apply(
+ ParDo.of(
+ new DoFn<KV<String, Long>, KV<String, TimestampedValue<Long>>>() {
+ @ProcessElement
+ public void reifyTimestamp(ProcessContext context) {
+ context.output(
+ KV.of(
+ context.element().getKey(),
+ TimestampedValue.of(
+ context.element().getValue(), context.timestamp())));
+ }
+ }));
+
+ PAssert.that(reified)
+ .containsInAnyOrder(
+ KV.of("foo", TimestampedValue.of(5L, new Instant(4L))),
+ KV.of("bazzle", TimestampedValue.of(4L, new Instant(4L))),
+ KV.of("foo", TimestampedValue.of(12L, new Instant(12L))));
+ pipeline.run();
+ }
+
+ private static class MultiStepCombineFn extends CombineFn<Long, MultiStepAccumulator, Long> {
+ @Override
+ public Coder<MultiStepAccumulator> getAccumulatorCoder(
+ CoderRegistry registry, Coder<Long> inputCoder) throws CannotProvideCoderException {
+ return new MultiStepAccumulatorCoder();
+ }
+
+ @Override
+ public MultiStepAccumulator createAccumulator() {
+ return MultiStepAccumulator.of(0L, false);
+ }
+
+ @Override
+ public MultiStepAccumulator addInput(MultiStepAccumulator accumulator, Long input) {
+ return MultiStepAccumulator.of(accumulator.getValue() + input, accumulator.isDeserialized());
+ }
+
+ @Override
+ public MultiStepAccumulator mergeAccumulators(Iterable<MultiStepAccumulator> accumulators) {
+ MultiStepAccumulator result = MultiStepAccumulator.of(0L, false);
+ for (MultiStepAccumulator accumulator : accumulators) {
+ result = result.merge(accumulator);
+ }
+ return result;
+ }
+
+ @Override
+ public Long extractOutput(MultiStepAccumulator accumulator) {
+ assertThat(
+ "Accumulators should have been serialized and deserialized within the Pipeline",
+ accumulator.isDeserialized(),
+ is(true));
+ return accumulator.getValue();
+ }
+ }
+
+ @AutoValue
+ abstract static class MultiStepAccumulator {
+ private static MultiStepAccumulator of(long value, boolean deserialized) {
+ return new AutoValue_MultiStepCombineTest_MultiStepAccumulator(value, deserialized);
+ }
+
+ MultiStepAccumulator merge(MultiStepAccumulator other) {
+ return MultiStepAccumulator.of(
+ this.getValue() + other.getValue(), this.isDeserialized() || other.isDeserialized());
+ }
+
+ abstract long getValue();
+
+ abstract boolean isDeserialized();
+ }
+
+ private static class MultiStepAccumulatorCoder extends CustomCoder<MultiStepAccumulator> {
+ @Override
+ public void encode(MultiStepAccumulator value, OutputStream outStream)
+ throws CoderException, IOException {
+ VarInt.encode(value.getValue(), outStream);
+ }
+
+ @Override
+ public MultiStepAccumulator decode(InputStream inStream) throws CoderException, IOException {
+ return MultiStepAccumulator.of(VarInt.decodeLong(inStream), true);
+ }
+ }
+}