You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/12/06 00:35:17 UTC

[beam] branch master updated: Adds a deduplication key to Watch, and uses it to handle growing files in FileIO.match

This is an automated email from the ASF dual-hosted git repository.

jkff pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 38d94c8  Adds a deduplication key to Watch, and uses it to handle growing files in FileIO.match
     new df326cb  This closes #4190: [BEAM-3030] Adds a deduplication key to Watch, and uses it to handle growing files in FileIO.match
38d94c8 is described below

commit 38d94c8e3900a5e59ae8e3be772583da6cb2c8c6
Author: Eugene Kirpichov <ki...@google.com>
AuthorDate: Tue Nov 28 17:02:06 2017 -0800

    Adds a deduplication key to Watch, and uses it to handle growing files in FileIO.match
---
 .../main/java/org/apache/beam/sdk/io/FileIO.java   |  37 +++-
 .../java/org/apache/beam/sdk/transforms/Watch.java | 204 ++++++++++++++-------
 .../org/apache/beam/sdk/transforms/WatchTest.java  | 119 +++++++++---
 3 files changed, 259 insertions(+), 101 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
index a244c07..4e7124a 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java
@@ -33,13 +33,17 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.fs.EmptyMatchTreatment;
 import org.apache.beam.sdk.io.fs.MatchResult;
 import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.transforms.Contextful;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Requirements;
 import org.apache.beam.sdk.transforms.Reshuffle;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.Watch;
+import org.apache.beam.sdk.transforms.Watch.Growth.PollFn;
 import org.apache.beam.sdk.transforms.Watch.Growth.TerminationCondition;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.HasDisplayData;
@@ -69,6 +73,11 @@ public class FileIO {
    * <p>By default, a filepattern matching no resources is treated according to {@link
    * EmptyMatchTreatment#DISALLOW}. To configure this behavior, use {@link
    * Match#withEmptyMatchTreatment}.
+   *
+   * <p>Returned {@link MatchResult.Metadata} are deduplicated by filename. For example, if this
+   * transform observes a file with the same name several times with different metadata (e.g.
+   * because the file is growing), it will emit the metadata the first time this file is observed,
+   * and will ignore future changes to this file.
    */
   public static Match match() {
     return new AutoValue_FileIO_Match.Builder()
@@ -317,13 +326,17 @@ public class FileIO {
             "Match filepatterns",
             ParDo.of(new MatchFn(getConfiguration().getEmptyMatchTreatment())));
       } else {
-        res = input
-            .apply(
-                "Continuously match filepatterns",
-                Watch.growthOf(new MatchPollFn())
-                    .withPollInterval(getConfiguration().getWatchInterval())
-                    .withTerminationPerInput(getConfiguration().getWatchTerminationCondition()))
-            .apply(Values.<MatchResult.Metadata>create());
+        res =
+            input
+                .apply(
+                    "Continuously match filepatterns",
+                    Watch.growthOf(
+                            Contextful.<PollFn<String, MatchResult.Metadata>>of(
+                                new MatchPollFn(), Requirements.empty()),
+                            new ExtractFilenameFn())
+                        .withPollInterval(getConfiguration().getWatchInterval())
+                        .withTerminationPerInput(getConfiguration().getWatchTerminationCondition()))
+                .apply(Values.<MatchResult.Metadata>create());
       }
       return res.apply(Reshuffle.<MatchResult.Metadata>viaRandomKey());
     }
@@ -346,7 +359,7 @@ public class FileIO {
       }
     }
 
-    private static class MatchPollFn extends Watch.Growth.PollFn<String, MatchResult.Metadata> {
+    private static class MatchPollFn extends PollFn<String, MatchResult.Metadata> {
       @Override
       public Watch.Growth.PollResult<MatchResult.Metadata> apply(String element, Context c)
           throws Exception {
@@ -354,6 +367,14 @@ public class FileIO {
             Instant.now(), FileSystems.match(element, EmptyMatchTreatment.ALLOW).metadata());
       }
     }
+
+    private static class ExtractFilenameFn
+        implements SerializableFunction<MatchResult.Metadata, String> {
+      @Override
+      public String apply(MatchResult.Metadata input) {
+        return input.resourceId().toString();
+      }
+    }
   }
 
   /** Implementation of {@link #readMatches}. */
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
index 75c2fe4..4b31ae7 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Watch.java
@@ -58,6 +58,7 @@ import org.apache.beam.sdk.coders.MapCoder;
 import org.apache.beam.sdk.coders.NullableCoder;
 import org.apache.beam.sdk.coders.StructuredCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.transforms.Contextful.Fn;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.values.KV;
