You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2020/11/24 18:30:21 UTC

[beam] branch master updated: Add an option to GroupIntoBatches to output ShardedKeys. Update Dataflow pipeline translation accordingly.

This is an automated email from the ASF dual-hosted git repository.

robinyqiu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b51d64e  Add an option to GroupIntoBatches to output ShardedKeys. Update Dataflow pipeline translation accordingly.
     new f4d889f  Merge pull request #13208 from nehsyc/fix_override
b51d64e is described below

commit b51d64e0eee662a1cc75f1a558ef99c2e812813e
Author: sychen <sy...@google.com>
AuthorDate: Tue Oct 27 14:30:58 2020 -0700

    Add an option to GroupIntoBatches to output ShardedKeys. Update Dataflow pipeline translation accordingly.
---
 .../core/construction/PTransformMatchers.java      |   2 +-
 .../dataflow/DataflowPipelineTranslator.java       |  14 ++-
 .../beam/runners/dataflow/DataflowRunner.java      |  29 ++++-
 .../runners/dataflow/GroupIntoBatchesOverride.java | 132 ++++++++++++++++++---
 .../beam/runners/dataflow/util/PropertyNames.java  |   1 +
 .../dataflow/DataflowPipelineTranslatorTest.java   |  67 ++++++-----
 .../beam/runners/dataflow/DataflowRunnerTest.java  |  64 ++++++++--
 .../beam/sdk/transforms/GroupIntoBatches.java      |  70 ++++++++++-
 .../beam/sdk/transforms/GroupIntoBatchesTest.java  |  82 +++++++++++++
 9 files changed, 392 insertions(+), 69 deletions(-)

diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index 0bfbc4b..781ad52 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -487,7 +487,7 @@ public class PTransformMatchers {
     return new PTransformMatcher() {
       @Override
       public boolean matches(AppliedPTransform<?, ?, ?> application) {
-        return application.getTransform().getClass().equals(GroupIntoBatches.class);
+        return application.getTransform().getClass().equals(GroupIntoBatches.WithShardedKey.class);
       }
 
       @Override
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index a5c25e9..51f8fc2 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -654,10 +654,6 @@ public class DataflowPipelineTranslator {
       if (value instanceof PValue) {
         PValue pvalue = (PValue) value;
         addInput(name, translator.asOutputReference(pvalue, translator.getProducer(pvalue)));
-        if (value instanceof PCollection
-            && translator.runner.doesPCollectionRequireAutoSharding((PCollection<?>) value)) {
-          addInput(PropertyNames.ALLOWS_SHARDABLE_STATE, "true");
-        }
       } else {
         throw new IllegalStateException("Input must be a PValue");
       }
@@ -696,6 +692,16 @@ public class DataflowPipelineTranslator {
     private void addOutput(String name, PValue value, Coder<?> valueCoder) {
       translator.registerOutputName(value, name);
 
+      // If the output requires runner determined sharding, also append necessary input properties.
+      if (value instanceof PCollection
+          && translator.runner.doesPCollectionRequireAutoSharding((PCollection<?>) value)) {
+        addInput(PropertyNames.ALLOWS_SHARDABLE_STATE, "true");
+        // Currently we only allow auto-sharding to be enabled through the GroupIntoBatches
+        // transform. So we also add the following property which GroupIntoBatchesDoFn has, to allow
+        // the backend to perform graph optimization.
+        addInput(PropertyNames.PRESERVES_KEYS, "true");
+      }
+
       Map<String, Object> properties = getProperties();
       @Nullable List<Map<String, Object>> outputInfoList = null;
       try {
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 20e3272..428c998 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -494,7 +494,8 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
       overridesBuilder.add(
           PTransformOverride.of(
               PTransformMatchers.groupWithShardableStates(),
-              new GroupIntoBatchesOverride.StreamingGroupIntoBatchesOverrideFactory(this)));
+              new GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKeyOverrideFactory(
+                  this)));
 
       if (!fnApiEnabled) {
         overridesBuilder
@@ -526,9 +527,14 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
           // Replace GroupIntoBatches before the state/timer replacements below since
           // GroupIntoBatches internally uses a stateful DoFn.
           .add(
-          PTransformOverride.of(
-              PTransformMatchers.classEqualTo(GroupIntoBatches.class),
-              new GroupIntoBatchesOverride.BatchGroupIntoBatchesOverrideFactory()));
+              PTransformOverride.of(
+                  PTransformMatchers.classEqualTo(GroupIntoBatches.class),
+                  new GroupIntoBatchesOverride.BatchGroupIntoBatchesOverrideFactory<>()))
+          .add(
+              PTransformOverride.of(
+                  PTransformMatchers.classEqualTo(GroupIntoBatches.WithShardedKey.class),
+                  new GroupIntoBatchesOverride
+                      .BatchGroupIntoBatchesWithShardedKeyOverrideFactory<>()));
 
       overridesBuilder
           // State and timer pardos are implemented by expansion to GBK-then-ParDo
@@ -1281,8 +1287,19 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
   }
 
   void maybeRecordPCollectionWithAutoSharding(PCollection<?> pcol) {
-    if (hasExperiment(options, "enable_streaming_auto_sharding")
-        && !hasExperiment(options, "beam_fn_api")) {
+    if (hasExperiment(options, "beam_fn_api")) {
+      LOG.warn(
+          "Runner determined sharding not available in Dataflow for GroupIntoBatches for portable "
+              + "jobs. Default sharding will be applied.");
+      return;
+    }
+    if (!options.isEnableStreamingEngine()) {
+      LOG.warn(
+          "Runner determined sharding not available in Dataflow for GroupIntoBatches for Streaming "
+              + "Appliance jobs. Default sharding will be applied.");
+      return;
+    }
+    if (hasExperiment(options, "enable_streaming_auto_sharding")) {
       pcollectionsRequiringAutoSharding.add(pcol);
     }
   }
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
index b6a13ef..ea92f9d 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
@@ -17,18 +17,24 @@
  */
 package org.apache.beam.runners.dataflow;
 
+import java.nio.ByteBuffer;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.UUID;
 import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.GroupIntoBatches;
+import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TupleTag;
@@ -48,7 +54,7 @@ public class GroupIntoBatchesOverride {
                 transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
-          new BatchGroupIntoBatches(transform.getTransform().getBatchSize()));
+          new BatchGroupIntoBatches<>(transform.getTransform().getBatchSize()));
     }
 
     @Override
@@ -92,54 +98,142 @@ public class GroupIntoBatchesOverride {
     }
   }
 
-  static class StreamingGroupIntoBatchesOverrideFactory<K, V>
+  static class BatchGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
       implements PTransformOverrideFactory<
-          PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>> {
+          PCollection<KV<K, V>>,
+          PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+          GroupIntoBatches<K, V>.WithShardedKey> {
+
+    @Override
+    public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, V>>,
+                    PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+                    GroupIntoBatches<K, V>.WithShardedKey>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new BatchGroupIntoBatchesWithShardedKey<>(transform.getTransform().getBatchSize()));
+    }
+
+    @Override
+    public Map<PCollection<?>, ReplacementOutput> mapOutputs(
+        Map<TupleTag<?>, PCollection<?>> outputs,
+        PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
+      return ReplacementOutputs.singleton(outputs, newOutput);
+    }
+  }
+
+  /**
+   * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for bounded Dataflow
+   * pipelines.
+   */
+  static class BatchGroupIntoBatchesWithShardedKey<K, V>
+      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
+
+    private final long batchSize;
+
+    private BatchGroupIntoBatchesWithShardedKey(long batchSize) {
+      this.batchSize = batchSize;
+    }
+
+    @Override
+    public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
+      return shardKeys(input).apply(new BatchGroupIntoBatches<>(batchSize));
+    }
+  }
+
+  static class StreamingGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
+      implements PTransformOverrideFactory<
+          PCollection<KV<K, V>>,
+          PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+          GroupIntoBatches<K, V>.WithShardedKey> {
 
     private final DataflowRunner runner;
 
-    StreamingGroupIntoBatchesOverrideFactory(DataflowRunner runner) {
+    StreamingGroupIntoBatchesWithShardedKeyOverrideFactory(DataflowRunner runner) {
       this.runner = runner;
     }
 
     @Override
-    public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>
+    public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
         getReplacementTransform(
             AppliedPTransform<
-                    PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>>
+                    PCollection<KV<K, V>>,
+                    PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+                    GroupIntoBatches<K, V>.WithShardedKey>
                 transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
-          new StreamingGroupIntoBatches(runner, transform.getTransform()));
+          new StreamingGroupIntoBatchesWithShardedKey<>(
+              runner,
+              transform.getTransform(),
+              PTransformReplacements.getSingletonMainOutput(transform)));
     }
 
     @Override
     public Map<PCollection<?>, ReplacementOutput> mapOutputs(
-        Map<TupleTag<?>, PCollection<?>> outputs, PCollection<KV<K, Iterable<V>>> newOutput) {
+        Map<TupleTag<?>, PCollection<?>> outputs,
+        PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
       return ReplacementOutputs.singleton(outputs, newOutput);
     }
   }
 
   /**
-   * Specialized implementation of {@link GroupIntoBatches} for unbounded Dataflow pipelines. The
-   * override does the same thing as the original transform but additionally record the input to add
-   * corresponding properties during the graph translation.
+   * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for unbounded Dataflow
+   * pipelines. The override does the same thing as the original transform but additionally records
+   * the output in order to append required step properties during the graph translation.
    */
