You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/08/17 22:19:58 UTC

[beam] branch master updated: [#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new ec44ac9a9e3 [#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)
ec44ac9a9e3 is described below

commit ec44ac9a9e305a813be54f81bef8eac36df02d0c
Author: Luke Cwik <lc...@google.com>
AuthorDate: Wed Aug 17 15:19:52 2022 -0700

    [#21935] Reject ill formed GroupByKey coders during pipeline.run validation within Beam Java SDK. (#22702)
---
 .../main/java/org/apache/beam/sdk/Pipeline.java    |  4 +-
 .../org/apache/beam/sdk/transforms/GroupByKey.java | 30 +++++++++--
 .../org/apache/beam/sdk/transforms/PTransform.java | 13 +++++
 .../apache/beam/sdk/transforms/GroupByKeyTest.java | 63 +++++++++++++---------
 4 files changed, 80 insertions(+), 30 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index d3cbf6dbf54..7d8e101334a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -659,14 +659,14 @@ public class Pipeline {
     @Override
     public CompositeBehavior enterCompositeTransform(Node node) {
       if (node.getTransform() != null) {
-        node.getTransform().validate(options);
+        node.getTransform().validate(options, node.getInputs(), node.getOutputs());
       }
       return CompositeBehavior.ENTER_TRANSFORM;
     }
 
     @Override
     public void visitPrimitiveTransform(Node node) {
-      node.getTransform().validate(options);
+      node.getTransform().validate(options, node.getInputs(), node.getOutputs());
     }
   }
 
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
index 6dc7aaa3e3f..63eb914ede0 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupByKey.java
@@ -17,10 +17,12 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import java.util.Map;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.Coder.NonDeterministicException;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.windowing.AfterWatermark.AfterWatermarkEarlyAndLate;
 import org.apache.beam.sdk.transforms.windowing.AfterWatermark.FromEndOfWindow;
@@ -33,7 +35,10 @@ import org.apache.beam.sdk.transforms.windowing.WindowFn;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollection.IsBounded;
+import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * {@code GroupByKey<K, V>} takes a {@code PCollection<KV<K, V>>}, groups the values by key and
@@ -169,6 +174,25 @@ public class GroupByKey<K, V>
     }
   }
 
+  @Override
+  public void validate(
+      @Nullable PipelineOptions options,
+      Map<TupleTag<?>, PCollection<?>> inputs,
+      Map<TupleTag<?>, PCollection<?>> outputs) {
+    PCollection<?> input = Iterables.getOnlyElement(inputs.values());
+    KvCoder<K, V> inputCoder = getInputKvCoder(input.getCoder());
+
+    // Ensure that the output coder key and value types aren't different.
+    Coder<?> outputCoder = Iterables.getOnlyElement(outputs.values()).getCoder();
+    KvCoder<?, ?> expectedOutputCoder = getOutputKvCoder(inputCoder);
+    if (!expectedOutputCoder.equals(outputCoder)) {
+      throw new IllegalStateException(
+          String.format(
+              "the GroupByKey requires its output coder to be %s but found %s.",
+              expectedOutputCoder, outputCoder));
+    }
+  }
+
   // Note that Never trigger finishes *at* GC time so it is OK, and
   // AfterWatermark.fromEndOfWindow() finishes at end-of-window time so it is
   // OK if there is no allowed lateness.
@@ -235,7 +259,7 @@ public class GroupByKey<K, V>
    * Returns the {@code Coder} of the input to this transform, which should be a {@code KvCoder}.
    */
   @SuppressWarnings("unchecked")
