You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by re...@apache.org on 2023/07/22 17:40:51 UTC

[beam] branch master updated: Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform

This is an automated email from the ASF dual-hosted git repository.

reuvenlax pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 05305ede453 Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform
05305ede453 is described below

commit 05305ede45366f158f27fc2b83b9ce00db4df2ab
Author: Reuven Lax <re...@google.com>
AuthorDate: Sat Jul 22 10:40:43 2023 -0700

    Merge pull request #27617: Support withFanout and withHotKeyFanout on schema group transform
---
 .../apache/beam/sdk/schemas/transforms/Group.java  | 162 ++++++++++++++++-----
 .../beam/sdk/schemas/transforms/GroupTest.java     |  96 +++++++++---
 .../sdk/extensions/sql/impl/rel/BeamWindowRel.java |   5 +-
 3 files changed, 201 insertions(+), 62 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
index fb48ceed311..fe8933d24d5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/transforms/Group.java
@@ -17,7 +17,9 @@
  */
 package org.apache.beam.sdk.schemas.transforms;
 
+import com.google.auto.value.AutoOneOf;
 import com.google.auto.value.AutoValue;
+import java.io.Serializable;
 import java.util.List;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
@@ -34,6 +36,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
@@ -42,6 +45,7 @@ import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TypeDescriptors;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.checkerframework.checker.nullness.qual.Nullable;
 
 /**
  * A generic grouping transform for schema {@link PCollection}s.
@@ -153,7 +157,7 @@ public class Group {
      */
     public <OutputT> CombineGlobally<InputT, OutputT> aggregate(
         CombineFn<InputT, ?, OutputT> combineFn) {
-      return new CombineGlobally<>(combineFn);
+      return new CombineGlobally<>(combineFn, 0);
     }
 
     /**
@@ -169,10 +173,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldNames(inputFieldName),
-                  false,
-                  fn,
-                  outputFieldName));
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName),
+          0);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -183,7 +185,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName));
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName),
+          0);
     }
 
     /** The same as {@link #aggregateField} but using field id. */
@@ -194,7 +197,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName));
+                  FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName),
+          0);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -205,7 +209,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName));
+                  FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName),
+          0);
     }
 
     /**
@@ -221,7 +226,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField));
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField),
+          0);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -232,7 +238,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
+                  FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField),
+          0);
     }
 
     /** The same as {@link #aggregateField} but using field id. */