-  static class StreamingGroupIntoBatches<K, V>
-      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> {
+  static class StreamingGroupIntoBatchesWithShardedKey<K, V>
+      extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
 
     private final transient DataflowRunner runner;
-    private final GroupIntoBatches<K, V> original;
+    private final GroupIntoBatches<K, V>.WithShardedKey originalTransform;
+    private final transient PCollection<KV<ShardedKey<K>, Iterable<V>>> originalOutput;
 
-    public StreamingGroupIntoBatches(DataflowRunner runner, GroupIntoBatches<K, V> original) {
+    public StreamingGroupIntoBatchesWithShardedKey(
+        DataflowRunner runner,
+        GroupIntoBatches<K, V>.WithShardedKey original,
+        PCollection<KV<ShardedKey<K>, Iterable<V>>> output) {
       this.runner = runner;
-      this.original = original;
+      this.originalTransform = original;
+      this.originalOutput = output;
     }
 
     @Override
-    public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
-      runner.maybeRecordPCollectionWithAutoSharding(input);
-      return input.apply(original);
+    public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
+      // Record the output PCollection of the original transform since the new output will be
+      // replaced by the original one when the replacement transform is wired to other nodes in the
+      // graph, although the old and the new outputs are effectively the same.
+      runner.maybeRecordPCollectionWithAutoSharding(originalOutput);
+      return input.apply(originalTransform);
     }
   }
