You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/08/17 22:19:58 UTC
[beam] branch master updated: [#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new ec44ac9a9e3 [#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)
ec44ac9a9e3 is described below
commit ec44ac9a9e305a813be54f81bef8eac36df02d0c
Author: Luke Cwik <lc...@google.com>
AuthorDate: Wed Aug 17 15:19:52 2022 -0700
[#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)
---
.../main/java/org/apache/beam/sdk/Pipeline.java | 4 +-
.../org/apache/beam/sdk/transforms/GroupByKey.java | 30 +++++++++--
.../org/apache/beam/sdk/transforms/PTransform.java | 13 +++++
.../apache/beam/sdk/transforms/GroupByKeyTest.java | 63 +++++++++++++---------
4 files changed, 80 insertions(+), 30 deletions(-)
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index d3cbf6dbf54..7d8e101334a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -659,14 +659,14 @@ public class Pipeline {
@Override
public CompositeBehavior enterCompositeTransform(Node node) {
if (node.getTransform() != null) {
- node.getTransform().validate(options);
+ node.getTransform().validate(options, node.getInputs(), node.getOutputs());
}
return CompositeBehavior.ENTER_TRANSFORM;
}
@Override
public void visitPrimitiveTransform(Node node) {
- node.getTransform().validate(options);
+ node.getTransform().validate(options, node.getInputs(), node.getOutputs());
}
}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
index 6dc7aaa3e3f..63eb914ede0 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
@@ -17,10 +17,12 @@
*/
package org.apache.beam.sdk.transforms;
+import java.util.Map;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate;
import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow;
@@ -33,7 +35,10 @@ import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.checkerframework.checker.nullness.qual.Nullable;
/**
* {@code GroupByKey<K, V>} takes a {@code PCollection<KV<K, V>>}, groups the values by key and
@@ -169,6 +174,25 @@ public class GroupByKey<K, V>
}
}
+ @Override
+ public void validate(
+ @Nullable PipelineOptions options,
+ Map<TupleTag<?>, PCollection<?>> inputs,
+ Map<TupleTag<?>, PCollection<?>> outputs) {
+ PCollection<?> input = Iterables.getOnlyElement(inputs.values());
+ KvCoder<K, V> inputCoder = getInputKvCoder(input.getCoder());
+
+ // Ensure that the output coder key and value types aren't different.
+ Coder<?> outputCoder = Iterables.getOnlyElement(outputs.values()).getCoder();
+ KvCoder<?, ?> expectedOutputCoder = getOutputKvCoder(inputCoder);
+ if (!expectedOutputCoder.equals(outputCoder)) {
+ throw new IllegalStateException(
+ String.format(
+ "the GroupByKey requires its output coder to be %s but found %s.",
+ expectedOutputCoder, outputCoder));
+ }
+ }
+
// Note that Never trigger finishes *at* GC time so it is OK, and
// AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is
// OK if there is no allowed lateness.
@@ -235,7 +259,7 @@ public class GroupByKey<K, V>
* Returns the {@code Coder} of the input to this transform, which should be a {@code KvCoder}.
*/
@SuppressWarnings("unchecked")
- static <K, V> KvCoder<K, V> getInputKvCoder(Coder<KV<K, V>> inputCoder) {
+ static <K, V> KvCoder<K, V> getInputKvCoder(Coder<?> inputCoder) {
if (!(inputCoder instanceof KvCoder)) {
throw new IllegalStateException("GroupByKey requires its input to use KvCoder");
}
@@ -249,12 +273,12 @@ public class GroupByKey<K, V>
* {@code Coder} of the keys of the output of this transform.
*/
public static <K, V> Coder<K> getKeyCoder(Coder<KV<K, V>> inputCoder) {
- return getInputKvCoder(inputCoder).getKeyCoder();
+ return GroupByKey.<K, V>getInputKvCoder(inputCoder).getKeyCoder();
}
/** Returns the {@code Coder} of the values of the input to this transform. */
public static <K, V> Coder<V> getInputValueCoder(Coder<KV<K, V>> inputCoder) {
- return getInputKvCoder(inputCoder).getValueCoder();
+ return GroupByKey.<K, V>getInputKvCoder(inputCoder).getValueCoder();
}
/** Returns the {@code Coder} of the {@code Iterable} values of the output of this transform. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
index c83bcb49b74..c01a53f3ebb 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
@@ -161,6 +161,19 @@ public abstract class PTransform<InputT extends PInput, OutputT extends POutput>
*/
public void validate(@Nullable PipelineOptions options) {}
+ /**
+ * Called before running the Pipeline to verify this transform, its inputs, and outputs are fully
+ * and correctly specified.
+ *
+ * <p>By default, delegates to {@link #validate(PipelineOptions)}.
+ */
+ public void validate(
+ @Nullable PipelineOptions options,
+ Map<TupleTag<?>, PCollection<?>> inputs,
+ Map<TupleTag<?>, PCollection<?>> outputs) {
+ validate(options);
+ }
+
/**
* Returns all {@link PValue PValues} that are consumed as inputs to this {@link PTransform} that
* are independent of the expansion of the {@link InputT} within {@link #expand(PInput)}.
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
index 21c3f24b713..9d6bd0131a7 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
@@ -24,6 +24,7 @@ import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
+import static org.junit.Assert.assertThrows;
import java.io.DataInputStream;
import java.io.DataOutputStream;
@@ -43,8 +44,10 @@ import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderProviders;
+import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.MapCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
@@ -92,7 +95,6 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.experimental.runners.Enclosed;
-import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -107,8 +109,6 @@ public class GroupByKeyTest implements Serializable {
/** Shared test base class with setup/teardown helpers. */
public abstract static class SharedTestBase {
@Rule public transient TestPipeline p = TestPipeline.create();
-
- @Rule public transient ExpectedException thrown = ExpectedException.none();
}
/** Tests validating basic {@link GroupByKey} scenarios. */
@@ -306,7 +306,6 @@ public class GroupByKeyTest implements Serializable {
@Test
public void testGroupByKeyNonDeterministic() throws Exception {
-
List<KV<Map<String, String>, Integer>> ungroupedPairs = Arrays.asList();
PCollection<KV<Map<String, String>, Integer>> input =
@@ -317,9 +316,31 @@ public class GroupByKeyTest implements Serializable {
MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
BigEndianIntegerCoder.of())));
- thrown.expect(IllegalStateException.class);
- thrown.expectMessage("must be deterministic");
- input.apply(GroupByKey.create());
+ assertThrows(
+ "must be deterministic",
+ IllegalStateException.class,
+ () -> input.apply(GroupByKey.create()));
+ }
+
+ @Test
+ public void testGroupByKeyOutputCoderUnmodifiedAfterApplyAndBeforePipelineRun()
+ throws Exception {
+ List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
+
+ PCollection<KV<String, Integer>> input =
+ p.apply(
+ Create.of(ungroupedPairs)
+ .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())));
+
+ // Apply with a known good coder
+ PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create());
+
+ // Change the output to have a different coder that doesn't match the input coder types
+ output.setCoder(
+ KvCoder.of(
+ SerializableCoder.of(String.class), IterableCoder.of(BigEndianIntegerCoder.of())));
+ assertThrows(
+ "the GroupByKey requires its output coder", IllegalStateException.class, () -> p.run());
}
// AfterPane.elementCountAtLeast(1) is not OK
@@ -332,9 +353,8 @@ public class GroupByKeyTest implements Serializable {
.discardingFiredPanes()
.triggering(AfterPane.elementCountAtLeast(1)));
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("Unsafe trigger");
- input.apply(GroupByKey.create());
+ assertThrows(
+ "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
}
// AfterWatermark.pastEndOfWindow() is OK with 0 allowed lateness
@@ -380,9 +400,8 @@ public class GroupByKeyTest implements Serializable {
.triggering(AfterWatermark.pastEndOfWindow())
.withAllowedLateness(Duration.millis(10)));
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("Unsafe trigger");
- input.apply(GroupByKey.create());
+ assertThrows(
+ "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
}
// AfterWatermark.pastEndOfWindow().withEarlyFirings() is not OK with > 0 allowed lateness
@@ -398,9 +417,8 @@ public class GroupByKeyTest implements Serializable {
.withEarlyFirings(AfterPane.elementCountAtLeast(1)))
.withAllowedLateness(Duration.millis(10)));
- thrown.expect(IllegalArgumentException.class);
- thrown.expectMessage("Unsafe trigger");
- input.apply(GroupByKey.create());
+ assertThrows(
+ "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
}
// AfterWatermark.pastEndOfWindow().withLateFirings() is always OK
@@ -423,7 +441,6 @@ public class GroupByKeyTest implements Serializable {
@Test
@Category(NeedsRunner.class)
public void testRemerge() {
-
List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
PCollection<KV<String, Integer>> input =
@@ -450,7 +467,6 @@ public class GroupByKeyTest implements Serializable {
@Test
public void testGroupByKeyDirectUnbounded() {
-
PCollection<KV<String, Integer>> input =
p.apply(
new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@@ -464,12 +480,11 @@ public class GroupByKeyTest implements Serializable {
}
});
- thrown.expect(IllegalStateException.class);
- thrown.expectMessage(
+ assertThrows(
"GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without "
- + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.");
-
- input.apply("GroupByKey", GroupByKey.create());
+ + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.",
+ IllegalStateException.class,
+ () -> input.apply("GroupByKey", GroupByKey.create()));
}
/**
@@ -480,7 +495,6 @@ public class GroupByKeyTest implements Serializable {
@Test
@Category(ValidatesRunner.class)
public void testTimestampCombinerEarliest() {
-
p.apply(
Create.timestamped(
TimestampedValue.of(KV.of(0, "hello"), new Instant(0)),
@@ -745,7 +759,6 @@ public class GroupByKeyTest implements Serializable {
@Test
@Category(NeedsRunner.class)
public void testIdentityWindowFnPropagation() {
-
List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
PCollection<KV<String, Integer>> input =