You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/03/07 01:56:28 UTC

[beam] branch master updated: [BEAM-9464] Fix WithKeys to respect parameterized types

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b5301d9  [BEAM-9464] Fix WithKeys to respect parameterized types
     new e593074  Merge pull request #11064 from lukecwik/splittabledofn4
b5301d9 is described below

commit b5301d90f771b956ff94c346fc2a44d2e0590439
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Mar 6 13:53:57 2020 -0800

    [BEAM-9464] Fix WithKeys to respect parameterized types
---
 .../org/apache/beam/sdk/transforms/WithKeys.java    | 21 +++++++++++----------
 .../apache/beam/sdk/transforms/WithKeysTest.java    | 20 ++++++++++++++++++++
 2 files changed, 31 insertions(+), 10 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java
index 60bd4a9..c5fe782 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java
@@ -28,6 +28,7 @@ import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
 
 /**
  * {@code WithKeys<K, V>} takes a {@code PCollection<V>}, and either a constant key of type {@code
@@ -74,17 +75,20 @@ public class WithKeys<K, V> extends PTransform<PCollection<V>, PCollection<KV<K,
    */
   @SuppressWarnings("unchecked")
   public static <K, V> WithKeys<K, V> of(@Nullable final K key) {
-    return new WithKeys<>(value -> key, (Class<K>) (key == null ? Void.class : key.getClass()));
+    return new WithKeys<>(
+        value -> key,
+        (TypeDescriptor<K>)
+            (key == null ? TypeDescriptors.voids() : TypeDescriptor.of(key.getClass())));
   }
 
   /////////////////////////////////////////////////////////////////////////////
 
   private SerializableFunction<V, K> fn;
-  @CheckForNull private transient Class<K> keyClass;
+  @CheckForNull private transient TypeDescriptor<K> keyType;
 
-  private WithKeys(SerializableFunction<V, K> fn, Class<K> keyClass) {
+  private WithKeys(SerializableFunction<V, K> fn, TypeDescriptor<K> keyType) {
     this.fn = fn;
-    this.keyClass = keyClass;
+    this.keyType = keyType;
   }
 
   /**
@@ -95,10 +99,7 @@ public class WithKeys<K, V> extends PTransform<PCollection<V>, PCollection<KV<K,
    * PCollection}.
    */
   public WithKeys<K, V> withKeyType(TypeDescriptor<K> keyType) {
-    // Safe cast
-    @SuppressWarnings("unchecked")
-    Class<K> rawType = (Class<K>) keyType.getRawType();
-    return new WithKeys<>(fn, rawType);
+    return new WithKeys<>(fn, keyType);
   }
 
   @Override
@@ -117,10 +118,10 @@ public class WithKeys<K, V> extends PTransform<PCollection<V>, PCollection<KV<K,
     try {
       Coder<K> keyCoder;
       CoderRegistry coderRegistry = in.getPipeline().getCoderRegistry();
-      if (keyClass == null) {
+      if (keyType == null) {
         keyCoder = coderRegistry.getOutputCoder(fn, in.getCoder());
       } else {
-        keyCoder = coderRegistry.getCoder(TypeDescriptor.of(keyClass));
+        keyCoder = coderRegistry.getCoder(keyType);
       }
       // TODO: Remove when we can set the coder inference context.
       result.setCoder(KvCoder.of(keyCoder, in.getCoder()));
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java
index 1baa3d4..5a8da19 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.transforms;
 import static org.junit.Assert.assertEquals;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.testing.NeedsRunner;
@@ -28,6 +29,7 @@ import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
@@ -144,6 +146,24 @@ public class WithKeysTest {
 
   @Test
   @Category(NeedsRunner.class)
+  public void withLambdaAndParameterizedTypeDescriptorShouldSucceed() {
+
+    PCollection<String> values = p.apply(Create.of("1234", "3210"));
+    PCollection<KV<List<String>, String>> kvs =
+        values.apply(
+            WithKeys.of((SerializableFunction<String, List<String>>) Collections::singletonList)
+                .withKeyType(TypeDescriptors.lists(TypeDescriptors.strings())));
+
+    PAssert.that(kvs)
+        .containsInAnyOrder(
+            KV.of(Collections.singletonList("1234"), "1234"),
+            KV.of(Collections.singletonList("3210"), "3210"));
+
+    p.run();
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
   public void withLambdaAndNoTypeDescriptorShouldThrow() {
 
     PCollection<String> values = p.apply(Create.of("1234", "3210", "0", "-12"));