You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ke...@apache.org on 2016/08/06 02:53:02 UTC

[44/51] [abbrv] incubator-beam git commit: Use input type in coder inference for MapElements and FlatMapElements

Use input type in coder inference for MapElements and FlatMapElements

Previously, the input TypeDescriptor was unknown, so we would fail
to infer a coder for things like MapElements.of(SimpleFunction<T, T>)
even if the input PCollection provided a coder for T.

Now, the input type is plumbed appropriately and the coder is inferred.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/4ac5cafe
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/4ac5cafe
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/4ac5cafe

Branch: refs/heads/python-sdk
Commit: 4ac5cafe90a371cf616f97cb202d5016b68616d1
Parents: 8daf518
Author: Kenneth Knowles <kl...@google.com>
Authored: Fri Jul 29 10:35:01 2016 -0700
Committer: Kenneth Knowles <kl...@google.com>
Committed: Thu Aug 4 20:18:59 2016 -0700

----------------------------------------------------------------------
 .../beam/sdk/transforms/FlatMapElements.java    | 126 +++++++++++++------
 .../apache/beam/sdk/transforms/MapElements.java |  60 +++++----
 .../beam/sdk/transforms/SimpleFunction.java     |  34 +++++
 .../sdk/transforms/FlatMapElementsTest.java     |  48 +++++++
 .../beam/sdk/transforms/MapElementsTest.java    |  84 +++++++++++++
 5 files changed, 288 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
index 694592e..04d993c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/FlatMapElements.java
@@ -17,8 +17,10 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
 
 import java.lang.reflect.ParameterizedType;
 
@@ -45,8 +47,16 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
    * descriptor need not be provided.
    */
   public static <InputT, OutputT> MissingOutputTypeDescriptor<InputT, OutputT>
-  via(SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) {
-    return new MissingOutputTypeDescriptor<>(fn);
+  via(SerializableFunction<? super InputT, ? extends Iterable<OutputT>> fn) {
+
+    // TypeDescriptor interacts poorly with the wildcards needed to correctly express
+    // covariance and contravariance in Java, so instead we cast it to an invariant
+    // function here.
+    @SuppressWarnings("unchecked") // safe covariant cast
+    SerializableFunction<InputT, Iterable<OutputT>> simplerFn =
+        (SerializableFunction<InputT, Iterable<OutputT>>) fn;
+
+    return new MissingOutputTypeDescriptor<>(simplerFn);
   }
 
   /**
@@ -72,16 +82,15 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
    * <p>To use a Java 8 lambda, see {@link #via(SerializableFunction)}.
    */
   public static <InputT, OutputT> FlatMapElements<InputT, OutputT>
-  via(SimpleFunction<InputT, ? extends Iterable<OutputT>> fn) {
-
-    @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing
-    TypeDescriptor<Iterable<?>> iterableType = (TypeDescriptor) fn.getOutputTypeDescriptor();
-
-    @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType
-    TypeDescriptor<OutputT> outputType =
-        (TypeDescriptor<OutputT>) getIterableElementType(iterableType);
-
-    return new FlatMapElements<>(fn, outputType);
+  via(SimpleFunction<? super InputT, ? extends Iterable<OutputT>> fn) {
+    // TypeDescriptor interacts poorly with the wildcards needed to correctly express
+    // covariance and contravariance in Java, so instead we cast it to an invariant
+    // function here.
+    @SuppressWarnings("unchecked") // safe covariant cast
+    SimpleFunction<InputT, Iterable<OutputT>> simplerFn =
+        (SimpleFunction<InputT, Iterable<OutputT>>) fn;
+
+    return new FlatMapElements<>(simplerFn, fn.getClass());
   }
 
   /**
@@ -91,18 +100,80 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
    */
   public static final class MissingOutputTypeDescriptor<InputT, OutputT> {
 
-    private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn;
+    private final SerializableFunction<InputT, Iterable<OutputT>> fn;
 
     private MissingOutputTypeDescriptor(
-        SerializableFunction<InputT, ? extends Iterable<OutputT>> fn) {
+        SerializableFunction<InputT, Iterable<OutputT>> fn) {
       this.fn = fn;
     }
 
     public FlatMapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT> outputType) {
-      return new FlatMapElements<>(fn, outputType);
+      TypeDescriptor<Iterable<OutputT>> iterableOutputType = TypeDescriptors.iterables(outputType);
+
+      return new FlatMapElements<>(
+          SimpleFunction.fromSerializableFunctionWithOutputType(fn,
+              iterableOutputType),
+              fn.getClass());
     }
   }
 
+  //////////////////////////////////////////////////////////////////////////////////////////////////
+
+  private final SimpleFunction<InputT, ? extends Iterable<OutputT>> fn;
+  private final DisplayData.Item<?> fnClassDisplayData;
+
+  private FlatMapElements(
+      SimpleFunction<InputT, ? extends Iterable<OutputT>> fn,
+      Class<?> fnClass) {
+    this.fn = fn;
+    this.fnClassDisplayData = DisplayData.item("flatMapFn", fnClass).withLabel("FlatMap Function");
+  }
+
+  @Override
+  public PCollection<OutputT> apply(PCollection<InputT> input) {
+    return input.apply(
+        "FlatMap",
+        ParDo.of(
+            new DoFn<InputT, OutputT>() {
+              private static final long serialVersionUID = 0L;
+
+              @ProcessElement
+              public void processElement(ProcessContext c) {
+                for (OutputT element : fn.apply(c.element())) {
+                  c.output(element);
+                }
+              }
+
+              @Override
+              public TypeDescriptor<InputT> getInputTypeDescriptor() {
+                return fn.getInputTypeDescriptor();
+              }
+
+              @Override
+              public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+                @SuppressWarnings({"rawtypes", "unchecked"}) // safe by static typing
+                TypeDescriptor<Iterable<?>> iterableType =
+                    (TypeDescriptor) fn.getOutputTypeDescriptor();
+
+                @SuppressWarnings("unchecked") // safe by correctness of getIterableElementType
+                TypeDescriptor<OutputT> outputType =
+                    (TypeDescriptor<OutputT>) getIterableElementType(iterableType);
+
+                return outputType;
+              }
+            }));
+  }
+
+  @Override
+  public void populateDisplayData(DisplayData.Builder builder) {
+    super.populateDisplayData(builder);
+    builder.add(fnClassDisplayData);
+  }
+
+  /**
+   * Does a best-effort job of getting the best {@link TypeDescriptor} for the type of the
+   * elements contained in the iterable described by the given {@link TypeDescriptor}.
+   */
   private static TypeDescriptor<?> getIterableElementType(
       TypeDescriptor<Iterable<?>> iterableTypeDescriptor) {
 
@@ -118,29 +189,4 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
         (ParameterizedType) iterableTypeDescriptor.getSupertype(Iterable.class).getType();
     return TypeDescriptor.of(iterableType.getActualTypeArguments()[0]);
   }