@@ -241,7 +248,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField));
+                  FieldAccessDescriptor.withFieldIds(inputFielId), false, fn, outputField),
+          0);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -252,7 +260,8 @@ public class Group {
       return new CombineFieldsGlobally<>(
           SchemaAggregateFn.create()
               .aggregateFields(
-                  FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField));
+                  FieldAccessDescriptor.withFieldIds(inputFielId), true, fn, outputField),
+          0);
     }
 
     /**
@@ -298,8 +307,8 @@ public class Group {
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
         String outputFieldName) {
       return new CombineFieldsGlobally<>(
-          SchemaAggregateFn.create()
-              .aggregateFields(fieldsToAggregate, false, fn, outputFieldName));
+          SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputFieldName),
+          0);
     }
 
     /**
@@ -335,7 +344,7 @@ public class Group {
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
         Field outputField) {
       return new CombineFieldsGlobally<>(
-          SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField));
+          SchemaAggregateFn.create().aggregateFields(fieldsToAggregate, false, fn, outputField), 0);
     }
 
     @Override
@@ -351,14 +360,20 @@ public class Group {
   public static class CombineGlobally<InputT, OutputT>
       extends PTransform<PCollection<InputT>, PCollection<OutputT>> {
     final CombineFn<InputT, ?, OutputT> combineFn;
+    int fanout;
 
-    CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn) {
+    CombineGlobally(CombineFn<InputT, ?, OutputT> combineFn, int fanout) {
       this.combineFn = combineFn;
+      this.fanout = fanout;
+    }
+
+    public CombineGlobally<InputT, OutputT> withFanout(int fanout) {
+      return new CombineGlobally<>(combineFn, fanout);
     }
 
     @Override
     public PCollection<OutputT> expand(PCollection<InputT> input) {
-      return input.apply("globalCombine", Combine.globally(combineFn));
+      return input.apply("globalCombine", Combine.globally(combineFn).withFanout(fanout));
     }
   }
 
@@ -420,9 +435,11 @@ public class Group {
    */
   public static class CombineFieldsGlobally<InputT> extends AggregateCombiner<InputT> {
     private final SchemaAggregateFn.Inner schemaAggregateFn;
+    private final int fanout;
 
-    CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn) {
+    CombineFieldsGlobally(SchemaAggregateFn.Inner schemaAggregateFn, int fanout) {
       this.schemaAggregateFn = schemaAggregateFn;
+      this.fanout = fanout;
     }
 
     /**
@@ -431,7 +448,7 @@ public class Group {
      * determined by the output types of all the composed combiners.
      */
     public static CombineFieldsGlobally<?> create() {
-      return new CombineFieldsGlobally<>(SchemaAggregateFn.create());
+      return new CombineFieldsGlobally<>(SchemaAggregateFn.create(), 0);
     }
 
     /**
@@ -450,7 +467,8 @@ public class Group {
         String outputFieldName) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName));
+              FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputFieldName),
+          fanout);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -460,7 +478,8 @@ public class Group {
             String outputFieldName) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName));
+              FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputFieldName),
+          fanout);
     }
 
     public <CombineInputT, AccumT, CombineOutputT> CombineFieldsGlobally<InputT> aggregateField(
@@ -469,7 +488,8 @@ public class Group {
         String outputFieldName) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName));
+              FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputFieldName),
+          fanout);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -479,7 +499,8 @@ public class Group {
             String outputFieldName) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName));
+              FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputFieldName),
+          fanout);
     }
 
     /**
@@ -495,7 +516,8 @@ public class Group {
         Field outputField) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField));
+              FieldAccessDescriptor.withFieldNames(inputFieldName), false, fn, outputField),
+          fanout);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -505,7 +527,8 @@ public class Group {
             Field outputField) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField));
+              FieldAccessDescriptor.withFieldNames(inputFieldName), true, fn, outputField),
+          fanout);
     }
 
     @Override
@@ -513,7 +536,8 @@ public class Group {
         int inputFieldId, CombineFn<CombineInputT, AccumT, CombineOutputT> fn, Field outputField) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField));
+              FieldAccessDescriptor.withFieldIds(inputFieldId), false, fn, outputField),
+          fanout);
     }
 
     public <CombineInputT, AccumT, CombineOutputT>
@@ -523,7 +547,8 @@ public class Group {
             Field outputField) {
       return new CombineFieldsGlobally<>(
           schemaAggregateFn.aggregateFields(
-              FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField));
+              FieldAccessDescriptor.withFieldIds(inputFieldId), true, fn, outputField),
+          fanout);
     }
 
     /**
@@ -568,7 +593,8 @@ public class Group {
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
         String outputFieldName) {
       return new CombineFieldsGlobally<>(
-          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName));
+          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputFieldName),
+          fanout);
     }
 
     /**
@@ -605,13 +631,17 @@ public class Group {
         CombineFn<CombineInputT, AccumT, CombineOutputT> fn,
         Field outputField) {
       return new CombineFieldsGlobally<>(
-          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField));
+          schemaAggregateFn.aggregateFields(fieldAccessDescriptor, false, fn, outputField), fanout);
+    }
+
+    public CombineFieldsGlobally<InputT> withFanout(int fanout) {
+      return new CombineFieldsGlobally<>(schemaAggregateFn, fanout);
     }
 
     @Override
     public PCollection<Row> expand(PCollection<InputT> input) {
       SchemaAggregateFn.Inner fn = schemaAggregateFn.withSchema(input.getSchema());
-      Combine.Globally<Row, Row> combineFn = Combine.globally(fn);
+      Combine.Globally<Row, Row> combineFn = Combine.globally(fn).withFanout(fanout);
       if (!(input.getWindowingStrategy().getWindowFn() instanceof GlobalWindows)) {
         combineFn = combineFn.withoutDefaults();
       }
@@ -631,6 +661,7 @@ public class Group {
    */
   @AutoValue
   public abstract static class ByFields<InputT> extends AggregateCombiner<InputT> {
+
     abstract FieldAccessDescriptor getFieldAccessDescriptor();
 
     abstract String getKeyField();
@@ -651,11 +682,11 @@ public class Group {
       abstract ByFields<InputT> build();
     }
 
-    class ToKv extends PTransform<PCollection<InputT>, PCollection<KV<Row, Iterable<Row>>>> {
+    class ToKV extends PTransform<PCollection<InputT>, PCollection<KV<Row, Row>>> {
       private RowSelector rowSelector;
 
       @Override
-      public PCollection<KV<Row, Iterable<Row>>> expand(PCollection<InputT> input) {
+      public PCollection<KV<Row, Row>> expand(PCollection<InputT> input) {
         Schema schema = input.getSchema();
         FieldAccessDescriptor resolved = getFieldAccessDescriptor().resolve(schema);
         rowSelector = new RowSelectorContainer(schema, resolved, true);
@@ -666,13 +697,12 @@ public class Group {
             .apply(
                 "selectKeys",
                 WithKeys.of((Row e) -> rowSelector.select(e)).withKeyType(TypeDescriptors.rows()))
-            .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)))
-            .apply("GroupByKey", GroupByKey.create());
+            .setCoder(KvCoder.of(SchemaCoder.of(keySchema), SchemaCoder.of(schema)));
       }
     }
 
-    public ToKv getToKvs() {
-      return new ToKv();
+    public ToKV getToKV() {
+      return new ToKV();
     }
 
     private static <InputT> ByFields<InputT> of(FieldAccessDescriptor fieldAccessDescriptor) {
@@ -919,7 +949,8 @@ public class Group {
               .build();
 
       return input
-          .apply("ToKvs", getToKvs())
+          .apply("ToKvs", getToKV())
+          .apply("GroupByKey", GroupByKey.create())
           .apply(
               "ToRow",
               ParDo.of(
@@ -942,6 +973,31 @@ public class Group {
    */
   @AutoValue
   public abstract static class CombineFieldsByFields<InputT> extends AggregateCombiner<InputT> {
+
+    @AutoOneOf(Fanout.Kind.class)
+    public abstract static class Fanout implements Serializable {
+      public enum Kind {
+        NUMBER,
+        FUNCTION
+      }
+
+      public abstract Kind getKind();
+
+      public abstract Integer getNumber();
+
+      public abstract SerializableFunction<Row, Integer> getFunction();
+
+      public static Fanout of(int n) {
+        return AutoOneOf_Group_CombineFieldsByFields_Fanout.number(n);
+      }
+
+      public static Fanout of(SerializableFunction<Row, Integer> f) {
+        return AutoOneOf_Group_CombineFieldsByFields_Fanout.function(f);
+      }
+    }
+
+    abstract @Nullable Fanout getFanout();
+
     abstract ByFields<InputT> getByFields();
 
     abstract SchemaAggregateFn.Inner getSchemaAggregateFn();
@@ -954,6 +1010,8 @@ public class Group {
 
     @AutoValue.Builder
     abstract static class Builder<InputT> {
+      public abstract Builder<InputT> setFanout(@Nullable Fanout value);
+
       abstract Builder<InputT> setByFields(ByFields<InputT> byFields);
 
       abstract Builder<InputT> setSchemaAggregateFn(SchemaAggregateFn.Inner schemaAggregateFn);
@@ -988,6 +1046,14 @@ public class Group {
       return toBuilder().setValueField(valueField).build();
     }
 
+    public CombineFieldsByFields<InputT> withHotKeyFanout(int n) {
+      return toBuilder().setFanout(Fanout.of(n)).build();
+    }
+
+    public CombineFieldsByFields<InputT> withHotKeyFanout(SerializableFunction<Row, Integer> f) {
+      return toBuilder().setFanout(Fanout.of(f)).build();
+    }
+
     /**
      * Build up an aggregation function over the input elements.
      *
@@ -1187,9 +1253,25 @@ public class Group {
           .build();
     }
 
+    PTransform<PCollection<KV<Row, Row>>, PCollection<KV<Row, Row>>> getCombineTransform(
+        Schema schema) {
+      SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(schema);
+      @Nullable Fanout fanout = getFanout();
+      if (fanout != null) {
+        switch (fanout.getKind()) {
+          case NUMBER:
+            return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getNumber());
+          case FUNCTION:
+            return Combine.<Row, Row, Row>perKey(fn).withHotKeyFanout(fanout.getFunction());
+          default:
+            throw new RuntimeException("Unexpected kind: " + fanout.getKind());
+        }
+      }
+      return Combine.perKey(fn);
+    }
+
     @Override
     public PCollection<Row> expand(PCollection<InputT> input) {
-      SchemaAggregateFn.Inner fn = getSchemaAggregateFn().withSchema(input.getSchema());
 
       Schema keySchema = getByFields().getKeySchema(input.getSchema());
       Schema outputSchema =
@@ -1199,8 +1281,8 @@ public class Group {
               .build();
 
       return input
-          .apply("ToKvs", getByFields().getToKvs())
-          .apply("Combine", Combine.groupedValues(fn))
+          .apply("ToKvs", getByFields().getToKV())
+          .apply("Combine", getCombineTransform(input.getSchema()))
           .apply(
               "ToRow",
               ParDo.of(
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
index b4b074e0ca0..f6f33208a10 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/schemas/transforms/GroupTest.java
@@ -30,6 +30,7 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.List;
+import javax.annotation.Nullable;
 import org.apache.beam.sdk.schemas.AutoValueSchema;
 import org.apache.beam.sdk.schemas.NoSuchSchemaException;
 import org.apache.beam.sdk.schemas.Schema;
@@ -50,6 +51,7 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sample;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.Top;
 import org.apache.beam.sdk.values.PCollection;
@@ -260,17 +262,30 @@ public class GroupTest implements Serializable {
 
   @Test
   @Category(NeedsRunner.class)
-  public void testGlobalAggregation() {
+  public void testGlobalAggregationWithoutFanout() {
+    globalAggregationWithFanout(false);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testGlobalAggregationWithFanout() {
+    globalAggregationWithFanout(true);
+  }
+
+  public void globalAggregationWithFanout(boolean withFanout) {
     Collection<Basic> elements =
         ImmutableList.of(
             Basic.of("key1", 1, "value1"),
             Basic.of("key1", 1, "value2"),
             Basic.of("key2", 2, "value3"),
             Basic.of("key2", 2, "value4"));
-    PCollection<Long> count =
-        pipeline
-            .apply(Create.of(elements))
-            .apply(Group.<Basic>globally().aggregate(Count.combineFn()));
+
+    Group.CombineGlobally<Basic, Long> transform =
+        Group.<Basic>globally().aggregate(Count.combineFn());
+    if (withFanout) {
+      transform = transform.withFanout(10);
+    }
+    PCollection<Long> count = pipeline.apply(Create.of(elements)).apply(transform);
     PAssert.that(count).containsInAnyOrder(4L);
 
     pipeline.run();
@@ -426,7 +441,17 @@ public class GroupTest implements Serializable {
 
   @Test
   @Category(NeedsRunner.class)
-  public void testAggregateByMultipleFields() {
+  public void testAggregateByMultipleFieldsWithoutFanout() {
+    aggregateByMultipleFieldsWithFanout(false);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testAggregateByMultipleFieldsWithFanout() {
+    aggregateByMultipleFieldsWithFanout(true);
+  }
+
+  public void aggregateByMultipleFieldsWithFanout(boolean withFanout) {
     Collection<Aggregate> elements =
         ImmutableList.of(
             Aggregate.of(1, 1, 2),
@@ -435,12 +460,14 @@ public class GroupTest implements Serializable {
             Aggregate.of(4, 2, 5));
 
     List<String> fieldNames = Lists.newArrayList("field1", "field2");
-    PCollection<Row> aggregate =
-        pipeline
-            .apply(Create.of(elements))
-            .apply(
-                Group.<Aggregate>globally()
-                    .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2"));
+
+    Group.CombineFieldsGlobally<Aggregate> transform =
+        Group.<Aggregate>globally()
+            .aggregateFields(fieldNames, new MultipleFieldCombineFn(), "field1+field2");
+    if (withFanout) {
+      transform = transform.withFanout(10);
+    }
+    PCollection<Row> aggregate = pipeline.apply(Create.of(elements)).apply(transform);
 
     Schema outputSchema = Schema.builder().addInt64Field("field1+field2").build();
     Row expectedRow = Row.withSchema(outputSchema).addValues(16L).build();
@@ -462,7 +489,25 @@ public class GroupTest implements Serializable {
 
   @Test
   @Category(NeedsRunner.class)
-  public void testByKeyWithSchemaAggregateFnNestedFields() {
+  public void testByKeyWithSchemaAggregateFnNestedFieldsNoFanout() {
+    byKeyWithSchemaAggregateFnNestedFieldsWithFanout(null);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testByKeyWithSchemaAggregateFnNestedFieldsWithNumberFanout() {
+    byKeyWithSchemaAggregateFnNestedFieldsWithFanout(Group.CombineFieldsByFields.Fanout.of(10));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testByKeyWithSchemaAggregateFnNestedFieldsWithFunctionFanout() {
+    byKeyWithSchemaAggregateFnNestedFieldsWithFanout(
+        Group.CombineFieldsByFields.Fanout.of(SerializableFunctions.constant(10)));
+  }
+
+  public void byKeyWithSchemaAggregateFnNestedFieldsWithFanout(
+      @Nullable Group.CombineFieldsByFields.Fanout fanout) {
     Collection<OuterAggregate> elements =
         ImmutableList.of(
             OuterAggregate.of(Aggregate.of(1, 1, 2)),
@@ -470,14 +515,23 @@ public class GroupTest implements Serializable {
             OuterAggregate.of(Aggregate.of(3, 2, 4)),
             OuterAggregate.of(Aggregate.of(4, 2, 5)));
 
-    PCollection<Row> aggregations =
-        pipeline
-            .apply(Create.of(elements))
-            .apply(
-                Group.<OuterAggregate>byFieldNames("inner.field2")
-                    .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
-                    .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
-                    .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top"));
+    Group.CombineFieldsByFields<OuterAggregate> transform =
+        Group.<OuterAggregate>byFieldNames("inner.field2")
+            .aggregateField("inner.field1", Sum.ofLongs(), "field1_sum")
+            .aggregateField("inner.field3", Sum.ofIntegers(), "field3_sum")
+            .aggregateField("inner.field1", Top.largestLongsFn(1), "field1_top");
+    if (fanout != null) {
+      switch (fanout.getKind()) {
+        case NUMBER:
+          transform = transform.withHotKeyFanout(fanout.getNumber());
+          break;
+        case FUNCTION:
+          transform = transform.withHotKeyFanout(fanout.getFunction());
+          break;
+      }
+    }
+
+    PCollection<Row> aggregations = pipeline.apply(Create.of(elements)).apply(transform);
 
     Schema keySchema = Schema.builder().addInt64Field("field2").build();
     Schema valueSchema =
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
index 906258b164a..be88229e755 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamWindowRel.java
@@ -35,6 +35,7 @@ import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
 import org.apache.beam.sdk.schemas.Schema;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
@@ -254,7 +255,9 @@ public class BeamWindowRel extends Window implements BeamRelNode {
           org.apache.beam.sdk.schemas.transforms.Group.ByFields<Row> myg =
               org.apache.beam.sdk.schemas.transforms.Group.byFieldIds(af.partitionKeys);
           PCollection<KV<Row, Iterable<Row>>> partitionBy =
-              inputData.apply(prefix + "partitionBy", myg.getToKvs());
+              inputData
+                  .apply(prefix + "partitionByKV", myg.getToKV())
+                  .apply(prefix + "partitionByGK", GroupByKey.create());
           partitioned =
               partitionBy
                   .apply(prefix + "selectOnlyValues", ParDo.of(new SelectOnlyValues()))