+
+  private static final UUID workerUuid = UUID.randomUUID();
+
+  private static <K, V> PCollection<KV<ShardedKey<K>, V>> shardKeys(PCollection<KV<K, V>> input) {
+    KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+    org.apache.beam.sdk.coders.Coder<K> keyCoder =
+        (org.apache.beam.sdk.coders.Coder<K>) inputCoder.getCoderArguments().get(0);
+    org.apache.beam.sdk.coders.Coder<V> valueCoder =
+        (org.apache.beam.sdk.coders.Coder<V>) inputCoder.getCoderArguments().get(1);
+    return input
+        .apply(
+            "Shard Keys",
+            MapElements.via(
+                new SimpleFunction<KV<K, V>, KV<ShardedKey<K>, V>>() {
+                  @Override
+                  public KV<ShardedKey<K>, V> apply(KV<K, V> input) {
+                    long tid = Thread.currentThread().getId();
+                    ByteBuffer buffer = ByteBuffer.allocate(3 * Long.BYTES);
+                    buffer.putLong(workerUuid.getMostSignificantBits());
+                    buffer.putLong(workerUuid.getLeastSignificantBits());
+                    buffer.putLong(tid);
+                    return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
+                  }
+                }))
+        .setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder));
+  }
 }
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
index dacfdc1..f48bcac 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
@@ -64,6 +64,7 @@ public class PropertyNames {
   public static final String VALUE = "value";
   public static final String WINDOWING_STRATEGY = "windowing_strategy";
   public static final String DISPLAY_DATA = "display_data";
+  public static final String PRESERVES_KEYS = "preserves_keys";
   /**
    * @deprecated Uses the incorrect terminology. {@link #RESTRICTION_ENCODING}. Should be removed
    *     once non FnAPI SplittableDoFn expansion for Dataflow is removed.
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index 202c1e7..a140b1e 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -78,6 +78,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
 import org.apache.beam.sdk.extensions.gcp.util.GcsUtil;
 import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath;
 import org.apache.beam.sdk.io.FileSystems;
@@ -1120,17 +1121,28 @@ public class DataflowPipelineTranslatorTest implements Serializable {
     assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind());
   }
 
-  @Test
-  public void testStreamingGroupIntoBatchesTranslation() throws Exception {
+  private Map<String, Object> runGroupIntoBatchesAndGetStepProperties(
+      Boolean withShardedKey, Boolean usesFnApi) throws IOException {
     DataflowPipelineOptions options = buildPipelineOptions();
-    options.setExperiments(Arrays.asList("enable_streaming_auto_sharding"));
+    options.setExperiments(
+        Arrays.asList(
+            "enable_streaming_auto_sharding",
+            GcpOptions.STREAMING_ENGINE_EXPERIMENT,
+            GcpOptions.WINDMILL_SERVICE_EXPERIMENT));
+    if (usesFnApi) {
+      options.setExperiments(Arrays.asList("beam_fn_api"));
+    }
     options.setStreaming(true);
     DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options);
 
     Pipeline pipeline = Pipeline.create(options);
-    pipeline
-        .apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))))
-        .apply(GroupIntoBatches.ofSize(2));
+    PCollection<KV<Integer, String>> input =
+        pipeline.apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))));
+    if (withShardedKey) {
+      input.apply(GroupIntoBatches.<Integer, String>ofSize(3).withShardedKey());
+    } else {
+      input.apply(GroupIntoBatches.ofSize(3));
+    }
 
     DataflowRunner runner = DataflowRunner.fromOptions(options);
     runner.replaceTransforms(pipeline);
@@ -1142,40 +1154,37 @@ public class DataflowPipelineTranslatorTest implements Serializable {
             .getJob();
     List<Step> steps = job.getSteps();
     Step shardedStateStep = steps.get(steps.size() - 1);
-    Map<String, Object> properties = shardedStateStep.getProperties();
+    return shardedStateStep.getProperties();
+  }
+
+  @Test
+  public void testStreamingGroupIntoBatchesTranslation() throws Exception {
+    Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(false, false);
+    assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
+    assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
+    assertFalse(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
+    assertFalse(properties.containsKey(PropertyNames.PRESERVES_KEYS));
+  }
+
+  @Test
+  public void testStreamingGroupIntoBatchesWithShardedKeyTranslation() throws Exception {
+    Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(true, false);
     assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
     assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
     assertTrue(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
     assertEquals("true", getString(properties, PropertyNames.ALLOWS_SHARDABLE_STATE));
+    assertTrue(properties.containsKey(PropertyNames.PRESERVES_KEYS));
+    assertEquals("true", getString(properties, PropertyNames.PRESERVES_KEYS));
   }
 
   @Test
   public void testStreamingGroupIntoBatchesTranslationFnApi() throws Exception {
-    DataflowPipelineOptions options = buildPipelineOptions();
-    options.setExperiments(Arrays.asList("enable_windmill_auto_sharding"));
-    options.setExperiments(Arrays.asList("beam_fn_api"));
-    options.setStreaming(true);
-    DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options);
-
-    Pipeline pipeline = Pipeline.create(options);
-    pipeline
-        .apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))))
-        .apply(GroupIntoBatches.ofSize(2));
-
-    DataflowRunner runner = DataflowRunner.fromOptions(options);
-    runner.replaceTransforms(pipeline);
-    SdkComponents sdkComponents = createSdkComponents(options);
-    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, sdkComponents, true);
-    Job job =
-        translator
-            .translate(pipeline, pipelineProto, sdkComponents, runner, Collections.emptyList())
-            .getJob();
-    List<Step> steps = job.getSteps();
-    Step shardedStateStep = steps.get(steps.size() - 1);
-    Map<String, Object> properties = shardedStateStep.getProperties();
+    Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(true, true);
     assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
     assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
+    // "allows_shardable_state" is currently unsupported for portable jobs.
     assertFalse(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
+    assertFalse(properties.containsKey(PropertyNames.PRESERVES_KEYS));
   }
 
   @Test
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index d5e8fa7..cb74bff 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -38,6 +38,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeFalse;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -131,14 +132,17 @@ import org.apache.beam.sdk.testing.ValidatesRunner;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupIntoBatches;
+import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.transforms.windowing.Sessions;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PValues;
@@ -1605,12 +1609,32 @@ public class DataflowRunnerTest implements Serializable {
     verifyMergingStatefulParDoRejected(options);
   }
 
-  private void verifyGroupIntoBatchesOverride(Pipeline p) {
+  private void verifyGroupIntoBatchesOverride(
+      Pipeline p, Boolean withShardedKey, Boolean expectOverriden) {
     final int batchSize = 2;
     List<KV<String, Integer>> testValues =
         Arrays.asList(KV.of("A", 1), KV.of("B", 0), KV.of("A", 2), KV.of("A", 4), KV.of("A", 8));
-    PCollection<KV<String, Iterable<Integer>>> output =
-        p.apply(Create.of(testValues)).apply(GroupIntoBatches.ofSize(batchSize));
+    PCollection<KV<String, Integer>> input = p.apply(Create.of(testValues));
+    PCollection<KV<String, Iterable<Integer>>> output;
+    if (withShardedKey) {
+      output =
+          input
+              .apply(GroupIntoBatches.<String, Integer>ofSize(batchSize).withShardedKey())
+              .apply(
+                  "StripShardId",
+                  MapElements.via(
+                      new SimpleFunction<
+                          KV<ShardedKey<String>, Iterable<Integer>>,
+                          KV<String, Iterable<Integer>>>() {
+                        @Override
+                        public KV<String, Iterable<Integer>> apply(
+                            KV<ShardedKey<String>, Iterable<Integer>> input) {
+                          return KV.of(input.getKey().getKey(), input.getValue());
+                        }
+                      }));
+    } else {
+      output = input.apply(GroupIntoBatches.ofSize(batchSize));
+    }
     PAssert.thatMultimap(output)
         .satisfies(
             new SerializableFunction<Map<String, Iterable<Iterable<Integer>>>, Void>() {
@@ -1642,23 +1666,41 @@ public class DataflowRunnerTest implements Serializable {
           public CompositeBehavior enterCompositeTransform(Node node) {
             if (p.getOptions().as(StreamingOptions.class).isStreaming()
                 && node.getTransform()
-                    instanceof GroupIntoBatchesOverride.StreamingGroupIntoBatches) {
+                    instanceof GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKey) {
               sawGroupIntoBatchesOverride.set(true);
             }
             if (!p.getOptions().as(StreamingOptions.class).isStreaming()
                 && node.getTransform() instanceof GroupIntoBatchesOverride.BatchGroupIntoBatches) {
               sawGroupIntoBatchesOverride.set(true);
             }
+            if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+                && node.getTransform()
+                    instanceof GroupIntoBatchesOverride.BatchGroupIntoBatchesWithShardedKey) {
+              sawGroupIntoBatchesOverride.set(true);
+            }
             return CompositeBehavior.ENTER_TRANSFORM;
           }
         });
-    assertTrue(sawGroupIntoBatchesOverride.get());
+    if (expectOverriden) {
+      assertTrue(sawGroupIntoBatchesOverride.get());
+    } else {
+      assertFalse(sawGroupIntoBatchesOverride.get());
+    }
   }
 
   @Test
   @Category(ValidatesRunner.class)
   public void testBatchGroupIntoBatchesOverride() {
-    verifyGroupIntoBatchesOverride(pipeline);
+    // Ignore this test for streaming pipelines.
+    assumeFalse(pipeline.getOptions().as(StreamingOptions.class).isStreaming());
+    verifyGroupIntoBatchesOverride(pipeline, false, true);
+  }
+
+  @Test
+  public void testBatchGroupIntoBatchesWithShardedKeyOverride() throws IOException {
+    PipelineOptions options = buildPipelineOptions();
+    Pipeline p = Pipeline.create(options);
+    verifyGroupIntoBatchesOverride(p, true, true);
   }
 
   @Test
@@ -1666,7 +1708,15 @@ public class DataflowRunnerTest implements Serializable {
     PipelineOptions options = buildPipelineOptions();
     options.as(StreamingOptions.class).setStreaming(true);
     Pipeline p = Pipeline.create(options);
-    verifyGroupIntoBatchesOverride(p);
+    verifyGroupIntoBatchesOverride(p, false, false);
+  }
+
+  @Test
+  public void testStreamingGroupIntoBatchesWithShardedKeyOverride() throws IOException {
+    PipelineOptions options = buildPipelineOptions();
+    options.as(StreamingOptions.class).setStreaming(true);
+    Pipeline p = Pipeline.create(options);
+    verifyGroupIntoBatchesOverride(p, true, true);
   }
 
   private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
index ae3b0f2..1ac6493 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
@@ -19,7 +19,10 @@ package org.apache.beam.sdk.transforms;
 
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
+import java.nio.ByteBuffer;
+import java.util.UUID;
 import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.state.BagState;
@@ -32,6 +35,7 @@ import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
@@ -89,6 +93,7 @@ public class GroupIntoBatches<K, InputT>
 
   private final long batchSize;
   @Nullable private final Duration maxBufferingDuration;
+  private static final UUID workerUuid = UUID.randomUUID();
 
   private GroupIntoBatches(long batchSize, @Nullable Duration maxBufferingDuration) {
     this.batchSize = batchSize;
@@ -105,8 +110,8 @@ public class GroupIntoBatches<K, InputT>
   }
 
   /**
-   * Set a time limit (in processing time) on how long an incomplete batch of elements is allowed to
-   * be buffered. Once a batch is flushed to output, the timer is reset.
+   * Sets a time limit (in processing time) on how long an incomplete batch of elements is allowed
+   * to be buffered. Once a batch is flushed to output, the timer is reset.
    */
   public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration duration) {
     checkArgument(
@@ -114,6 +119,65 @@ public class GroupIntoBatches<K, InputT>
     return new GroupIntoBatches<>(batchSize, duration);
   }
 
+  /**
+   * Outputs batched elements associated with sharded input keys. By default, keys are sharded to
+   * such that the input elements with the same key are spread to all available threads executing
+   * the transform. Runners may override the default sharding to do a better load balancing during
+   * the execution time.
+   */
+  @Experimental
+  public WithShardedKey withShardedKey() {
+    return new WithShardedKey();
+  }
+
+  public class WithShardedKey
+      extends PTransform<
+          PCollection<KV<K, InputT>>, PCollection<KV<ShardedKey<K>, Iterable<InputT>>>> {
+    private WithShardedKey() {}
+
+    /** Returns the size of the batch. */
+    public long getBatchSize() {
+      return batchSize;
+    }
+
+    @Override
+    public PCollection<KV<ShardedKey<K>, Iterable<InputT>>> expand(
+        PCollection<KV<K, InputT>> input) {
+      Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
+
+      checkArgument(
+          input.getCoder() instanceof KvCoder,
+          "coder specified in the input PCollection is not a KvCoder");
+      KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+      Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
+      Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
+
+      return input
+          .apply(
+              MapElements.via(
+                  new SimpleFunction<KV<K, InputT>, KV<ShardedKey<K>, InputT>>() {
+                    @Override
+                    public KV<ShardedKey<K>, InputT> apply(KV<K, InputT> input) {
+                      long tid = Thread.currentThread().getId();
+                      ByteBuffer buffer = ByteBuffer.allocate(3 * Long.BYTES);
+                      buffer.putLong(workerUuid.getMostSignificantBits());
+                      buffer.putLong(workerUuid.getLeastSignificantBits());
+                      buffer.putLong(tid);
+                      return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
+                    }
+                  }))
+          .setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder))
+          .apply(
+              ParDo.of(
+                  new GroupIntoBatchesDoFn<>(
+                      batchSize,
+                      allowedLateness,
+                      maxBufferingDuration,
+                      ShardedKey.Coder.of(keyCoder),
+                      valueCoder)));
+    }
+  }
+
   @Override
   public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {
     Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
@@ -121,7 +185,7 @@ public class GroupIntoBatches<K, InputT>
     checkArgument(
         input.getCoder() instanceof KvCoder,
         "coder specified in the input PCollection is not a KvCoder");
-    KvCoder inputCoder = (KvCoder) input.getCoder();
+    KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
     Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
     Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
 
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
index f7b0906..3334875 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
@@ -44,6 +44,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.Repeatedly;
 import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ShardedKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.TimestampedValue;
@@ -125,6 +126,87 @@ public class GroupIntoBatchesTest implements Serializable {
     pipeline.run();
   }
 
+  @Test
+  @Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class})
+  public void testWithShardedKeyInGlobalWindow() {
+    // Since with default sharding, the number of subshards of of a key is nondeterministic, create
+    // a large number of input elements and a small batch size and check there is no batch larger
+    // than the specified size.
+    int numElements = 10000;
+    int batchSize = 5;
+    PCollection<KV<ShardedKey<String>, Iterable<String>>> collection =
+        pipeline
+            .apply("Input data", Create.of(createTestData(numElements)))
+            .apply(GroupIntoBatches.<String, String>ofSize(batchSize).withShardedKey())
+            .setCoder(
+                KvCoder.of(
+                    ShardedKey.Coder.of(StringUtf8Coder.of()),
+                    IterableCoder.of(StringUtf8Coder.of())));
+    PAssert.that("Incorrect batch size in one or more elements", collection)
+        .satisfies(
+            new SerializableFunction<Iterable<KV<ShardedKey<String>, Iterable<String>>>, Void>() {
+
+              private boolean checkBatchSizes(
+                  Iterable<KV<ShardedKey<String>, Iterable<String>>> listToCheck) {
+                for (KV<ShardedKey<String>, Iterable<String>> element : listToCheck) {
+                  if (Iterables.size(element.getValue()) > batchSize) {
+                    return false;
+                  }
+                }
+                return true;
+              }
+
+              @Override
+              public Void apply(Iterable<KV<ShardedKey<String>, Iterable<String>>> input) {
+                assertTrue(checkBatchSizes(input));
+                return null;
+              }
+            });
+    PCollection<KV<Integer, Long>> numBatchesbyBatchSize =
+        collection
+            .apply(
+                "KeyByBatchSize",
+                MapElements.via(
+                    new SimpleFunction<
+                        KV<ShardedKey<String>, Iterable<String>>, KV<Integer, Integer>>() {
+                      @Override
+                      public KV<Integer, Integer> apply(
+                          KV<ShardedKey<String>, Iterable<String>> input) {
+                        int batchSize = 0;
+                        for (String ignored : input.getValue()) {
+                          batchSize++;
+                        }
+                        return KV.of(batchSize, 1);
+                      }
+                    }))
+            .apply("CountBatchesBySize", Count.perKey());
+    PAssert.that("Expecting majority of the batches are full", numBatchesbyBatchSize)
+        .satisfies(
+            (SerializableFunction<Iterable<KV<Integer, Long>>, Void>)
+                listOfBatchSize -> {
+                  Long numFullBatches = 0L;
+                  Long totalNumBatches = 0L;
+                  for (KV<Integer, Long> batchSizeAndCount : listOfBatchSize) {
+                    if (batchSizeAndCount.getKey() == batchSize) {
+                      numFullBatches += batchSizeAndCount.getValue();
+                    }
+                    totalNumBatches += batchSizeAndCount.getValue();
+                  }
+                  assertTrue(
+                      String.format(
+                          "total number of batches should be in the range [%d, %d] but got %d",
+                          numElements, numElements / batchSize, numFullBatches),
+                      numFullBatches <= numElements && numFullBatches >= numElements / batchSize);
+                  assertTrue(
+                      String.format(
+                          "number of full batches vs. total number of batches in total: %d vs. %d",
+                          numFullBatches, totalNumBatches),
+                      numFullBatches > totalNumBatches / 2);
+                  return null;
+                });
+    pipeline.run();
+  }
+
   /** test behavior when the number of input elements is not evenly divisible by batch size. */
   @Test
   @Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class})