-
-  //////////////////////////////////////////////////////////////////////////////////////////////////
-
-  private final SerializableFunction<InputT, ? extends Iterable<OutputT>> fn;
-  private final transient TypeDescriptor<OutputT> outputType;
-
-  private FlatMapElements(
-      SerializableFunction<InputT, ? extends Iterable<OutputT>> fn,
-      TypeDescriptor<OutputT> outputType) {
-    this.fn = fn;
-    this.outputType = outputType;
-  }
-
-  @Override
-  public PCollection<OutputT> apply(PCollection<InputT> input) {
-    return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() {
-      private static final long serialVersionUID = 0L;
-      @ProcessElement
-      public void processElement(ProcessContext c) {
-        for (OutputT element : fn.apply(c.element())) {
-          c.output(element);
-        }
-      }
-    })).setTypeDescriptorInternal(outputType);
-  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
index b7b9a5f..429d3fc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/MapElements.java
@@ -67,9 +67,9 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
    *     }));
    * }</pre>
    */
-  public static <InputT, OutputT> MapElements<InputT, OutputT>
-  via(final SimpleFunction<InputT, OutputT> fn) {
-    return new MapElements<>(fn, fn.getOutputTypeDescriptor());
+  public static <InputT, OutputT> MapElements<InputT, OutputT> via(
+      final SimpleFunction<InputT, OutputT> fn) {
+    return new MapElements<>(fn, fn.getClass());
   }
 
   /**
@@ -85,42 +85,54 @@ extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
       this.fn = fn;
     }
 
-    public MapElements<InputT, OutputT> withOutputType(TypeDescriptor<OutputT> outputType) {
-      return new MapElements<>(fn, outputType);
+    public MapElements<InputT, OutputT> withOutputType(final TypeDescriptor<OutputT> outputType) {
+      return new MapElements<>(
+          SimpleFunction.fromSerializableFunctionWithOutputType(fn, outputType), fn.getClass());
     }
+
   }
 
   ///////////////////////////////////////////////////////////////////
 
-  private final SerializableFunction<InputT, OutputT> fn;
-  private final transient TypeDescriptor<OutputT> outputType;
+  private final SimpleFunction<InputT, OutputT> fn;
+  private final DisplayData.Item<?> fnClassDisplayData;
 
-  private MapElements(
-      SerializableFunction<InputT, OutputT> fn,
-      TypeDescriptor<OutputT> outputType) {
+  private MapElements(SimpleFunction<InputT, OutputT> fn, Class<?> fnClass) {
     this.fn = fn;
-    this.outputType = outputType;
+    this.fnClassDisplayData = DisplayData.item("mapFn", fnClass).withLabel("Map Function");
   }
 
   @Override
   public PCollection<OutputT> apply(PCollection<InputT> input) {
-    return input.apply("Map", ParDo.of(new DoFn<InputT, OutputT>() {
-      @ProcessElement
-      public void processElement(ProcessContext c) {
-        c.output(fn.apply(c.element()));
-      }
-
-      @Override
-      public void populateDisplayData(DisplayData.Builder builder) {
-        MapElements.this.populateDisplayData(builder);
-      }
-    })).setTypeDescriptorInternal(outputType);
+    return input.apply(
+        "Map",
+        ParDo.of(
+            new DoFn<InputT, OutputT>() {
+              @ProcessElement
+              public void processElement(ProcessContext c) {
+                c.output(fn.apply(c.element()));
+              }
+
+              @Override
+              public void populateDisplayData(DisplayData.Builder builder) {
+                MapElements.this.populateDisplayData(builder);
+              }
+
+              @Override
+              public TypeDescriptor<InputT> getInputTypeDescriptor() {
+                return fn.getInputTypeDescriptor();
+              }
+
+              @Override
+              public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+                return fn.getOutputTypeDescriptor();
+              }
+            }));
   }
 
   @Override
   public void populateDisplayData(DisplayData.Builder builder) {
     super.populateDisplayData(builder);
-    builder.add(DisplayData.item("mapFn", fn.getClass())
-      .withLabel("Map Function"));
+    builder.add(fnClassDisplayData);
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
index 8894352..6c540cc 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/SimpleFunction.java
@@ -27,6 +27,12 @@ import org.apache.beam.sdk.values.TypeDescriptor;
 public abstract class SimpleFunction<InputT, OutputT>
     implements SerializableFunction<InputT, OutputT> {
 
+  public static <InputT, OutputT>
+      SimpleFunction<InputT, OutputT> fromSerializableFunctionWithOutputType(
+          SerializableFunction<InputT, OutputT> fn, TypeDescriptor<OutputT> outputType) {
+    return new SimpleFunctionWithOutputType<>(fn, outputType);
+  }
+
   /**
    * Returns a {@link TypeDescriptor} capturing what is known statically
    * about the input type of this {@code OldDoFn} instance's most-derived
@@ -52,4 +58,32 @@ public abstract class SimpleFunction<InputT, OutputT>
   public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
     return new TypeDescriptor<OutputT>(this) {};
   }
+
+  /**
+   * A {@link SimpleFunction} built from a {@link SerializableFunction}, having
+   * a known output type that is explicitly set.
+   */
+  private static class SimpleFunctionWithOutputType<InputT, OutputT>
+      extends SimpleFunction<InputT, OutputT> {
+
+    private final SerializableFunction<InputT, OutputT> fn;
+    private final TypeDescriptor<OutputT> outputType;
+
+    public SimpleFunctionWithOutputType(
+        SerializableFunction<InputT, OutputT> fn,
+        TypeDescriptor<OutputT> outputType) {
+      this.fn = fn;
+      this.outputType = outputType;
+    }
+
+    @Override
+    public OutputT apply(InputT input) {
+      return fn.apply(input);
+    }
+
+    @Override
+    public TypeDescriptor<OutputT> getOutputTypeDescriptor() {
+      return outputType;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
index 057fd19..781e143 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/FlatMapElementsTest.java
@@ -17,6 +17,8 @@
  */
 package org.apache.beam.sdk.transforms;
 
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+
 import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertThat;
 
@@ -24,6 +26,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TypeDescriptor;
@@ -102,6 +105,51 @@ public class FlatMapElementsTest implements Serializable {
     pipeline.run();
   }
 
+  /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables.
+   */
+  private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, Iterable<T>> {
+    @Override
+    public Iterable<T> apply(T input) {
+      return Collections.<T>emptyList();
+    }
+  }
+
+  /**
+   * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}.
+   */
+  @Test
+  public void testPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output = pipeline
+        .apply(Create.of(1, 2, 3))
+
+        // This is the function that needs to propagate the input T to output T
+        .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>()))
+
+        // This is a consumer to ensure that all coder inference logic is executed.
+        .apply("Test Consumer", MapElements.via(new SimpleFunction<Iterable<Integer>, Integer>() {
+          @Override
+          public Integer apply(Iterable<Integer> input) {
+            return 42;
+          }
+        }));
+  }
+
+  @Test
+  public void testSimpleFunctionClassDisplayData() {
+    SimpleFunction<Integer, List<Integer>> simpleFn = new SimpleFunction<Integer, List<Integer>>() {
+      @Override
+      public List<Integer> apply(Integer input) {
+        return Collections.emptyList();
+      }
+    };
+
+    FlatMapElements<?, ?> simpleMap = FlatMapElements.via(simpleFn);
+    assertThat(DisplayData.from(simpleMap), hasDisplayItem("flatMapFn", simpleFn.getClass()));
+  }
+
   @Test
   @Category(NeedsRunner.class)
   public void testVoidValues() throws Exception {

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/4ac5cafe/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
index b4751d2..dbf8844 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/MapElementsTest.java
@@ -54,6 +54,29 @@ public class MapElementsTest implements Serializable {
   public transient ExpectedException thrown = ExpectedException.none();
 
   /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables.
+   */
+  private static class PolymorphicSimpleFunction<T> extends SimpleFunction<T, T> {
+    @Override
+    public T apply(T input) {
+      return input;
+    }
+  }
+
+  /**
+   * A {@link SimpleFunction} to test that the coder registry can propagate coders
+   * that are bound to type variables, when the variable appears nested in the
+   * output.
+   */
+  private static class NestedPolymorphicSimpleFunction<T> extends SimpleFunction<T, KV<T, String>> {
+    @Override
+    public KV<T, String> apply(T input) {
+      return KV.of(input, "hello");
+    }
+  }
+
+  /**
    * Basic test of {@link MapElements} with a {@link SimpleFunction}.
    */
   @Test
@@ -74,6 +97,55 @@ public class MapElementsTest implements Serializable {
   }
 
   /**
+   * Basic test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}.
+   */
+  @Test
+  public void testPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output = pipeline
+        .apply(Create.of(1, 2, 3))
+
+        // This is the function that needs to propagate the input T to output T
+        .apply("Polymorphic Identity", MapElements.via(new PolymorphicSimpleFunction<Integer>()))
+
+        // This is a consumer to ensure that all coder inference logic is executed.
+        .apply("Test Consumer", MapElements.via(new SimpleFunction<Integer, Integer>() {
+          @Override
+          public Integer apply(Integer input) {
+            return input;
+          }
+        }));
+  }
+
+  /**
+   * Test of {@link MapElements} coder propagation with a parametric {@link SimpleFunction}
+   * where the type variable occurs nested within other concrete type constructors.
+   */
+  @Test
+  public void testNestedPolymorphicSimpleFunction() throws Exception {
+    Pipeline pipeline = TestPipeline.create();
+    PCollection<Integer> output =
+        pipeline
+            .apply(Create.of(1, 2, 3))
+
+            // This is the function that needs to propagate the input T to output T
+            .apply(
+                "Polymorphic Identity",
+                MapElements.via(new NestedPolymorphicSimpleFunction<Integer>()))
+
+            // This is a consumer to ensure that all coder inference logic is executed.
+            .apply(
+                "Test Consumer",
+                MapElements.via(
+                    new SimpleFunction<KV<Integer, String>, Integer>() {
+                      @Override
+                      public Integer apply(KV<Integer, String> input) {
+                        return 42;
+                      }
+                    }));
+  }
+
+  /**
    * Basic test of {@link MapElements} with a {@link SerializableFunction}. This style is
    * generally discouraged in Java 7, in favor of {@link SimpleFunction}.
    */
@@ -148,6 +220,18 @@ public class MapElementsTest implements Serializable {
   }
 
   @Test
+  public void testSimpleFunctionClassDisplayData() {
+    SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() {
+      @Override
+      public Integer apply(Integer input) {
+        return input;
+      }
+    };
+
+    MapElements<?, ?> simpleMap = MapElements.via(simpleFn);
+    assertThat(DisplayData.from(simpleMap), hasDisplayItem("mapFn", simpleFn.getClass()));
+  }
+  @Test
   public void testSimpleFunctionDisplayData() {
     SimpleFunction<?, ?> simpleFn = new SimpleFunction<Integer, Integer>() {
       @Override