You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bc...@apache.org on 2016/03/17 22:17:08 UTC
[1/2] incubator-beam git commit: [BEAM-96] Add composed `CombineFn`
builders in `CombineFns`
Repository: incubator-beam
Updated Branches:
refs/heads/master ac63fd6d4 -> c30326007
[BEAM-96] Add composed `CombineFn` builders in `CombineFns`
* `compose()` or `composeKeyed()` are used to start composition
* `with()` is used to add an input-transformation, a `CombineFn`
and an output `TupleTag`.
* A non-`CombineFn` initial builder is used to ensure that every
composition includes at least one item
* Duplicate output tags are not allowed in the same composition
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/23b43780
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/23b43780
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/23b43780
Branch: refs/heads/master
Commit: 23b437802546f32a167b38f8d0bc7a566abde224
Parents: ac63fd6
Author: Pei He <pe...@google.com>
Authored: Fri Mar 4 13:54:34 2016 -0800
Committer: bchambers <bc...@google.com>
Committed: Thu Mar 17 13:54:40 2016 -0700
----------------------------------------------------------------------
.../dataflow/sdk/transforms/CombineFns.java | 1100 ++++++++++++++++++
.../dataflow/sdk/transforms/CombineFnsTest.java | 413 +++++++
2 files changed, 1513 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
----------------------------------------------------------------------
diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
new file mode 100644
index 0000000..656c010
--- /dev/null
+++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/transforms/CombineFns.java
@@ -0,0 +1,1100 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.transforms;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
+import com.google.cloud.dataflow.sdk.coders.StandardCoder;
+import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn;
+import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.PerKeyCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
+import com.google.cloud.dataflow.sdk.util.PropertyNames;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+/**
+ * Static utility methods that create combine function instances.
+ */
+public class CombineFns {
+
+ /**
+ * Returns a {@link ComposeKeyedCombineFnBuilder} to construct a composed
+ * {@link PerKeyCombineFn}.
+ *
+ * <p>The same {@link TupleTag} cannot be used in a composition multiple times.
+ *
+ * <p>Example:
+ * <pre>{ @code
+ * PCollection<KV<K, Integer>> latencies = ...;
+ *
+ * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>();
+ * TupleTag<Double> meanLatencyTag = new TupleTag<Double>();
+ *
+ * SimpleFunction<Integer, Integer> identityFn =
+ * new SimpleFunction<Integer, Integer>() {
+ * @Override
+ * public Integer apply(Integer input) {
+ * return input;
+ * }};
+ * PCollection<KV<K, CoCombineResult>> maxAndMean = latencies.apply(
+ * Combine.perKey(
+ * CombineFns.composeKeyed()
+ * .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+ * .with(identityFn, new MeanFn<Integer>(), meanLatencyTag)));
+ *
+ * PCollection<T> finalResultCollection = maxAndMean
+ * .apply(ParDo.of(
+ * new DoFn<KV<K, CoCombineResult>, T>() {
+ * @Override
+ * public void processElement(ProcessContext c) throws Exception {
+ * KV<K, CoCombineResult> e = c.element();
+ * Integer maxLatency = e.getValue().get(maxLatencyTag);
+ * Double meanLatency = e.getValue().get(meanLatencyTag);
+ * .... Do Something ....
+ * c.output(...some T...);
+ * }
+ * }));
+ * } </pre>
+ */
+ public static ComposeKeyedCombineFnBuilder composeKeyed() {
+ return new ComposeKeyedCombineFnBuilder();
+ }
+
+ /**
+ * Returns a {@link ComposeCombineFnBuilder} to construct a composed
+ * {@link GlobalCombineFn}.
+ *
+ * <p>The same {@link TupleTag} cannot be used in a composition multiple times.
+ *
+ * <p>Example:
+ * <pre>{ @code
+ * PCollection<Integer> globalLatencies = ...;
+ *
+ * TupleTag<Integer> maxLatencyTag = new TupleTag<Integer>();
+ * TupleTag<Double> meanLatencyTag = new TupleTag<Double>();
+ *
+ * SimpleFunction<Integer, Integer> identityFn =
+ * new SimpleFunction<Integer, Integer>() {
+ * @Override
+ * public Integer apply(Integer input) {
+ * return input;
+ * }};
+ * PCollection<CoCombineResult> maxAndMean = globalLatencies.apply(
+ * Combine.globally(
+ * CombineFns.compose()
+ * .with(identityFn, new MaxIntegerFn(), maxLatencyTag)
+ * .with(identityFn, new MeanFn<Integer>(), meanLatencyTag)));
+ *
+ * PCollection<T> finalResultCollection = maxAndMean
+ * .apply(ParDo.of(
+ * new DoFn<CoCombineResult, T>() {
+ * @Override
+ * public void processElement(ProcessContext c) throws Exception {
+ * CoCombineResult e = c.element();
+ * Integer maxLatency = e.get(maxLatencyTag);
+ * Double meanLatency = e.get(meanLatencyTag);
+ * .... Do Something ....
+ * c.output(...some T...);
+ * }
+ * }));
+ * } </pre>
+ */
+ public static ComposeCombineFnBuilder compose() {
+ return new ComposeCombineFnBuilder();
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * A builder class to construct a composed {@link PerKeyCombineFn}.
+ */
+ public static class ComposeKeyedCombineFnBuilder {
+ /**
+ * Returns a {@link ComposedKeyedCombineFn} that can take additional
+ * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+ *
+ * <p>The {@link ComposedKeyedCombineFn} extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them with the {@code keyedCombineFn},
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn,
+ TupleTag<OutputT> outputTag) {
+ return new ComposedKeyedCombineFn<DataT, K>()
+ .with(extractInputFn, keyedCombineFn, outputTag);
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional
+ * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+ *
+ * <p>The {@link ComposedKeyedCombineFnWithContext} extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them with the {@code keyedCombineFnWithContext},
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFnWithContext,
+ TupleTag<OutputT> outputTag) {
+ return new ComposedKeyedCombineFnWithContext<DataT, K>()
+ .with(extractInputFn, keyedCombineFnWithContext, outputTag);
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFn} that can take additional
+ * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+ */
+ public <K, DataT, InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFn<InputT, ?, OutputT> combineFn,
+ TupleTag<OutputT> outputTag) {
+ return with(extractInputFn, combineFn.<K>asKeyedFn(), outputTag);
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} that can take additional
+ * {@link PerKeyCombineFn PerKeyCombineFns} and apply them as a single combine function.
+ */
+ public <K, DataT, InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext,
+ TupleTag<OutputT> outputTag) {
+ return with(extractInputFn, combineFnWithContext.<K>asKeyedFn(), outputTag);
+ }
+ }
+
+ /**
+ * A builder class to construct a composed {@link GlobalCombineFn}.
+ */
+ public static class ComposeCombineFnBuilder {
+ /**
+ * Returns a {@link ComposedCombineFn} that can take additional
+ * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function.
+ *
+ * <p>The {@link ComposedCombineFn} extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them with the {@code combineFn},
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public <DataT, InputT, OutputT> ComposedCombineFn<DataT> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFn<InputT, ?, OutputT> combineFn,
+ TupleTag<OutputT> outputTag) {
+ return new ComposedCombineFn<DataT>()
+ .with(extractInputFn, combineFn, outputTag);
+ }
+
+ /**
+ * Returns a {@link ComposedCombineFnWithContext} that can take additional
+ * {@link GlobalCombineFn GlobalCombineFns} and apply them as a single combine function.
+ *
+ * <p>The {@link ComposedCombineFnWithContext} extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them with the {@code combineFnWithContext},
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public <DataT, InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFnWithContext<InputT, ?, OutputT> combineFnWithContext,
+ TupleTag<OutputT> outputTag) {
+ return new ComposedCombineFnWithContext<DataT>()
+ .with(extractInputFn, combineFnWithContext, outputTag);
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * A tuple of outputs produced by a composed combine functions.
+ *
+ * <p>See {@link #compose()} or {@link #composeKeyed()}) for details.
+ */
+ public static class CoCombineResult implements Serializable {
+
+ private enum NullValue {
+ INSTANCE;
+ }
+
+ private final Map<TupleTag<?>, Object> valuesMap;
+
+ /**
+ * The constructor of {@link CoCombineResult}.
+ *
+ * <p>Null values should have been filtered out from the {@code valuesMap}.
+ * {@link TupleTag TupleTags} that associate with null values doesn't exist in the key set of
+ * {@code valuesMap}.
+ *
+ * @throws NullPointerException if any key or value in {@code valuesMap} is null
+ */
+ CoCombineResult(Map<TupleTag<?>, Object> valuesMap) {
+ ImmutableMap.Builder<TupleTag<?>, Object> builder = ImmutableMap.builder();
+ for (Entry<TupleTag<?>, Object> entry : valuesMap.entrySet()) {
+ if (entry.getValue() != null) {
+ builder.put(entry);
+ } else {
+ builder.put(entry.getKey(), NullValue.INSTANCE);
+ }
+ }
+ this.valuesMap = builder.build();
+ }
+
+ /**
+ * Returns the value represented by the given {@link TupleTag}.
+ *
+ * <p>It is an error to request a non-exist tuple tag from the {@link CoCombineResult}.
+ */
+ @SuppressWarnings("unchecked")
+ public <V> V get(TupleTag<V> tag) {
+ checkArgument(
+ valuesMap.keySet().contains(tag), "TupleTag " + tag + " is not in the CoCombineResult");
+ Object value = valuesMap.get(tag);
+ if (value == NullValue.INSTANCE) {
+ return null;
+ } else {
+ return (V) value;
+ }
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * A composed {@link CombineFn} that applies multiple {@link CombineFn CombineFns}.
+ *
+ * <p>For each {@link CombineFn} it extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them,
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public static class ComposedCombineFn<DataT> extends CombineFn<DataT, Object[], CoCombineResult> {
+
+ private final List<CombineFn<Object, Object, Object>> combineFns;
+ private final List<SerializableFunction<DataT, Object>> extractInputFns;
+ private final List<TupleTag<?>> outputTags;
+ private final int combineFnCount;
+
+ private ComposedCombineFn() {
+ this.extractInputFns = ImmutableList.of();
+ this.combineFns = ImmutableList.of();
+ this.outputTags = ImmutableList.of();
+ this.combineFnCount = 0;
+ }
+
+ private ComposedCombineFn(
+ ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+ ImmutableList<CombineFn<?, ?, ?>> combineFns,
+ ImmutableList<TupleTag<?>> outputTags) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns;
+ this.extractInputFns = castedExtractInputFns;
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<CombineFn<Object, Object, Object>> castedCombineFns = (List) combineFns;
+ this.combineFns = castedCombineFns;
+
+ this.outputTags = outputTags;
+ this.combineFnCount = this.combineFns.size();
+ }
+
+ /**
+ * Returns a {@link ComposedCombineFn} with an additional {@link CombineFn}.
+ */
+ public <InputT, OutputT> ComposedCombineFn<DataT> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFn<InputT, ?, OutputT> combineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ return new ComposedCombineFn<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<CombineFn<?, ?, ?>>builder()
+ .addAll(combineFns)
+ .add(combineFn)
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ /**
+ * Returns a {@link ComposedCombineFnWithContext} with an additional
+ * {@link CombineFnWithContext}.
+ */
+ public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFnWithContext<InputT, ?, OutputT> combineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ List<CombineFnWithContext<Object, Object, Object>> fnsWithContext = Lists.newArrayList();
+ for (CombineFn<Object, Object, Object> fn : combineFns) {
+ fnsWithContext.add(toFnWithContext(fn));
+ }
+ return new ComposedCombineFnWithContext<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<CombineFnWithContext<?, ?, ?>>builder()
+ .addAll(fnsWithContext)
+ .add(combineFn)
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ @Override
+ public Object[] createAccumulator() {
+ Object[] accumsArray = new Object[combineFnCount];
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumsArray[i] = combineFns.get(i).createAccumulator();
+ }
+ return accumsArray;
+ }
+
+ @Override
+ public Object[] addInput(Object[] accumulator, DataT value) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ Object input = extractInputFns.get(i).apply(value);
+ accumulator[i] = combineFns.get(i).addInput(accumulator[i], input);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Object[] mergeAccumulators(Iterable<Object[]> accumulators) {
+ Iterator<Object[]> iter = accumulators.iterator();
+ if (!iter.hasNext()) {
+ return createAccumulator();
+ } else {
+ // Reuses the first accumulator, and overwrites its values.
+ // It is safe because {@code accum[i]} only depends on
+ // the i-th component of each accumulator.
+ Object[] accum = iter.next();
+ for (int i = 0; i < combineFnCount; ++i) {
+ accum[i] = combineFns.get(i).mergeAccumulators(new ProjectionIterable(accumulators, i));
+ }
+ return accum;
+ }
+ }
+
+ @Override
+ public CoCombineResult extractOutput(Object[] accumulator) {
+ Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+ for (int i = 0; i < combineFnCount; ++i) {
+ valuesMap.put(
+ outputTags.get(i),
+ combineFns.get(i).extractOutput(accumulator[i]));
+ }
+ return new CoCombineResult(valuesMap);
+ }
+
+ @Override
+ public Object[] compact(Object[] accumulator) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumulator[i] = combineFns.get(i).compact(accumulator[i]);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder)
+ throws CannotProvideCoderException {
+ List<Coder<Object>> coders = Lists.newArrayList();
+ for (int i = 0; i < combineFnCount; ++i) {
+ Coder<Object> inputCoder =
+ registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+ coders.add(combineFns.get(i).getAccumulatorCoder(registry, inputCoder));
+ }
+ return new ComposedAccumulatorCoder(coders);
+ }
+ }
+
+ /**
+ * A composed {@link CombineFnWithContext} that applies multiple
+ * {@link CombineFnWithContext CombineFnWithContexts}.
+ *
+ * <p>For each {@link CombineFnWithContext} it extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them,
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public static class ComposedCombineFnWithContext<DataT>
+ extends CombineFnWithContext<DataT, Object[], CoCombineResult> {
+
+ private final List<SerializableFunction<DataT, Object>> extractInputFns;
+ private final List<CombineFnWithContext<Object, Object, Object>> combineFnWithContexts;
+ private final List<TupleTag<?>> outputTags;
+ private final int combineFnCount;
+
+ private ComposedCombineFnWithContext() {
+ this.extractInputFns = ImmutableList.of();
+ this.combineFnWithContexts = ImmutableList.of();
+ this.outputTags = ImmutableList.of();
+ this.combineFnCount = 0;
+ }
+
+ private ComposedCombineFnWithContext(
+ ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+ ImmutableList<CombineFnWithContext<?, ?, ?>> combineFnWithContexts,
+ ImmutableList<TupleTag<?>> outputTags) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<SerializableFunction<DataT, Object>> castedExtractInputFns =
+ (List) extractInputFns;
+ this.extractInputFns = castedExtractInputFns;
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ List<CombineFnWithContext<Object, Object, Object>> castedCombineFnWithContexts
+ = (List) combineFnWithContexts;
+ this.combineFnWithContexts = castedCombineFnWithContexts;
+
+ this.outputTags = outputTags;
+ this.combineFnCount = this.combineFnWithContexts.size();
+ }
+
+ /**
+ * Returns a {@link ComposedCombineFnWithContext} with an additional {@link GlobalCombineFn}.
+ */
+ public <InputT, OutputT> ComposedCombineFnWithContext<DataT> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ GlobalCombineFn<InputT, ?, OutputT> globalCombineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ return new ComposedCombineFnWithContext<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<CombineFnWithContext<?, ?, ?>>builder()
+ .addAll(combineFnWithContexts)
+ .add(toFnWithContext(globalCombineFn))
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ @Override
+ public Object[] createAccumulator(Context c) {
+ Object[] accumsArray = new Object[combineFnCount];
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumsArray[i] = combineFnWithContexts.get(i).createAccumulator(c);
+ }
+ return accumsArray;
+ }
+
+ @Override
+ public Object[] addInput(Object[] accumulator, DataT value, Context c) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ Object input = extractInputFns.get(i).apply(value);
+ accumulator[i] = combineFnWithContexts.get(i).addInput(accumulator[i], input, c);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Object[] mergeAccumulators(Iterable<Object[]> accumulators, Context c) {
+ Iterator<Object[]> iter = accumulators.iterator();
+ if (!iter.hasNext()) {
+ return createAccumulator(c);
+ } else {
+ // Reuses the first accumulator, and overwrites its values.
+ // It is safe because {@code accum[i]} only depends on
+ // the i-th component of each accumulator.
+ Object[] accum = iter.next();
+ for (int i = 0; i < combineFnCount; ++i) {
+ accum[i] = combineFnWithContexts.get(i).mergeAccumulators(
+ new ProjectionIterable(accumulators, i), c);
+ }
+ return accum;
+ }
+ }
+
+ @Override
+ public CoCombineResult extractOutput(Object[] accumulator, Context c) {
+ Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+ for (int i = 0; i < combineFnCount; ++i) {
+ valuesMap.put(
+ outputTags.get(i),
+ combineFnWithContexts.get(i).extractOutput(accumulator[i], c));
+ }
+ return new CoCombineResult(valuesMap);
+ }
+
+ @Override
+ public Object[] compact(Object[] accumulator, Context c) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumulator[i] = combineFnWithContexts.get(i).compact(accumulator[i], c);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Coder<Object[]> getAccumulatorCoder(CoderRegistry registry, Coder<DataT> dataCoder)
+ throws CannotProvideCoderException {
+ List<Coder<Object>> coders = Lists.newArrayList();
+ for (int i = 0; i < combineFnCount; ++i) {
+ Coder<Object> inputCoder =
+ registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+ coders.add(combineFnWithContexts.get(i).getAccumulatorCoder(registry, inputCoder));
+ }
+ return new ComposedAccumulatorCoder(coders);
+ }
+ }
+
+ /**
+ * A composed {@link KeyedCombineFn} that applies multiple {@link KeyedCombineFn KeyedCombineFns}.
+ *
+ * <p>For each {@link KeyedCombineFn} it extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them,
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public static class ComposedKeyedCombineFn<DataT, K>
+ extends KeyedCombineFn<K, DataT, Object[], CoCombineResult> {
+
+ private final List<SerializableFunction<DataT, Object>> extractInputFns;
+ private final List<KeyedCombineFn<K, Object, Object, Object>> keyedCombineFns;
+ private final List<TupleTag<?>> outputTags;
+ private final int combineFnCount;
+
+ private ComposedKeyedCombineFn() {
+ this.extractInputFns = ImmutableList.of();
+ this.keyedCombineFns = ImmutableList.of();
+ this.outputTags = ImmutableList.of();
+ this.combineFnCount = 0;
+ }
+
+ private ComposedKeyedCombineFn(
+ ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+ ImmutableList<KeyedCombineFn<K, ?, ?, ?>> keyedCombineFns,
+ ImmutableList<TupleTag<?>> outputTags) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<SerializableFunction<DataT, Object>> castedExtractInputFns = (List) extractInputFns;
+ this.extractInputFns = castedExtractInputFns;
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<KeyedCombineFn<K, Object, Object, Object>> castedKeyedCombineFns =
+ (List) keyedCombineFns;
+ this.keyedCombineFns = castedKeyedCombineFns;
+ this.outputTags = outputTags;
+ this.combineFnCount = this.keyedCombineFns.size();
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFn} with an additional {@link KeyedCombineFn}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ KeyedCombineFn<K, InputT, ?, OutputT> keyedCombineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ return new ComposedKeyedCombineFn<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<KeyedCombineFn<K, ?, ?, ?>>builder()
+ .addAll(keyedCombineFns)
+ .add(keyedCombineFn)
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+ * {@link KeyedCombineFnWithContext}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ KeyedCombineFnWithContext<K, InputT, ?, OutputT> keyedCombineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ List<KeyedCombineFnWithContext<K, Object, Object, Object>> fnsWithContext =
+ Lists.newArrayList();
+ for (KeyedCombineFn<K, Object, Object, Object> fn : keyedCombineFns) {
+ fnsWithContext.add(toFnWithContext(fn));
+ }
+ return new ComposedKeyedCombineFnWithContext<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder()
+ .addAll(fnsWithContext)
+ .add(keyedCombineFn)
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFn} with an additional {@link CombineFn}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFn<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFn<InputT, ?, OutputT> keyedCombineFn,
+ TupleTag<OutputT> outputTag) {
+ return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag);
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+ * {@link CombineFnWithContext}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ CombineFnWithContext<InputT, ?, OutputT> keyedCombineFn,
+ TupleTag<OutputT> outputTag) {
+ return with(extractInputFn, keyedCombineFn.<K>asKeyedFn(), outputTag);
+ }
+
+ @Override
+ public Object[] createAccumulator(K key) {
+ Object[] accumsArray = new Object[combineFnCount];
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key);
+ }
+ return accumsArray;
+ }
+
+ @Override
+ public Object[] addInput(K key, Object[] accumulator, DataT value) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ Object input = extractInputFns.get(i).apply(value);
+ accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Object[] mergeAccumulators(K key, final Iterable<Object[]> accumulators) {
+ Iterator<Object[]> iter = accumulators.iterator();
+ if (!iter.hasNext()) {
+ return createAccumulator(key);
+ } else {
+ // Reuses the first accumulator, and overwrites its values.
+ // It is safe because {@code accum[i]} only depends on
+ // the i-th component of each accumulator.
+ Object[] accum = iter.next();
+ for (int i = 0; i < combineFnCount; ++i) {
+ accum[i] = keyedCombineFns.get(i).mergeAccumulators(
+ key, new ProjectionIterable(accumulators, i));
+ }
+ return accum;
+ }
+ }
+
+ @Override
+ public CoCombineResult extractOutput(K key, Object[] accumulator) {
+ Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+ for (int i = 0; i < combineFnCount; ++i) {
+ valuesMap.put(
+ outputTags.get(i),
+ keyedCombineFns.get(i).extractOutput(key, accumulator[i]));
+ }
+ return new CoCombineResult(valuesMap);
+ }
+
+ @Override
+ public Object[] compact(K key, Object[] accumulator) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i]);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Coder<Object[]> getAccumulatorCoder(
+ CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder)
+ throws CannotProvideCoderException {
+ List<Coder<Object>> coders = Lists.newArrayList();
+ for (int i = 0; i < combineFnCount; ++i) {
+ Coder<Object> inputCoder =
+ registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+ coders.add(keyedCombineFns.get(i).getAccumulatorCoder(registry, keyCoder, inputCoder));
+ }
+ return new ComposedAccumulatorCoder(coders);
+ }
+ }
+
+ /**
+ * A composed {@link KeyedCombineFnWithContext} that applies multiple
+ * {@link KeyedCombineFnWithContext KeyedCombineFnWithContexts}.
+ *
+ * <p>For each {@link KeyedCombineFnWithContext} it extracts inputs from {@code DataT} with
+ * the {@code extractInputFn} and combines them,
+ * and then it outputs each combined value with a {@link TupleTag} to a
+ * {@link CoCombineResult}.
+ */
+ public static class ComposedKeyedCombineFnWithContext<DataT, K>
+ extends KeyedCombineFnWithContext<K, DataT, Object[], CoCombineResult> {
+
+ private final List<SerializableFunction<DataT, Object>> extractInputFns;
+ private final List<KeyedCombineFnWithContext<K, Object, Object, Object>> keyedCombineFns;
+ private final List<TupleTag<?>> outputTags;
+ private final int combineFnCount;
+
+ private ComposedKeyedCombineFnWithContext() {
+ this.extractInputFns = ImmutableList.of();
+ this.keyedCombineFns = ImmutableList.of();
+ this.outputTags = ImmutableList.of();
+ this.combineFnCount = 0;
+ }
+
+ private ComposedKeyedCombineFnWithContext(
+ ImmutableList<SerializableFunction<DataT, ?>> extractInputFns,
+ ImmutableList<KeyedCombineFnWithContext<K, ?, ?, ?>> keyedCombineFns,
+ ImmutableList<TupleTag<?>> outputTags) {
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<SerializableFunction<DataT, Object>> castedExtractInputFns =
+ (List) extractInputFns;
+ this.extractInputFns = castedExtractInputFns;
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ List<KeyedCombineFnWithContext<K, Object, Object, Object>> castedKeyedCombineFns =
+ (List) keyedCombineFns;
+ this.keyedCombineFns = castedKeyedCombineFns;
+ this.outputTags = outputTags;
+ this.combineFnCount = this.keyedCombineFns.size();
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+ * {@link PerKeyCombineFn}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ PerKeyCombineFn<K, InputT, ?, OutputT> perKeyCombineFn,
+ TupleTag<OutputT> outputTag) {
+ checkUniqueness(outputTags, outputTag);
+ return new ComposedKeyedCombineFnWithContext<>(
+ ImmutableList.<SerializableFunction<DataT, ?>>builder()
+ .addAll(extractInputFns)
+ .add(extractInputFn)
+ .build(),
+ ImmutableList.<KeyedCombineFnWithContext<K, ?, ?, ?>>builder()
+ .addAll(keyedCombineFns)
+ .add(toFnWithContext(perKeyCombineFn))
+ .build(),
+ ImmutableList.<TupleTag<?>>builder()
+ .addAll(outputTags)
+ .add(outputTag)
+ .build());
+ }
+
+ /**
+ * Returns a {@link ComposedKeyedCombineFnWithContext} with an additional
+ * {@link GlobalCombineFn}.
+ */
+ public <InputT, OutputT> ComposedKeyedCombineFnWithContext<DataT, K> with(
+ SimpleFunction<DataT, InputT> extractInputFn,
+ GlobalCombineFn<InputT, ?, OutputT> perKeyCombineFn,
+ TupleTag<OutputT> outputTag) {
+ return with(extractInputFn, perKeyCombineFn.<K>asKeyedFn(), outputTag);
+ }
+
+ @Override
+ public Object[] createAccumulator(K key, Context c) {
+ Object[] accumsArray = new Object[combineFnCount];
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumsArray[i] = keyedCombineFns.get(i).createAccumulator(key, c);
+ }
+ return accumsArray;
+ }
+
+ @Override
+ public Object[] addInput(K key, Object[] accumulator, DataT value, Context c) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ Object input = extractInputFns.get(i).apply(value);
+ accumulator[i] = keyedCombineFns.get(i).addInput(key, accumulator[i], input, c);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Object[] mergeAccumulators(K key, Iterable<Object[]> accumulators, Context c) {
+ Iterator<Object[]> iter = accumulators.iterator();
+ if (!iter.hasNext()) {
+ return createAccumulator(key, c);
+ } else {
+ // Reuses the first accumulator, and overwrites its values.
+ // It is safe because {@code accum[i]} only depends on
+ // the i-th component of each accumulator.
+ Object[] accum = iter.next();
+ for (int i = 0; i < combineFnCount; ++i) {
+ accum[i] = keyedCombineFns.get(i).mergeAccumulators(
+ key, new ProjectionIterable(accumulators, i), c);
+ }
+ return accum;
+ }
+ }
+
+ @Override
+ public CoCombineResult extractOutput(K key, Object[] accumulator, Context c) {
+ Map<TupleTag<?>, Object> valuesMap = Maps.newHashMap();
+ for (int i = 0; i < combineFnCount; ++i) {
+ valuesMap.put(
+ outputTags.get(i),
+ keyedCombineFns.get(i).extractOutput(key, accumulator[i], c));
+ }
+ return new CoCombineResult(valuesMap);
+ }
+
+ @Override
+ public Object[] compact(K key, Object[] accumulator, Context c) {
+ for (int i = 0; i < combineFnCount; ++i) {
+ accumulator[i] = keyedCombineFns.get(i).compact(key, accumulator[i], c);
+ }
+ return accumulator;
+ }
+
+ @Override
+ public Coder<Object[]> getAccumulatorCoder(
+ CoderRegistry registry, Coder<K> keyCoder, Coder<DataT> dataCoder)
+ throws CannotProvideCoderException {
+ List<Coder<Object>> coders = Lists.newArrayList();
+ for (int i = 0; i < combineFnCount; ++i) {
+ Coder<Object> inputCoder =
+ registry.getDefaultOutputCoder(extractInputFns.get(i), dataCoder);
+ coders.add(keyedCombineFns.get(i).getAccumulatorCoder(
+ registry, keyCoder, inputCoder));
+ }
+ return new ComposedAccumulatorCoder(coders);
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ private static class ProjectionIterable implements Iterable<Object> {
+ private final Iterable<Object[]> iterable;
+ private final int column;
+
+ private ProjectionIterable(Iterable<Object[]> iterable, int column) {
+ this.iterable = iterable;
+ this.column = column;
+ }
+
+ @Override
+ public Iterator<Object> iterator() {
+ final Iterator<Object[]> iter = iterable.iterator();
+ return new Iterator<Object>() {
+ @Override
+ public boolean hasNext() {
+ return iter.hasNext();
+ }
+
+ @Override
+ public Object next() {
+ return iter.next()[column];
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+ }
+
+ private static class ComposedAccumulatorCoder extends StandardCoder<Object[]> {
+ private List<Coder<Object>> coders;
+ private int codersCount;
+
+ public ComposedAccumulatorCoder(List<Coder<Object>> coders) {
+ this.coders = ImmutableList.copyOf(coders);
+ this.codersCount = coders.size();
+ }
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ @JsonCreator
+ public static ComposedAccumulatorCoder of(
+ @JsonProperty(PropertyNames.COMPONENT_ENCODINGS)
+ List<Coder<?>> components) {
+ return new ComposedAccumulatorCoder((List) components);
+ }
+
+ @Override
+ public void encode(Object[] value, OutputStream outStream, Context context)
+ throws CoderException, IOException {
+ checkArgument(value.length == codersCount);
+ Context nestedContext = context.nested();
+ for (int i = 0; i < codersCount; ++i) {
+ coders.get(i).encode(value[i], outStream, nestedContext);
+ }
+ }
+
+ @Override
+ public Object[] decode(InputStream inStream, Context context)
+ throws CoderException, IOException {
+ Object[] ret = new Object[codersCount];
+ Context nestedContext = context.nested();
+ for (int i = 0; i < codersCount; ++i) {
+ ret[i] = coders.get(i).decode(inStream, nestedContext);
+ }
+ return ret;
+ }
+
+ @Override
+ public List<? extends Coder<?>> getCoderArguments() {
+ return coders;
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ for (int i = 0; i < codersCount; ++i) {
+ coders.get(i).verifyDeterministic();
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static <InputT, AccumT, OutputT> CombineFnWithContext<InputT, AccumT, OutputT>
+ toFnWithContext(GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) {
+ if (globalCombineFn instanceof CombineFnWithContext) {
+ return (CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn;
+ } else {
+ final CombineFn<InputT, AccumT, OutputT> combineFn =
+ (CombineFn<InputT, AccumT, OutputT>) globalCombineFn;
+ return new CombineFnWithContext<InputT, AccumT, OutputT>() {
+ @Override
+ public AccumT createAccumulator(Context c) {
+ return combineFn.createAccumulator();
+ }
+ @Override
+ public AccumT addInput(AccumT accumulator, InputT input, Context c) {
+ return combineFn.addInput(accumulator, input);
+ }
+ @Override
+ public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) {
+ return combineFn.mergeAccumulators(accumulators);
+ }
+ @Override
+ public OutputT extractOutput(AccumT accumulator, Context c) {
+ return combineFn.extractOutput(accumulator);
+ }
+ @Override
+ public AccumT compact(AccumT accumulator, Context c) {
+ return combineFn.compact(accumulator);
+ }
+ @Override
+ public OutputT defaultValue() {
+ return combineFn.defaultValue();
+ }
+ @Override
+ public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
+ throws CannotProvideCoderException {
+ return combineFn.getAccumulatorCoder(registry, inputCoder);
+ }
+ @Override
+ public Coder<OutputT> getDefaultOutputCoder(
+ CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException {
+ return combineFn.getDefaultOutputCoder(registry, inputCoder);
+ }
+ };
+ }
+ }
+
+ private static <K, InputT, AccumT, OutputT> KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>
+ toFnWithContext(PerKeyCombineFn<K, InputT, AccumT, OutputT> perKeyCombineFn) {
+ if (perKeyCombineFn instanceof KeyedCombineFnWithContext) {
+ @SuppressWarnings("unchecked")
+ KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> keyedCombineFnWithContext =
+ (KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) perKeyCombineFn;
+ return keyedCombineFnWithContext;
+ } else {
+ @SuppressWarnings("unchecked")
+ final KeyedCombineFn<K, InputT, AccumT, OutputT> keyedCombineFn =
+ (KeyedCombineFn<K, InputT, AccumT, OutputT>) perKeyCombineFn;
+ return new KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>() {
+ @Override
+ public AccumT createAccumulator(K key, Context c) {
+ return keyedCombineFn.createAccumulator(key);
+ }
+ @Override
+ public AccumT addInput(K key, AccumT accumulator, InputT value, Context c) {
+ return keyedCombineFn.addInput(key, accumulator, value);
+ }
+ @Override
+ public AccumT mergeAccumulators(K key, Iterable<AccumT> accumulators, Context c) {
+ return keyedCombineFn.mergeAccumulators(key, accumulators);
+ }
+ @Override
+ public OutputT extractOutput(K key, AccumT accumulator, Context c) {
+ return keyedCombineFn.extractOutput(key, accumulator);
+ }
+ @Override
+ public AccumT compact(K key, AccumT accumulator, Context c) {
+ return keyedCombineFn.compact(key, accumulator);
+ }
+ @Override
+ public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<K> keyCoder,
+ Coder<InputT> inputCoder) throws CannotProvideCoderException {
+ return keyedCombineFn.getAccumulatorCoder(registry, keyCoder, inputCoder);
+ }
+ @Override
+ public Coder<OutputT> getDefaultOutputCoder(CoderRegistry registry, Coder<K> keyCoder,
+ Coder<InputT> inputCoder) throws CannotProvideCoderException {
+ return keyedCombineFn.getDefaultOutputCoder(registry, keyCoder, inputCoder);
+ }
+ };
+ }
+ }
+
+ private static <OutputT> void checkUniqueness(
+ List<TupleTag<?>> registeredTags, TupleTag<OutputT> outputTag) {
+ checkArgument(
+ !registeredTags.contains(outputTag),
+ "Cannot compose with tuple tag %s because it is already present in the composition.",
+ outputTag);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/23b43780/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
----------------------------------------------------------------------
diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
new file mode 100644
index 0000000..ad37708
--- /dev/null
+++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/transforms/CombineFnsTest.java
@@ -0,0 +1,413 @@
+/*
+ * Copyright (C) 2016 Google Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package com.google.cloud.dataflow.sdk.transforms;
+
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.Pipeline;
+import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder;
+import com.google.cloud.dataflow.sdk.coders.Coder;
+import com.google.cloud.dataflow.sdk.coders.CoderException;
+import com.google.cloud.dataflow.sdk.coders.KvCoder;
+import com.google.cloud.dataflow.sdk.coders.NullableCoder;
+import com.google.cloud.dataflow.sdk.coders.StandardCoder;
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.testing.DataflowAssert;
+import com.google.cloud.dataflow.sdk.testing.RunnableOnService;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Combine.BinaryCombineFn;
+import com.google.cloud.dataflow.sdk.transforms.CombineFns.CoCombineResult;
+import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
+import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn;
+import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn;
+import com.google.cloud.dataflow.sdk.values.KV;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionView;
+import com.google.cloud.dataflow.sdk.values.TupleTag;
+import com.google.common.collect.ImmutableList;
+
+import org.hamcrest.Matchers;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Unit tests for {@link CombineFns}.
+ */
+@RunWith(JUnit4.class)
+public class CombineFnsTest {
+ @Rule public ExpectedException expectedException = ExpectedException.none();
+
+ @Test
+ public void testDuplicatedTags() {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("it is already present in the composition");
+
+ TupleTag<Integer> tag = new TupleTag<Integer>();
+ CombineFns.compose()
+ .with(new GetIntegerFunction(), new MaxIntegerFn(), tag)
+ .with(new GetIntegerFunction(), new MinIntegerFn(), tag);
+ }
+
+ @Test
+ public void testDuplicatedTagsKeyed() {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("it is already present in the composition");
+
+ TupleTag<Integer> tag = new TupleTag<Integer>();
+ CombineFns.composeKeyed()
+ .with(new GetIntegerFunction(), new MaxIntegerFn(), tag)
+ .with(new GetIntegerFunction(), new MinIntegerFn(), tag);
+ }
+
+ @Test
+ public void testDuplicatedTagsWithContext() {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("it is already present in the composition");
+
+ TupleTag<UserString> tag = new TupleTag<UserString>();
+ CombineFns.compose()
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()),
+ tag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(null /* view */).forKey("G", StringUtf8Coder.of()),
+ tag);
+ }
+
+ @Test
+ public void testDuplicatedTagsWithContextKeyed() {
+ expectedException.expect(IllegalArgumentException.class);
+ expectedException.expectMessage("it is already present in the composition");
+
+ TupleTag<UserString> tag = new TupleTag<UserString>();
+ CombineFns.composeKeyed()
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(null /* view */),
+ tag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(null /* view */),
+ tag);
+ }
+
+ @Test
+ @Category(RunnableOnService.class)
+ public void testComposedCombine() {
+ Pipeline p = TestPipeline.create();
+ p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of());
+
+ PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+ Create.timestamped(
+ Arrays.asList(
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(4, UserString.of("4"))),
+ KV.of("b", KV.of(1, UserString.of("1"))),
+ KV.of("b", KV.of(13, UserString.of("13")))),
+ Arrays.asList(0L, 4L, 7L, 10L, 16L))
+ .withCoder(KvCoder.of(
+ StringUtf8Coder.of(),
+ KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
+
+ TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+ TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+ PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput
+ .apply(Values.<KV<Integer, UserString>>create())
+ .apply(Combine.globally(CombineFns.compose()
+ .with(
+ new GetIntegerFunction(),
+ new MaxIntegerFn(),
+ maxIntTag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatString(),
+ concatStringTag)))
+ .apply(WithKeys.<String, CoCombineResult>of("global"))
+ .apply(
+ "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+
+ PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+ .apply(Combine.perKey(CombineFns.composeKeyed()
+ .with(
+ new GetIntegerFunction(),
+ new MaxIntegerFn().<String>asKeyedFn(),
+ maxIntTag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatString().<String>asKeyedFn(),
+ concatStringTag)))
+ .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+ DataflowAssert.that(combineGlobally).containsInAnyOrder(
+ KV.of("global", KV.of(13, "111134")));
+ DataflowAssert.that(combinePerKey).containsInAnyOrder(
+ KV.of("a", KV.of(4, "114")),
+ KV.of("b", KV.of(13, "113")));
+ p.run();
+ }
+
+ @Test
+ @Category(RunnableOnService.class)
+ public void testComposedCombineWithContext() {
+ Pipeline p = TestPipeline.create();
+ p.getCoderRegistry().registerCoder(UserString.class, UserStringCoder.of());
+
+ PCollectionView<String> view = p
+ .apply(Create.of("I"))
+ .apply(View.<String>asSingleton());
+
+ PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+ Create.timestamped(
+ Arrays.asList(
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(4, UserString.of("4"))),
+ KV.of("b", KV.of(1, UserString.of("1"))),
+ KV.of("b", KV.of(13, UserString.of("13")))),
+ Arrays.asList(0L, 4L, 7L, 10L, 16L))
+ .withCoder(KvCoder.of(
+ StringUtf8Coder.of(),
+ KvCoder.of(BigEndianIntegerCoder.of(), UserStringCoder.of()))));
+
+ TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+ TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+ PCollection<KV<String, KV<Integer, String>>> combineGlobally = perKeyInput
+ .apply(Values.<KV<Integer, UserString>>create())
+ .apply(Combine.globally(CombineFns.compose()
+ .with(
+ new GetIntegerFunction(),
+ new MaxIntegerFn(),
+ maxIntTag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(view).forKey("G", StringUtf8Coder.of()),
+ concatStringTag))
+ .withoutDefaults()
+ .withSideInputs(ImmutableList.of(view)))
+ .apply(WithKeys.<String, CoCombineResult>of("global"))
+ .apply(
+ "ExtractGloballyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+
+ PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+ .apply(Combine.perKey(CombineFns.composeKeyed()
+ .with(
+ new GetIntegerFunction(),
+ new MaxIntegerFn().<String>asKeyedFn(),
+ maxIntTag)
+ .with(
+ new GetUserStringFunction(),
+ new ConcatStringWithContext(view),
+ concatStringTag))
+ .withSideInputs(ImmutableList.of(view)))
+ .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+ DataflowAssert.that(combineGlobally).containsInAnyOrder(
+ KV.of("global", KV.of(13, "111134GI")));
+ DataflowAssert.that(combinePerKey).containsInAnyOrder(
+ KV.of("a", KV.of(4, "114Ia")),
+ KV.of("b", KV.of(13, "113Ib")));
+ p.run();
+ }
+
+ @Test
+ @Category(RunnableOnService.class)
+ public void testComposedCombineNullValues() {
+ Pipeline p = TestPipeline.create();
+ p.getCoderRegistry().registerCoder(UserString.class, NullableCoder.of(UserStringCoder.of()));
+ p.getCoderRegistry().registerCoder(String.class, NullableCoder.of(StringUtf8Coder.of()));
+
+ PCollection<KV<String, KV<Integer, UserString>>> perKeyInput = p.apply(
+ Create.timestamped(
+ Arrays.asList(
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(1, UserString.of("1"))),
+ KV.of("a", KV.of(4, UserString.of("4"))),
+ KV.of("b", KV.of(1, UserString.of("1"))),
+ KV.of("b", KV.of(13, UserString.of("13")))),
+ Arrays.asList(0L, 4L, 7L, 10L, 16L))
+ .withCoder(KvCoder.of(
+ StringUtf8Coder.of(),
+ KvCoder.of(
+ BigEndianIntegerCoder.of(), NullableCoder.of(UserStringCoder.of())))));
+
+ TupleTag<Integer> maxIntTag = new TupleTag<Integer>();
+ TupleTag<UserString> concatStringTag = new TupleTag<UserString>();
+
+ PCollection<KV<String, KV<Integer, String>>> combinePerKey = perKeyInput
+ .apply(Combine.perKey(CombineFns.composeKeyed()
+ .with(
+ new GetIntegerFunction(),
+ new MaxIntegerFn().<String>asKeyedFn(),
+ maxIntTag)
+ .with(
+ new GetUserStringFunction(),
+ new OutputNullString().<String>asKeyedFn(),
+ concatStringTag)))
+ .apply("ExtractPerKeyResult", ParDo.of(new ExtractResultDoFn(maxIntTag, concatStringTag)));
+ DataflowAssert.that(combinePerKey).containsInAnyOrder(
+ KV.of("a", KV.of(4, (String) null)),
+ KV.of("b", KV.of(13, (String) null)));
+ p.run();
+ }
+
+ private static class UserString implements Serializable {
+ private String strValue;
+
+ static UserString of(String strValue) {
+ UserString ret = new UserString();
+ ret.strValue = strValue;
+ return ret;
+ }
+ }
+
+ private static class UserStringCoder extends StandardCoder<UserString> {
+ public static UserStringCoder of() {
+ return INSTANCE;
+ }
+
+ private static final UserStringCoder INSTANCE = new UserStringCoder();
+
+ @Override
+ public void encode(UserString value, OutputStream outStream, Context context)
+ throws CoderException, IOException {
+ StringUtf8Coder.of().encode(value.strValue, outStream, context);
+ }
+
+ @Override
+ public UserString decode(InputStream inStream, Context context)
+ throws CoderException, IOException {
+ return UserString.of(StringUtf8Coder.of().decode(inStream, context));
+ }
+
+ @Override
+ public List<? extends Coder<?>> getCoderArguments() {
+ return null;
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {}
+ }
+
+ private static class GetIntegerFunction
+ extends SimpleFunction<KV<Integer, UserString>, Integer> {
+ @Override
+ public Integer apply(KV<Integer, UserString> input) {
+ return input.getKey();
+ }
+ }
+
+ private static class GetUserStringFunction
+ extends SimpleFunction<KV<Integer, UserString>, UserString> {
+ @Override
+ public UserString apply(KV<Integer, UserString> input) {
+ return input.getValue();
+ }
+ }
+
+ private static class ConcatString extends BinaryCombineFn<UserString> {
+ @Override
+ public UserString apply(UserString left, UserString right) {
+ String retStr = left.strValue + right.strValue;
+ char[] chars = retStr.toCharArray();
+ Arrays.sort(chars);
+ return UserString.of(new String(chars));
+ }
+ }
+
+ private static class OutputNullString extends BinaryCombineFn<UserString> {
+ @Override
+ public UserString apply(UserString left, UserString right) {
+ return null;
+ }
+ }
+
+ private static class ConcatStringWithContext
+ extends KeyedCombineFnWithContext<String, UserString, UserString, UserString> {
+ private final PCollectionView<String> view;
+
+ private ConcatStringWithContext(PCollectionView<String> view) {
+ this.view = view;
+ }
+
+ @Override
+ public UserString createAccumulator(String key, CombineWithContext.Context c) {
+ return UserString.of(key + c.sideInput(view));
+ }
+
+ @Override
+ public UserString addInput(
+ String key, UserString accumulator, UserString input, CombineWithContext.Context c) {
+ assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view)));
+ accumulator.strValue += input.strValue;
+ return accumulator;
+ }
+
+ @Override
+ public UserString mergeAccumulators(
+ String key, Iterable<UserString> accumulators, CombineWithContext.Context c) {
+ String keyPrefix = key + c.sideInput(view);
+ String all = keyPrefix;
+ for (UserString accumulator : accumulators) {
+ assertThat(accumulator.strValue, Matchers.startsWith(keyPrefix));
+ all += accumulator.strValue.substring(keyPrefix.length());
+ accumulator.strValue = "cleared in mergeAccumulators";
+ }
+ return UserString.of(all);
+ }
+
+ @Override
+ public UserString extractOutput(
+ String key, UserString accumulator, CombineWithContext.Context c) {
+ assertThat(accumulator.strValue, Matchers.startsWith(key + c.sideInput(view)));
+ char[] chars = accumulator.strValue.toCharArray();
+ Arrays.sort(chars);
+ return UserString.of(new String(chars));
+ }
+ }
+
+ private static class ExtractResultDoFn
+ extends DoFn<KV<String, CoCombineResult>, KV<String, KV<Integer, String>>>{
+
+ private final TupleTag<Integer> maxIntTag;
+ private final TupleTag<UserString> concatStringTag;
+
+ ExtractResultDoFn(TupleTag<Integer> maxIntTag, TupleTag<UserString> concatStringTag) {
+ this.maxIntTag = maxIntTag;
+ this.concatStringTag = concatStringTag;
+ }
+
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ UserString userString = c.element().getValue().get(concatStringTag);
+ KV<Integer, String> value = KV.of(
+ c.element().getValue().get(maxIntTag),
+ userString == null ? null : userString.strValue);
+ c.output(KV.of(c.element().getKey(), value));
+ }
+ }
+}
[2/2] incubator-beam git commit: This closes #23
Posted by bc...@apache.org.
This closes #23
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/c3032600
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/c3032600
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/c3032600
Branch: refs/heads/master
Commit: c30326007370bff7f678a3bf4d3ea338287225c7
Parents: ac63fd6 23b4378
Author: bchambers <bc...@google.com>
Authored: Thu Mar 17 14:06:13 2016 -0700
Committer: bchambers <bc...@google.com>
Committed: Thu Mar 17 14:06:13 2016 -0700
----------------------------------------------------------------------
.../dataflow/sdk/transforms/CombineFns.java | 1100 ++++++++++++++++++
.../dataflow/sdk/transforms/CombineFnsTest.java | 413 +++++++
2 files changed, 1513 insertions(+)
----------------------------------------------------------------------