-  static <K, V> KvCoder<K, V> getInputKvCoder(Coder<KV<K, V>> inputCoder) {
+  static <K, V> KvCoder<K, V> getInputKvCoder(Coder<?> inputCoder) {
     if (!(inputCoder instanceof KvCoder)) {
       throw new IllegalStateException("GroupByKey requires its input to use KvCoder");
     }
@@ -249,12 +273,12 @@ public class GroupByKey<K, V>
    * {@code Coder} of the keys of the output of this transform.
    */
   public static <K, V> Coder<K> getKeyCoder(Coder<KV<K, V>> inputCoder) {
-    return getInputKvCoder(inputCoder).getKeyCoder();
+    return GroupByKey.<K, V>getInputKvCoder(inputCoder).getKeyCoder();
   }
 
   /** Returns the {@code Coder} of the values of the input to this transform. */
   public static <K, V> Coder<V> getInputValueCoder(Coder<KV<K, V>> inputCoder) {
-    return getInputKvCoder(inputCoder).getValueCoder();
+    return GroupByKey.<K, V>getInputKvCoder(inputCoder).getValueCoder();
   }
 
   /** Returns the {@code Coder} of the {@code Iterable} values of the output of this transform. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
index c83bcb49b74..c01a53f3ebb 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/PTransform.java
@@ -161,6 +161,19 @@ public abstract class PTransform<InputT extends PInput, OutputT extends POutput>
    */
   public void validate(@Nullable PipelineOptions options) {}
 
+  /**
+   * Called before running the Pipeline to verify this transform, its inputs, and outputs are fully
+   * and correctly specified.
+   *
+   * <p>By default, delegates to {@link #validate(PipelineOptions)}.
+   */
+  public void validate(
+      @Nullable PipelineOptions options,
+      Map<TupleTag<?>, PCollection<?>> inputs,
+      Map<TupleTag<?>, PCollection<?>> outputs) {
+    validate(options);
+  }
+
   /**
    * Returns all {@link PValue PValues} that are consumed as inputs to this {@link PTransform} that
    * are independent of the expansion of the {@link InputT} within {@link #expand(PInput)}.
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
index 21c3f24b713..9d6bd0131a7 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyTest.java
@@ -24,6 +24,7 @@ import static org.hamcrest.CoreMatchers.hasItem;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder;
+import static org.junit.Assert.assertThrows;
 
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
@@ -43,8 +44,10 @@ import org.apache.beam.sdk.coders.AtomicCoder;
 import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderProviders;
+import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.MapCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.io.GenerateSequence;
@@ -92,7 +95,6 @@ import org.junit.Rule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
 import org.junit.experimental.runners.Enclosed;
-import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
@@ -107,8 +109,6 @@ public class GroupByKeyTest implements Serializable {
   /** Shared test base class with setup/teardown helpers. */
   public abstract static class SharedTestBase {
     @Rule public transient TestPipeline p = TestPipeline.create();
-
-    @Rule public transient ExpectedException thrown = ExpectedException.none();
   }
 
   /** Tests validating basic {@link GroupByKey} scenarios. */
@@ -306,7 +306,6 @@ public class GroupByKeyTest implements Serializable {
 
     @Test
     public void testGroupByKeyNonDeterministic() throws Exception {
-
       List<KV<Map<String, String>, Integer>> ungroupedPairs = Arrays.asList();
 
       PCollection<KV<Map<String, String>, Integer>> input =
@@ -317,9 +316,31 @@ public class GroupByKeyTest implements Serializable {
                           MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()),
                           BigEndianIntegerCoder.of())));
 
-      thrown.expect(IllegalStateException.class);
-      thrown.expectMessage("must be deterministic");
-      input.apply(GroupByKey.create());
+      assertThrows(
+          "must be deterministic",
+          IllegalStateException.class,
+          () -> input.apply(GroupByKey.create()));
+    }
+
+    @Test
+    public void testGroupByKeyOutputCoderUnmodifiedAfterApplyAndBeforePipelineRun()
+        throws Exception {
+      List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
+
+      PCollection<KV<String, Integer>> input =
+          p.apply(
+              Create.of(ungroupedPairs)
+                  .withCoder(KvCoder.of(StringUtf8Coder.of(), BigEndianIntegerCoder.of())));
+
+      // Apply with a known good coder
+      PCollection<KV<String, Iterable<Integer>>> output = input.apply(GroupByKey.create());
+
+      // Change the output to have a different coder that doesn't match the input coder types
+      output.setCoder(
+          KvCoder.of(
+              SerializableCoder.of(String.class), IterableCoder.of(BigEndianIntegerCoder.of())));
+      assertThrows(
+          "the GroupByKey requires its output coder", IllegalStateException.class, () -> p.run());
     }
 
     // AfterPane.elementCountAtLeast(1) is not OK