@@ -117,29 +118,46 @@ public class Watch {
   private static final Logger LOG = LoggerFactory.getLogger(Watch.class);
 
   /** Watches the growth of the given poll function. See class documentation for more details. */
-  public static <InputT, OutputT> Growth<InputT, OutputT> growthOf(
-      Contextful<Growth.PollFn<InputT, OutputT>> pollFn) {
-    return new AutoValue_Watch_Growth.Builder<InputT, OutputT>()
-        .setTerminationPerInput(Watch.Growth.<InputT>never())
-        .setPollFn(pollFn)
-        .build();
-  }
-
-  /** Watches the growth of the given poll function. See class documentation for more details. */
-  public static <InputT, OutputT> Growth<InputT, OutputT> growthOf(
+  public static <InputT, OutputT> Growth<InputT, OutputT, OutputT> growthOf(
       Growth.PollFn<InputT, OutputT> pollFn, Requirements requirements) {
-    return growthOf(Contextful.of(pollFn, requirements));
+    return new AutoValue_Watch_Growth.Builder<InputT, OutputT, OutputT>()
+        .setTerminationPerInput(Growth.<InputT>never())
+        .setPollFn(Contextful.of(pollFn, requirements))
+        // use null as a signal that this is the identity function and output coder can be
+        // reused as key coder
+        .setOutputKeyFn(null)
+        .build();
   }
 
   /** Watches the growth of the given poll function. See class documentation for more details. */
-  public static <InputT, OutputT> Growth<InputT, OutputT> growthOf(
+  public static <InputT, OutputT> Growth<InputT, OutputT, OutputT> growthOf(
       Growth.PollFn<InputT, OutputT> pollFn) {
     return growthOf(pollFn, Requirements.empty());
   }
 
+  /**
+   * Watches the growth of the given poll function, using the given "key function" to deduplicate
+   * outputs. For example, if OutputT is a filename + file size, this can be a function that returns
+   * just the filename, so that if the same file is observed multiple times with different sizes,
+   * only the first observation is emitted.
+   *
+   * <p>By default, this is the identity function, i.e. the output is used as its own key.
+   */
+  public static <InputT, OutputT, KeyT> Growth<InputT, OutputT, KeyT> growthOf(
+      Contextful<Growth.PollFn<InputT, OutputT>> pollFn,
+      SerializableFunction<OutputT, KeyT> outputKeyFn) {
+    checkArgument(pollFn != null, "pollFn can not be null");
+    checkArgument(outputKeyFn != null, "outputKeyFn can not be null");
+    return new AutoValue_Watch_Growth.Builder<InputT, OutputT, KeyT>()
+        .setTerminationPerInput(Watch.Growth.<InputT>never())
+        .setPollFn(pollFn)
+        .setOutputKeyFn(outputKeyFn)
+        .build();
+  }
+
   /** Implementation of {@link #growthOf}. */
   @AutoValue
-  public abstract static class Growth<InputT, OutputT>
+  public abstract static class Growth<InputT, OutputT, KeyT>
       extends PTransform<PCollection<InputT>, PCollection<KV<InputT, OutputT>>> {
     /** The result of a single invocation of a {@link PollFn}. */
     public static final class PollResult<OutputT> {
@@ -219,7 +237,7 @@ public class Watch {
      * {@link PollResult}.
      */
     public abstract static class PollFn<InputT, OutputT>
-        implements Contextful.Fn<InputT, PollResult<OutputT>> {}
+        implements Fn<InputT, PollResult<OutputT>> {}
 
     /**
      * A strategy for determining whether it is time to stop polling the current input regardless of
@@ -551,6 +569,12 @@ public class Watch {
     abstract Contextful<PollFn<InputT, OutputT>> getPollFn();
 
     @Nullable
+    abstract SerializableFunction<OutputT, KeyT> getOutputKeyFn();
+
+    @Nullable
+    abstract Coder<KeyT> getOutputKeyCoder();
+
+    @Nullable
     abstract Duration getPollInterval();
 
     @Nullable
@@ -559,24 +583,34 @@ public class Watch {
     @Nullable
     abstract Coder<OutputT> getOutputCoder();
 
-    abstract Builder<InputT, OutputT> toBuilder();
+    abstract Builder<InputT, OutputT, KeyT> toBuilder();
 
     @AutoValue.Builder
-    abstract static class Builder<InputT, OutputT> {
-      abstract Builder<InputT, OutputT> setPollFn(Contextful<PollFn<InputT, OutputT>> pollFn);
+    abstract static class Builder<InputT, OutputT, KeyT> {
+      abstract Builder<InputT, OutputT, KeyT> setPollFn(Contextful<PollFn<InputT, OutputT>> pollFn);
+
+      abstract Builder<InputT, OutputT, KeyT> setOutputKeyFn(
+          @Nullable SerializableFunction<OutputT, KeyT> outputKeyFn);
 
-      abstract Builder<InputT, OutputT> setTerminationPerInput(
+      abstract Builder<InputT, OutputT, KeyT> setOutputKeyCoder(Coder<KeyT> outputKeyCoder);
+
+      abstract Builder<InputT, OutputT, KeyT> setTerminationPerInput(
           TerminationCondition<InputT, ?> terminationPerInput);
 
-      abstract Builder<InputT, OutputT> setPollInterval(Duration pollInterval);
+      abstract Builder<InputT, OutputT, KeyT> setPollInterval(Duration pollInterval);
+
+      abstract Builder<InputT, OutputT, KeyT> setOutputCoder(Coder<OutputT> outputCoder);
 
-      abstract Builder<InputT, OutputT> setOutputCoder(Coder<OutputT> outputCoder);
+      abstract Growth<InputT, OutputT, KeyT> build();
+    }
 
-      abstract Growth<InputT, OutputT> build();
+    /** Specifies the coder for the output key. */
+    public Growth<InputT, OutputT, KeyT> withOutputKeyCoder(Coder<KeyT> outputKeyCoder) {
+      return toBuilder().setOutputKeyCoder(outputKeyCoder).build();
     }
 
     /** Specifies a {@link TerminationCondition} that will be independently used for every input. */
-    public Growth<InputT, OutputT> withTerminationPerInput(
+    public Growth<InputT, OutputT, KeyT> withTerminationPerInput(
         TerminationCondition<InputT, ?> terminationPerInput) {
       return toBuilder().setTerminationPerInput(terminationPerInput).build();
     }
@@ -585,7 +619,7 @@ public class Watch {
      * Specifies how long to wait after a call to {@link PollFn} before calling it again (if at all
      * - according to {@link PollResult} and the {@link TerminationCondition}).
      */
-    public Growth<InputT, OutputT> withPollInterval(Duration pollInterval) {
+    public Growth<InputT, OutputT, KeyT> withPollInterval(Duration pollInterval) {
       return toBuilder().setPollInterval(pollInterval).build();
     }
 
@@ -596,7 +630,7 @@ public class Watch {
      * <p>The coder must be deterministic, because the transform will compare encoded outputs for
      * deduplication between polling rounds.
      */
-    public Growth<InputT, OutputT> withOutputCoder(Coder<OutputT> outputCoder) {
+    public Growth<InputT, OutputT, KeyT> withOutputCoder(Coder<OutputT> outputCoder) {
       return toBuilder().setOutputCoder(outputCoder).build();
     }
 
@@ -618,36 +652,68 @@ public class Watch {
           outputCoder = input.getPipeline().getCoderRegistry().getCoder(outputT);
         } catch (CannotProvideCoderException e) {
           throw new RuntimeException(
-              "Unable to infer coder for OutputT. Specify it explicitly using withOutputCoder().");
+              "Unable to infer coder for OutputT ("
+                  + outputT
+                  + "). Specify it explicitly using withOutputCoder().");
         }
       }
-      try {
-        outputCoder.verifyDeterministic();
-      } catch (Coder.NonDeterministicException e) {
-        throw new IllegalArgumentException(
-            "Output coder " + outputCoder + " must be deterministic");
+
+      Coder<KeyT> outputKeyCoder = getOutputKeyCoder();
+      SerializableFunction<OutputT, KeyT> outputKeyFn = getOutputKeyFn();
+      if (getOutputKeyFn() == null) {
+        // This by construction can happen only if OutputT == KeyT
+        outputKeyCoder = (Coder) outputCoder;
+        outputKeyFn = (SerializableFunction) SerializableFunctions.identity();
+      } else {
+        if (outputKeyCoder == null) {
+          // If a coder was not specified explicitly, infer it from the OutputT type parameter
+          // of the output key fn.
+          TypeDescriptor<KeyT> keyT = TypeDescriptors.outputOf(getOutputKeyFn());
+          try {
+            outputKeyCoder = input.getPipeline().getCoderRegistry().getCoder(keyT);
+          } catch (CannotProvideCoderException e) {
+            throw new RuntimeException(
+                "Unable to infer coder for KeyT ("
+                    + keyT
+                    + "). Specify it explicitly using withOutputKeyCoder().");
+          }
+        }
+        try {
+          outputKeyCoder.verifyDeterministic();
+        } catch (Coder.NonDeterministicException e) {
+          throw new IllegalArgumentException(
+              "Key coder " + outputKeyCoder + " must be deterministic");
+        }
       }
 
       return input
-          .apply(ParDo.of(new WatchGrowthFn<>(this, outputCoder))
+          .apply(ParDo.of(new WatchGrowthFn<>(this, outputCoder, outputKeyFn, outputKeyCoder))
           .withSideInputs(getPollFn().getRequirements().getSideInputs()))
           .setCoder(KvCoder.of(input.getCoder(), outputCoder));
     }
   }
 
-  private static class WatchGrowthFn<InputT, OutputT, TerminationStateT>
+  private static class WatchGrowthFn<InputT, OutputT, KeyT, TerminationStateT>
       extends DoFn<InputT, KV<InputT, OutputT>> {
-    private final Watch.Growth<InputT, OutputT> spec;
+    private final Watch.Growth<InputT, OutputT, KeyT> spec;
     private final Coder<OutputT> outputCoder;
-
-    private WatchGrowthFn(Growth<InputT, OutputT> spec, Coder<OutputT> outputCoder) {
+    private final SerializableFunction<OutputT, KeyT> outputKeyFn;
+    private final Coder<KeyT> outputKeyCoder;
+
+    private WatchGrowthFn(
+        Growth<InputT, OutputT, KeyT> spec,
+        Coder<OutputT> outputCoder,
+        SerializableFunction<OutputT, KeyT> outputKeyFn,
+        Coder<KeyT> outputKeyCoder) {
       this.spec = spec;
       this.outputCoder = outputCoder;
+      this.outputKeyFn = outputKeyFn;
+      this.outputKeyCoder = outputKeyCoder;
     }
 
     @ProcessElement
     public ProcessContinuation process(
-        ProcessContext c, final GrowthTracker<OutputT, TerminationStateT> tracker)
+        ProcessContext c, final GrowthTracker<OutputT, KeyT, TerminationStateT> tracker)
         throws Exception {
       if (!tracker.hasPending() && !tracker.currentRestriction().isOutputComplete) {
         LOG.debug("{} - polling input", c.element());
@@ -700,26 +766,27 @@ public class Watch {
     }
 
     @GetInitialRestriction
-    public GrowthState<OutputT, TerminationStateT> getInitialRestriction(InputT element) {
+    public GrowthState<OutputT, KeyT, TerminationStateT> getInitialRestriction(InputT element) {
       return new GrowthState<>(getTerminationCondition().forNewInput(Instant.now(), element));
     }
 
     @NewTracker
-    public GrowthTracker<OutputT, TerminationStateT> newTracker(
-        GrowthState<OutputT, TerminationStateT> restriction) {
-      return new GrowthTracker<>(outputCoder, restriction, getTerminationCondition());
+    public GrowthTracker<OutputT, KeyT, TerminationStateT> newTracker(
+        GrowthState<OutputT, KeyT, TerminationStateT> restriction) {
+      return new GrowthTracker<>(
+          outputKeyFn, outputKeyCoder, restriction, getTerminationCondition());
     }
 
     @GetRestrictionCoder
     @SuppressWarnings({"unchecked", "rawtypes"})
-    public Coder<GrowthState<OutputT, TerminationStateT>> getRestrictionCoder() {
+    public Coder<GrowthState<OutputT, KeyT, TerminationStateT>> getRestrictionCoder() {
       return GrowthStateCoder.of(
           outputCoder, (Coder) spec.getTerminationPerInput().getStateCoder());
     }
   }
 
   @VisibleForTesting
-  static class GrowthState<OutputT, TerminationStateT> {
+  static class GrowthState<OutputT, KeyT, TerminationStateT> {
     // Hashes and timestamps of outputs that have already been output and should be omitted
     // from future polls. Timestamps are preserved to allow garbage-collecting this state
     // in the future, e.g. dropping elements from "completed" and from addNewAsPending() if their
@@ -781,14 +848,14 @@ public class Watch {
   }
 
   @VisibleForTesting
-  static class GrowthTracker<OutputT, TerminationStateT>
-      implements RestrictionTracker<GrowthState<OutputT, TerminationStateT>> {
+  static class GrowthTracker<OutputT, KeyT, TerminationStateT>
+      implements RestrictionTracker<GrowthState<OutputT, KeyT, TerminationStateT>> {
     private final Funnel<OutputT> coderFunnel;
     private final Growth.TerminationCondition<?, TerminationStateT> terminationCondition;
 
     // The restriction describing the entire work to be done by the current ProcessElement call.
     // Changes only in checkpoint().
-    private GrowthState<OutputT, TerminationStateT> state;
+    private GrowthState<OutputT, KeyT, TerminationStateT> state;
 
     // Mutable state changed by the ProcessElement call itself, and used to compute the primary
     // and residual restrictions in checkpoint().
@@ -803,14 +870,19 @@ public class Watch {
     @Nullable private Instant pollWatermark;
     private boolean shouldStop = false;
 
-    GrowthTracker(final Coder<OutputT> outputCoder, GrowthState<OutputT, TerminationStateT> state,
-                  Growth.TerminationCondition<?, TerminationStateT> terminationCondition) {
+    GrowthTracker(
+        final SerializableFunction<OutputT, KeyT> keyFn,
+        final Coder<KeyT> outputKeyCoder,
+        GrowthState<OutputT, KeyT, TerminationStateT> state,
+        Growth.TerminationCondition<?, TerminationStateT> terminationCondition) {
       this.coderFunnel =
           new Funnel<OutputT>() {
             @Override
             public void funnel(OutputT from, PrimitiveSink into) {
               try {
-                outputCoder.encode(from, Funnels.asOutputStream(into));
+                // Rather than hashing the output itself, hash the output key.
+                KeyT outputKey = keyFn.apply(from);
+                outputKeyCoder.encode(outputKey, Funnels.asOutputStream(into));
               } catch (IOException e) {
                 throw new RuntimeException(e);
               }
@@ -825,15 +897,15 @@ public class Watch {
     }
 
     @Override
-    public synchronized GrowthState<OutputT, TerminationStateT> currentRestriction() {
+    public synchronized GrowthState<OutputT, KeyT, TerminationStateT> currentRestriction() {
       return state;
     }
 
     @Override
-    public synchronized GrowthState<OutputT, TerminationStateT> checkpoint() {
+    public synchronized GrowthState<OutputT, KeyT, TerminationStateT> checkpoint() {
       // primary should contain exactly the work claimed in the current ProcessElement call - i.e.
       // claimed outputs become pending, and it shouldn't poll again.
-      GrowthState<OutputT, TerminationStateT> primary =
+      GrowthState<OutputT, KeyT, TerminationStateT> primary =
           new GrowthState<>(
               state.completed /* completed */,
               claimed /* pending */,
@@ -845,9 +917,10 @@ public class Watch {
       // unclaimed pending outputs plus future polling outputs.
       Map<HashCode, Instant> newCompleted = Maps.newHashMap(state.completed);
       for (TimestampedValue<OutputT> claimedOutput : claimed) {
-        newCompleted.put(hash128(claimedOutput.getValue()), claimedOutput.getTimestamp());
+        newCompleted.put(
+            hash128(claimedOutput.getValue()), claimedOutput.getTimestamp());
       }
-      GrowthState<OutputT, TerminationStateT> residual =
+      GrowthState<OutputT, KeyT, TerminationStateT> residual =
           new GrowthState<>(
               newCompleted /* completed */,
               pending /* pending */,
@@ -910,10 +983,14 @@ public class Watch {
           "Should have drained all old pending outputs before adding new, "
               + "but there are %s old pending outputs",
           state.pending.size());
-      List<TimestampedValue<OutputT>> newPending = Lists.newArrayList();
+      // Collect results to include as newly pending. Note that the poll result may in theory
+      // contain multiple outputs mapping to the the same output key - we need to ignore duplicates
+      // here already.
+      Map<HashCode, TimestampedValue<OutputT>> newPending = Maps.newHashMap();
       for (TimestampedValue<OutputT> output : pollResult.getOutputs()) {
         OutputT value = output.getValue();
-        if (state.completed.containsKey(hash128(value))) {
+        HashCode hash = hash128(value);
+        if (state.completed.containsKey(hash) || newPending.containsKey(hash)) {
           continue;
         }
         // TODO (https://issues.apache.org/jira/browse/BEAM-2680):
@@ -921,7 +998,7 @@ public class Watch {
         // instead relying on future poll rounds to provide them, in order to avoid
         // blowing up the state. Combined with garbage collection of GrowthState.completed,
         // this would make the transform scalable to very large poll results.
-        newPending.add(TimestampedValue.of(value, output.getTimestamp()));
+        newPending.put(hash, TimestampedValue.of(value, output.getTimestamp()));
       }
       if (!newPending.isEmpty()) {
         terminationState = terminationCondition.onSeenNewOutput(Instant.now(), terminationState);
@@ -936,7 +1013,7 @@ public class Watch {
                           return output.getTimestamp();
                         }
                       })
-                  .sortedCopy(newPending));
+                  .sortedCopy(newPending.values()));
       // If poll result doesn't provide a watermark, assume that future new outputs may
       // arrive with about the same timestamps as the current new outputs.
       if (pollResult.getWatermark() != null) {
@@ -1008,10 +1085,11 @@ public class Watch {
     }
   }
 
-  private static class GrowthStateCoder<OutputT, TerminationStateT>
-      extends StructuredCoder<GrowthState<OutputT, TerminationStateT>> {
-    public static <OutputT, TerminationStateT> GrowthStateCoder<OutputT, TerminationStateT> of(
-        Coder<OutputT> outputCoder, Coder<TerminationStateT> terminationStateCoder) {
+  private static class GrowthStateCoder<OutputT, KeyT, TerminationStateT>
+      extends StructuredCoder<GrowthState<OutputT, KeyT, TerminationStateT>> {
+    public static <OutputT, KeyT, TerminationStateT>
+        GrowthStateCoder<OutputT, KeyT, TerminationStateT> of(
+            Coder<OutputT> outputCoder, Coder<TerminationStateT> terminationStateCoder) {
       return new GrowthStateCoder<>(outputCoder, terminationStateCoder);
     }
 
@@ -1033,7 +1111,7 @@ public class Watch {
     }
 
     @Override
-    public void encode(GrowthState<OutputT, TerminationStateT> value, OutputStream os)
+    public void encode(GrowthState<OutputT, KeyT, TerminationStateT> value, OutputStream os)
         throws IOException {
       completedCoder.encode(value.completed, os);
       pendingCoder.encode(value.pending, os);
@@ -1043,7 +1121,7 @@ public class Watch {
     }
 
     @Override
-    public GrowthState<OutputT, TerminationStateT> decode(InputStream is) throws IOException {
+    public GrowthState<OutputT, KeyT, TerminationStateT> decode(InputStream is) throws IOException {
       Map<HashCode, Instant> completed = completedCoder.decode(is);
       List<TimestampedValue<OutputT>> pending = pendingCoder.decode(is);
       boolean isOutputComplete = BOOLEAN_CODER.decode(is);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java
index 8904376..ec6880c 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WatchTest.java
@@ -175,6 +175,60 @@ public class WatchTest implements Serializable {
 
   @Test
   @Category({NeedsRunner.class, UsesSplittableParDo.class})
+  public void testMultiplePollsWithKeyExtractor() {
+    List<KV<Integer, String>> polls =
+        Arrays.asList(
+            KV.of(0, "0"),
+            KV.of(10, "10"),
+            KV.of(20, "20"),
+            KV.of(30, "30"),
+            KV.of(40, "40"),
+            KV.of(40, "40.1"),
+            KV.of(20, "20.1"),
+            KV.of(50, "50"),
+            KV.of(10, "10.1"),
+            KV.of(10, "10.2"),
+            KV.of(60, "60"),
+            KV.of(70, "70"),
+            KV.of(60, "60.1"),
+            KV.of(80, "80"),
+            KV.of(40, "40.2"),
+            KV.of(90, "90"),
+            KV.of(90, "90.1"));
+
+    List<Integer> expected = Arrays.asList(0, 10, 20, 30, 40, 50, 60, 70, 80, 90);
+
+    PCollection<Integer> res =
+        p.apply(Create.of("a"))
+            .apply(
+                Watch.growthOf(
+                        Contextful.<PollFn<String, KV<Integer, String>>>of(
+                            new TimedPollFn<String, KV<Integer, String>>(
+                                polls,
+                                standardSeconds(1) /* timeToOutputEverything */,
+                                standardSeconds(3) /* timeToDeclareOutputFinal */,
+                                standardSeconds(30) /* timeToFail */),
+                            Requirements.empty()),
+                        new SerializableFunction<KV<Integer, String>, Integer>() {
+                          @Override
+                          public Integer apply(KV<Integer, String> input) {
+                            return input.getKey();
+                          }
+                        })
+                    .withTerminationPerInput(Watch.Growth.<String>afterTotalOf(standardSeconds(5)))
+                    .withPollInterval(Duration.millis(100))
+                    .withOutputCoder(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of())))
+            .apply("Drop input", Values.<KV<Integer, String>>create())
+            .apply("Drop auxiliary string", Keys.<Integer>create());
+
+    PAssert.that(res).containsInAnyOrder(expected);
+
+    p.run();
+  }
+
+
+  @Test
+  @Category({NeedsRunner.class, UsesSplittableParDo.class})
   public void testMultiplePollsStopAfterTimeSinceNewOutput() {
     List<Integer> all = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
 
@@ -437,20 +491,23 @@ public class WatchTest implements Serializable {
     assertTrue(c.canStopPolling(now.plus(standardSeconds(12)), state));
   }
 
-  private static GrowthTracker<String, Integer> newTracker(GrowthState<String, Integer> state) {
-    return new GrowthTracker<>(StringUtf8Coder.of(), state, never());
+  private static GrowthTracker<String, String, Integer> newTracker(
+      GrowthState<String, String, Integer> state) {
+    return new GrowthTracker<>(
+        SerializableFunctions.<String>identity(), StringUtf8Coder.of(), state, never());
   }
 
-  private static GrowthTracker<String, Integer> newTracker() {
-    return newTracker(new GrowthState<String, Integer>(never().forNewInput(Instant.now(), null)));
+  private static GrowthTracker<String, String, Integer> newTracker() {
+    return newTracker(
+        new GrowthState<String, String, Integer>(never().forNewInput(Instant.now(), null)));
   }
 
   @Test
   public void testGrowthTrackerCheckpointEmpty() {
     // Checkpoint an empty tracker.
-    GrowthTracker<String, Integer> tracker = newTracker();
-    GrowthState<String, Integer> residual = tracker.checkpoint();
-    GrowthState<String, Integer> primary = tracker.currentRestriction();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
+    GrowthState<String, String, Integer> residual = tracker.checkpoint();
+    GrowthState<String, String, Integer> primary = tracker.currentRestriction();
     Watch.Growth.Never<String> condition = never();
     assertEquals(
         primary.toString(condition),
@@ -475,7 +532,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerCheckpointNonEmpty() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.incomplete(
                 Arrays.asList(
@@ -493,8 +550,9 @@ public class WatchTest implements Serializable {
     assertTrue(tracker.hasPending());
     assertEquals(now.plus(standardSeconds(3)), tracker.getWatermark());
 
-    GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
-    GrowthTracker<String, Integer> primaryTracker = newTracker(tracker.currentRestriction());
+    GrowthTracker<String, String, Integer> residualTracker = newTracker(tracker.checkpoint());
+    GrowthTracker<String, String, Integer> primaryTracker =
+        newTracker(tracker.currentRestriction());
 
     // Verify primary: should contain what the current tracker claimed, and nothing else.
     assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark());
@@ -530,7 +588,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerOutputFullyBeforeCheckpointIncomplete() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.incomplete(
                 Arrays.asList(
@@ -547,8 +605,9 @@ public class WatchTest implements Serializable {
     assertFalse(tracker.hasPending());
     assertEquals(now.plus(standardSeconds(7)), tracker.getWatermark());
 
-    GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
-    GrowthTracker<String, Integer> primaryTracker = newTracker(tracker.currentRestriction());
+    GrowthTracker<String, String, Integer> residualTracker = newTracker(tracker.checkpoint());
+    GrowthTracker<String, String, Integer> primaryTracker =
+        newTracker(tracker.currentRestriction());
 
     // Verify primary: should contain what the current tracker claimed, and nothing else.
     assertEquals(now.plus(standardSeconds(1)), primaryTracker.getWatermark());
@@ -582,7 +641,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerPollAfterCheckpointIncompleteWithNewOutputs() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.incomplete(
                 Arrays.asList(
@@ -597,10 +656,10 @@ public class WatchTest implements Serializable {
     assertEquals("c", tracker.tryClaimNextPending().getValue());
     assertEquals("d", tracker.tryClaimNextPending().getValue());
 
-    GrowthState<String, Integer> checkpoint = tracker.checkpoint();
+    GrowthState<String, String, Integer> checkpoint = tracker.checkpoint();
     // Simulate resuming from the checkpoint and adding more elements.
     {
-      GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(checkpoint);
       residualTracker.addNewAsPending(
           PollResult.incomplete(
                   Arrays.asList(
@@ -623,7 +682,7 @@ public class WatchTest implements Serializable {
     }
     // Try same without an explicitly specified watermark.
     {
-      GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(checkpoint);
       residualTracker.addNewAsPending(
           PollResult.incomplete(
               Arrays.asList(
@@ -648,7 +707,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputs() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.incomplete(
                 Arrays.asList(
@@ -664,9 +723,9 @@ public class WatchTest implements Serializable {
     assertEquals("d", tracker.tryClaimNextPending().getValue());
 
     // Simulate resuming from the checkpoint but there are no new elements.
-    GrowthState<String, Integer> checkpoint = tracker.checkpoint();
+    GrowthState<String, String, Integer> checkpoint = tracker.checkpoint();
     {
-      GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(checkpoint);
       residualTracker.addNewAsPending(
           PollResult.incomplete(
                   Arrays.asList(
@@ -682,7 +741,7 @@ public class WatchTest implements Serializable {
     }
     // Try the same without an explicitly specified watermark
     {
-      GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(checkpoint);
       residualTracker.addNewAsPending(
           PollResult.incomplete(
               Arrays.asList(
@@ -698,7 +757,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerPollAfterCheckpointWithoutNewOutputsNoWatermark() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.incomplete(
             Arrays.asList(
@@ -713,8 +772,8 @@ public class WatchTest implements Serializable {
     assertEquals(now.plus(standardSeconds(1)), tracker.getWatermark());
 
     // Simulate resuming from the checkpoint but there are no new elements.
-    GrowthState<String, Integer> checkpoint = tracker.checkpoint();
-    GrowthTracker<String, Integer> residualTracker = newTracker(checkpoint);
+    GrowthState<String, String, Integer> checkpoint = tracker.checkpoint();
+    GrowthTracker<String, String, Integer> residualTracker = newTracker(checkpoint);
     residualTracker.addNewAsPending(
         PollResult.incomplete(
             Arrays.asList(
@@ -730,13 +789,13 @@ public class WatchTest implements Serializable {
   public void testGrowthTrackerRepeatedEmptyPollWatermark() {
     // Empty poll result with no watermark
     {
-      GrowthTracker<String, Integer> tracker = newTracker();
+      GrowthTracker<String, String, Integer> tracker = newTracker();
       tracker.addNewAsPending(
           PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
       assertEquals(BoundedWindow.TIMESTAMP_MIN_VALUE, tracker.getWatermark());
 
       // Simulate resuming from the checkpoint but there are still no new elements.
-      GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(tracker.checkpoint());
       tracker.addNewAsPending(
           PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
       // No new elements and no explicit watermark supplied - still no watermark.
@@ -745,14 +804,14 @@ public class WatchTest implements Serializable {
     // Empty poll result with watermark
     {
       Instant now = Instant.now();
-      GrowthTracker<String, Integer> tracker = newTracker();
+      GrowthTracker<String, String, Integer> tracker = newTracker();
       tracker.addNewAsPending(
           PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList())
               .withWatermark(now));
       assertEquals(now, tracker.getWatermark());
 
       // Simulate resuming from the checkpoint but there are still no new elements.
-      GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+      GrowthTracker<String, String, Integer> residualTracker = newTracker(tracker.checkpoint());
       tracker.addNewAsPending(
           PollResult.incomplete(Collections.<TimestampedValue<String>>emptyList()));
       // No new elements and no explicit watermark supplied - should keep old watermark.
@@ -763,7 +822,7 @@ public class WatchTest implements Serializable {
   @Test
   public void testGrowthTrackerOutputFullyBeforeCheckpointComplete() {
     Instant now = Instant.now();
-    GrowthTracker<String, Integer> tracker = newTracker();
+    GrowthTracker<String, String, Integer> tracker = newTracker();
     tracker.addNewAsPending(
         PollResult.complete(
             Arrays.asList(
@@ -779,7 +838,7 @@ public class WatchTest implements Serializable {
     assertFalse(tracker.hasPending());
     assertEquals(BoundedWindow.TIMESTAMP_MAX_VALUE, tracker.getWatermark());
 
-    GrowthTracker<String, Integer> residualTracker = newTracker(tracker.checkpoint());
+    GrowthTracker<String, String, Integer> residualTracker = newTracker(tracker.checkpoint());
 
     // Verify residual: should be empty, since output was final.
     residualTracker.checkDone();

-- 
To stop receiving notification emails like this one, please contact
['"commits@beam.apache.org" <co...@beam.apache.org>'].