@@ -332,9 +353,8 @@ public class GroupByKeyTest implements Serializable {
                       .discardingFiredPanes()
                       .triggering(AfterPane.elementCountAtLeast(1)));
 
-      thrown.expect(IllegalArgumentException.class);
-      thrown.expectMessage("Unsafe trigger");
-      input.apply(GroupByKey.create());
+      assertThrows(
+          "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
     }
 
     // AfterWatermark.pastEndOfWindow() is OK with 0 allowed lateness
@@ -380,9 +400,8 @@ public class GroupByKeyTest implements Serializable {
                       .triggering(AfterWatermark.pastEndOfWindow())
                       .withAllowedLateness(Duration.millis(10)));
 
-      thrown.expect(IllegalArgumentException.class);
-      thrown.expectMessage("Unsafe trigger");
-      input.apply(GroupByKey.create());
+      assertThrows(
+          "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
     }
 
     // AfterWatermark.pastEndOfWindow().withEarlyFirings() is not OK with > 0 allowed lateness
@@ -398,9 +417,8 @@ public class GroupByKeyTest implements Serializable {
                               .withEarlyFirings(AfterPane.elementCountAtLeast(1)))
                       .withAllowedLateness(Duration.millis(10)));
 
-      thrown.expect(IllegalArgumentException.class);
-      thrown.expectMessage("Unsafe trigger");
-      input.apply(GroupByKey.create());
+      assertThrows(
+          "Unsafe trigger", IllegalArgumentException.class, () -> input.apply(GroupByKey.create()));
     }
 
     // AfterWatermark.pastEndOfWindow().withLateFirings() is always OK
@@ -423,7 +441,6 @@ public class GroupByKeyTest implements Serializable {
     @Test
     @Category(NeedsRunner.class)
     public void testRemerge() {
-
       List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
 
       PCollection<KV<String, Integer>> input =
@@ -450,7 +467,6 @@ public class GroupByKeyTest implements Serializable {
 
     @Test
     public void testGroupByKeyDirectUnbounded() {
-
       PCollection<KV<String, Integer>> input =
           p.apply(
               new PTransform<PBegin, PCollection<KV<String, Integer>>>() {
@@ -464,12 +480,11 @@ public class GroupByKeyTest implements Serializable {
                 }
               });
 
-      thrown.expect(IllegalStateException.class);
-      thrown.expectMessage(
+      assertThrows(
           "GroupByKey cannot be applied to non-bounded PCollection in the GlobalWindow without "
-              + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.");
-
-      input.apply("GroupByKey", GroupByKey.create());
+              + "a trigger. Use a Window.into or Window.triggering transform prior to GroupByKey.",
+          IllegalStateException.class,
+          () -> input.apply("GroupByKey", GroupByKey.create()));
     }
 
     /**
@@ -480,7 +495,6 @@ public class GroupByKeyTest implements Serializable {
     @Test
     @Category(ValidatesRunner.class)
     public void testTimestampCombinerEarliest() {
-
       p.apply(
               Create.timestamped(
                   TimestampedValue.of(KV.of(0, "hello"), new Instant(0)),
@@ -745,7 +759,6 @@ public class GroupByKeyTest implements Serializable {
     @Test
     @Category(NeedsRunner.class)
     public void testIdentityWindowFnPropagation() {
-
       List<KV<String, Integer>> ungroupedPairs = Arrays.asList();
 
       PCollection<KV<String, Integer>> input =